[23] | 1 | /* try to train a decision tree */ |
---|
| 2 | /* I assume I have a dataset X, and labels y. Due to implementation |
---|
| 3 | * simplification, I ask the number of classes as well (it may be that |
---|
| 4 | * not all classes are available in y). I assume that y contains |
---|
| 5 | * integers, from 1,2,...,K. Also the size of the feature subset F |
---|
| 6 | * should be given. */ |
---|
| 7 | |
---|
| 8 | /* The data matrix should be nxd, where n is the number of objects, and |
---|
| 9 | * d is the number of dimensions. */ |
---|
| 10 | |
---|
| 11 | #include <stdlib.h> |
---|
| 12 | #include <stdio.h> |
---|
| 13 | #include <mex.h> |
---|
| 14 | |
---|
| 15 | /* first the tree structure */ |
---|
| 16 | typedef struct dtree { |
---|
| 17 | int class; /* predicted class */ |
---|
| 18 | int feat; /* feature to split */ |
---|
| 19 | double thres; /* threshold */ |
---|
| 20 | struct dtree *left, *right; /* children */ |
---|
| 21 | } dtree; |
---|
| 22 | |
---|
| 23 | /* then the data structure for sorting */ |
---|
| 24 | typedef struct obj { |
---|
| 25 | double val; |
---|
| 26 | int class; |
---|
| 27 | int idx; |
---|
| 28 | } obj; |
---|
| 29 | |
---|
| 30 | /* general variables (ok, global vars are bad, ok) */ |
---|
| 31 | double *x; /* data */ |
---|
| 32 | double *y; /* labels */ |
---|
| 33 | size_t N,D; /* nr objs, nr feats*/ |
---|
| 34 | int K,F; /* nr classes, nr subspaces */ |
---|
| 35 | int nrnodes; /* the size of the tree */ |
---|
| 36 | double *tmp_p; /* for gini */ |
---|
| 37 | int storeindex;/* for storing tree */ |
---|
| 38 | |
---|
| 39 | /* All the stuff for sorting */ |
---|
| 40 | int compare_objs(const void *a,const void *b) |
---|
| 41 | { |
---|
| 42 | obj *obj_a = (obj *)a; |
---|
| 43 | obj *obj_b = (obj *)b; |
---|
| 44 | |
---|
| 45 | if (obj_a->val > obj_b->val) |
---|
| 46 | return 1; |
---|
| 47 | else |
---|
| 48 | return -1; |
---|
| 49 | } |
---|
| 50 | int compare_doubles(const void *a,const void *b) |
---|
| 51 | { |
---|
| 52 | double *obj_a = (double *)a; |
---|
| 53 | double *obj_b = (double *)b; |
---|
| 54 | |
---|
| 55 | if (obj_a > obj_b) |
---|
| 56 | return 1; |
---|
| 57 | else |
---|
| 58 | return -1; |
---|
| 59 | } |
---|
| 60 | |
---|
| 61 | /* Gini */ |
---|
| 62 | double gini(int *I) |
---|
| 63 | { |
---|
| 64 | int i; |
---|
| 65 | double out; |
---|
| 66 | |
---|
| 67 | /* initialize to zero */ |
---|
| 68 | for (i=0;i<K;i++) |
---|
| 69 | tmp_p[i] = 0; |
---|
| 70 | /* count the occurance of each class */ |
---|
| 71 | /* index vector starts at 1, the class numbering as well... */ |
---|
| 72 | for (i=0;i<I[0];i++) |
---|
| 73 | tmp_p[(int)y[I[i+1]]-1] +=1; |
---|
| 74 | /* normalize and compute gini */ |
---|
| 75 | out = 0; |
---|
| 76 | for (i=0;i<K;i++) |
---|
| 77 | out = out + tmp_p[i]*(1-tmp_p[i]/I[0])/I[0]; |
---|
| 78 | |
---|
| 79 | return out; |
---|
| 80 | } |
---|
| 81 | |
---|
| 82 | /* make a tree */ |
---|
| 83 | dtree *tree_train(int *I) |
---|
| 84 | { |
---|
| 85 | double err,besterr; |
---|
| 86 | dtree *out; |
---|
| 87 | int *fss; |
---|
| 88 | obj *tmp; |
---|
| 89 | int i,j,k; |
---|
| 90 | int bestsplit; |
---|
| 91 | int *Ileft, *Iright, *Ileftbest, *Irightbest; |
---|
| 92 | |
---|
| 93 | /* make the node */ |
---|
| 94 | out = (dtree *)malloc(sizeof(dtree)); |
---|
| 95 | nrnodes +=1; |
---|
| 96 | /* printf("Make NODE %d!\n",nrnodes); */ |
---|
| 97 | /* printf("%d objects in this node\n",I[0]); */ |
---|
| 98 | |
---|
| 99 | /* is it good enough? */ |
---|
| 100 | err = gini(I); |
---|
| 101 | /* printf(" gini = %f\n",err); */ |
---|
| 102 | if (err==0) |
---|
| 103 | { |
---|
| 104 | /* leave is perfectly classified: return this */ |
---|
| 105 | out->class = y[I[1]]; |
---|
| 106 | out->feat = 0; |
---|
| 107 | out->thres = 0; |
---|
| 108 | out->left = NULL; |
---|
| 109 | out->right = NULL; |
---|
| 110 | /* printf(" Node %d is leaf. Done\n",nrnodes); */ |
---|
| 111 | return out; |
---|
| 112 | } |
---|
| 113 | else |
---|
| 114 | { |
---|
| 115 | /* store illegal class number to show it is a branch */ |
---|
| 116 | out->class = -1; |
---|
| 117 | /* what features to use? */ |
---|
| 118 | fss = (int *)malloc(D*sizeof(int)); |
---|
| 119 | if (F>0) { |
---|
| 120 | /* randomly permute feature indices */ |
---|
| 121 | tmp = (obj *)malloc(D*sizeof(obj)); |
---|
| 122 | for (i=0;i<D;i++) |
---|
| 123 | { |
---|
| 124 | tmp[i].val = random(); |
---|
| 125 | tmp[i].idx = i; |
---|
| 126 | /* printf("tmp[%d]=%f,%d\n",i,tmp[i].val,tmp[i].idx); */ |
---|
| 127 | } |
---|
| 128 | qsort(tmp,D,sizeof(tmp[0]),compare_objs); |
---|
| 129 | for (i=0;i<D;i++) |
---|
| 130 | { |
---|
| 131 | fss[i] = tmp[i].idx; |
---|
| 132 | /* printf("fss[%d]=%d\n",i,fss[i]); */ |
---|
| 133 | } |
---|
[25] | 134 | free(tmp); |
---|
[23] | 135 | } |
---|
| 136 | else |
---|
[25] | 137 | { |
---|
[23] | 138 | for (i=0;i<D;i++) |
---|
[25] | 139 | { |
---|
[23] | 140 | fss[i] = i; |
---|
[25] | 141 | /* printf("fss[%d]=%d\n",i,fss[i]); */ |
---|
| 142 | } |
---|
| 143 | F = D; |
---|
| 144 | } |
---|
[23] | 145 | |
---|
| 146 | /* check each feature separately: */ |
---|
| 147 | besterr = 1e100; |
---|
| 148 | tmp = (obj *)malloc(I[0]*sizeof(obj)); |
---|
| 149 | Ileft = (int *)malloc((I[0]+1)*sizeof(int)); |
---|
| 150 | Iright = (int *)malloc((I[0]+1)*sizeof(int)); |
---|
| 151 | Ileftbest = (int *)malloc((I[0]+1)*sizeof(int)); |
---|
| 152 | Irightbest = (int *)malloc((I[0]+1)*sizeof(int)); |
---|
| 153 | for (i=0;i<F;i++) { |
---|
| 154 | /* printf("Try feature %d:\n",fss[i]); */ |
---|
| 155 | /* sort the data along feature fss[i] */ |
---|
| 156 | for (j=0;j<I[0];j++){ |
---|
| 157 | tmp[j].val = x[fss[i]*N+I[j+1]]; |
---|
| 158 | tmp[j].class = y[j]; |
---|
| 159 | tmp[j].idx = I[j+1]; |
---|
| 160 | /* printf(" tmp[%d] = %f, idx=%d\n",j,tmp[j].val,tmp[j].idx); */ |
---|
| 161 | } |
---|
| 162 | qsort((void *)tmp,I[0],sizeof(tmp[0]),compare_objs); |
---|
[25] | 163 | /* for (j=0;j<I[0];j++) printf(" -> tmp[%d] = %f, idx=%d\n",j,tmp[j].val,tmp[j].idx); */ |
---|
[23] | 164 | /* make indices for the split */ |
---|
| 165 | for (j=0;j<I[0];j++) |
---|
| 166 | { |
---|
| 167 | Ileft[j+1] = tmp[j].idx; |
---|
| 168 | Iright[I[0]-j] = tmp[j].idx; |
---|
| 169 | } |
---|
| 170 | /* for (k=1;k<=I[0];k++) printf(" Ileft[%d] = %d \n",k,Ileft[k]); */ |
---|
| 171 | /* for (k=1;k<=I[0];k++) printf(" Iright[%d] = %d \n",k,Iright[k]); */ |
---|
| 172 | /* if (nrnodes==3) return out; */ |
---|
| 173 | /* run over all possible splits */ |
---|
| 174 | for (j=1;j<I[0];j++) |
---|
| 175 | { |
---|
| 176 | /* printf(" split %d ",j); */ |
---|
| 177 | Ileft[0]=j; /* redefine the length of vector Ileft */ |
---|
| 178 | /* for (k=1;k<=j;k++) printf(" Il[%d] = %d ",k,Ileft[k]); */ |
---|
| 179 | /* printf(" -> gini left = %f\n",gini(Ileft)); */ |
---|
| 180 | Iright[0]=I[0]-j; |
---|
| 181 | /* for (k=1;k<=I[0]-j;k++) printf(" Ir[%d] = %d ",k,Iright[k]); */ |
---|
| 182 | /* printf(" gini right = %f\n",gini(Iright)); */ |
---|
| 183 | err = j*gini(Ileft) + (I[0]-j)*gini(Iright); |
---|
| 184 | /* printf(" give err %f\n",err); */ |
---|
| 185 | /* is this good? */ |
---|
| 186 | if (err<besterr) { |
---|
| 187 | /* printf(" We have a better result! (%f<%f)\n",err,besterr); */ |
---|
| 188 | /* printf(" Feature %d at %d ",fss[i],j); */ |
---|
| 189 | besterr = err; |
---|
| 190 | bestsplit = j; |
---|
| 191 | out->feat = fss[i]; |
---|
| 192 | out->thres = (tmp[j].val + tmp[j-1].val)/2; |
---|
| 193 | /* printf(" thres = %f\n",out->thres); */ |
---|
| 194 | for (k=0;k<=j;k++) |
---|
| 195 | Ileftbest[k] = Ileft[k]; |
---|
| 196 | Ileftbest[0] = j; |
---|
| 197 | for (k=0;k<=I[0]-j;k++) |
---|
| 198 | Irightbest[k] = Iright[k]; |
---|
| 199 | Irightbest[0] = I[0]-j; |
---|
| 200 | } |
---|
| 201 | |
---|
| 202 | } |
---|
| 203 | |
---|
| 204 | } |
---|
[25] | 205 | /* printf("Finally, we use feature %d on split %d, threshold %f\n", |
---|
| 206 | out->feat,bestsplit,out->thres); */ |
---|
| 207 | /*printf("Left objects:\n"); |
---|
[23] | 208 | for (k=1;k<=Ileftbest[0];k++) |
---|
| 209 | printf(" Ileft[%d] = %d\n",k,Ileftbest[k]); |
---|
| 210 | printf("Right objects:\n"); |
---|
| 211 | for (k=1;k<=Irightbest[0];k++) |
---|
| 212 | printf(" Iright[%d] = %d\n",k,Irightbest[k]);*/ |
---|
| 213 | |
---|
| 214 | /* now find the children */ |
---|
| 215 | out->left = tree_train(Ileftbest); |
---|
| 216 | out->right = tree_train(Irightbest); |
---|
| 217 | |
---|
| 218 | |
---|
| 219 | free(Ileft); |
---|
| 220 | free(Iright); |
---|
| 221 | free(Ileftbest); |
---|
| 222 | free(Irightbest); |
---|
| 223 | free(tmp); |
---|
| 224 | free(fss); |
---|
| 225 | } |
---|
| 226 | return out; |
---|
| 227 | } |
---|
| 228 | |
---|
| 229 | /* Store the tree in a matrix */ |
---|
| 230 | /* Order of the variables: |
---|
| 231 | * 1. class |
---|
| 232 | * 2. feature |
---|
| 233 | * 3. threshold |
---|
| 234 | * 4. left branch index |
---|
| 235 | * 5. right branch index */ |
---|
| 236 | void tree_encode(dtree *tree,double *ptr) |
---|
| 237 | { |
---|
| 238 | int thisindex = storeindex; |
---|
| 239 | |
---|
| 240 | /* printf("Store %d \n",thisindex); */ |
---|
| 241 | |
---|
| 242 | if (tree->class<0) /* it is branching */ |
---|
| 243 | { |
---|
| 244 | /* printf(" : split feat %d at %f\n",tree->feat,tree->thres); */ |
---|
| 245 | *(ptr+5*thisindex) = -1; /* encode splitting */ |
---|
| 246 | *(ptr+5*thisindex+1) = tree->feat; |
---|
| 247 | *(ptr+5*thisindex+2) = tree->thres; |
---|
| 248 | storeindex += 1; |
---|
| 249 | *(ptr+5*thisindex+3) = storeindex+1; /* Matlab indexing...*/ |
---|
| 250 | tree_encode(tree->left,ptr); |
---|
| 251 | storeindex += 1; |
---|
| 252 | *(ptr+5*thisindex+4) = storeindex+1; |
---|
| 253 | tree_encode(tree->right,ptr); |
---|
| 254 | } |
---|
| 255 | else |
---|
| 256 | { |
---|
| 257 | /* printf(" : class %d\n",tree->class); */ |
---|
| 258 | *(ptr+5*thisindex) = tree->class; |
---|
| 259 | *(ptr+5*thisindex+1) = 0; |
---|
| 260 | *(ptr+5*thisindex+2) = 0; |
---|
| 261 | *(ptr+5*thisindex+3) = 0; |
---|
| 262 | *(ptr+5*thisindex+4) = 0; |
---|
| 263 | } |
---|
| 264 | } |
---|
| 265 | |
---|
| 266 | void destroy_tree(dtree *tree) |
---|
| 267 | { |
---|
| 268 | if (tree->class<0) |
---|
| 269 | { |
---|
| 270 | destroy_tree(tree->left); |
---|
| 271 | destroy_tree(tree->right); |
---|
| 272 | free(tree); |
---|
| 273 | /* printf("removed branch\n"); */ |
---|
| 274 | } |
---|
| 275 | else |
---|
| 276 | { |
---|
| 277 | free(tree); |
---|
| 278 | /* printf("removed leave\n"); */ |
---|
| 279 | } |
---|
| 280 | } |
---|
| 281 | |
---|
| 282 | int classify_data(double *T, int idx, int obj) |
---|
| 283 | { |
---|
| 284 | int k; |
---|
| 285 | int feat; |
---|
| 286 | double thres; |
---|
| 287 | |
---|
| 288 | /* printf("Obj x(%d): [",obj); |
---|
| 289 | for (k=0;k<D;k++) |
---|
| 290 | printf("%f, ",x[obj+k*N]); |
---|
| 291 | printf("]\n"); */ |
---|
| 292 | |
---|
| 293 | if (*(T+5*idx)<0) /* branching */ |
---|
| 294 | { |
---|
| 295 | feat = (int)(*(T+5*idx+1)); |
---|
| 296 | thres = *(T+5*idx+2); |
---|
| 297 | /* printf(" is x[%d]=%f < %f? (x=%f)\n",feat, x[obj+feat*N],thres); */ |
---|
| 298 | if (x[obj+feat*N]<thres) |
---|
| 299 | { |
---|
| 300 | /* printf("left branch\n"); */ |
---|
| 301 | return classify_data(T,*(T+5*idx+3)-1,obj); |
---|
| 302 | } |
---|
| 303 | else |
---|
| 304 | { |
---|
| 305 | /* printf("right branch\n"); */ |
---|
| 306 | return classify_data(T,*(T+5*idx+4)-1,obj); |
---|
| 307 | } |
---|
| 308 | } |
---|
| 309 | else |
---|
| 310 | return *(T+5*idx); |
---|
| 311 | } |
---|
| 312 | |
---|
| 313 | |
---|
| 314 | |
---|
| 315 | /* GO! */ |
---|
| 316 | void mexFunction(int nlhs, mxArray *plhs[], |
---|
| 317 | int nrhs, const mxArray *prhs[]) |
---|
| 318 | { |
---|
| 319 | int i; |
---|
| 320 | int *I; /* index vector */ |
---|
| 321 | dtree *tree; |
---|
| 322 | double *T; |
---|
| 323 | double *ptr; |
---|
| 324 | |
---|
| 325 | /* Four inputs: train the tree */ |
---|
| 326 | /* We require four inputs, x,y, K and F */ |
---|
| 327 | if (nrhs==4) { |
---|
| 328 | /* Get the input and check stuff */ |
---|
[25] | 329 | /* printf("get input and check\n"); */ |
---|
[23] | 330 | x = mxGetPr(prhs[0]); |
---|
| 331 | N = mxGetM(prhs[0]); |
---|
| 332 | D = mxGetN(prhs[0]); |
---|
| 333 | y = mxGetPr(prhs[1]); |
---|
| 334 | if (mxGetM(prhs[1])!=N) { |
---|
| 335 | printf("ERROR: Size of Y does not fit with X.\n"); |
---|
| 336 | return; |
---|
| 337 | } |
---|
| 338 | K = (int)(mxGetPr(prhs[2])[0]); |
---|
| 339 | F = (int)(mxGetPr(prhs[3])[0]); |
---|
[25] | 340 | /* printf("N=%d, D=%d, K=%d, F=%d\n",N,D,K,F); */ |
---|
[23] | 341 | /* allocate */ |
---|
| 342 | tmp_p = (double *)malloc(K*sizeof(double)); /* for gini */ |
---|
| 343 | |
---|
| 344 | /* start the tree with all data: */ |
---|
| 345 | I = (int *)malloc((N+1)*sizeof(int)); |
---|
| 346 | I[0] = N; |
---|
| 347 | for (i=0;i<N;i++) I[i+1] = i; |
---|
| 348 | |
---|
| 349 | /* make the tree */ |
---|
| 350 | nrnodes = 0; |
---|
| 351 | tree = tree_train(I); |
---|
| 352 | |
---|
| 353 | /* store results */ |
---|
| 354 | /* printf("\n\n\nStore results, of %d nodes\n",nrnodes); */ |
---|
| 355 | plhs[0] = mxCreateNumericMatrix(5,nrnodes,mxDOUBLE_CLASS,mxREAL); |
---|
| 356 | storeindex = 0; |
---|
| 357 | tree_encode(tree,mxGetPr(plhs[0])); |
---|
| 358 | |
---|
| 359 | /* clean up */ |
---|
| 360 | destroy_tree(tree); |
---|
| 361 | free(I); |
---|
| 362 | free(tmp_p); |
---|
| 363 | } |
---|
| 364 | /* Two inputs: evaluate the tree */ |
---|
| 365 | /* We require the encoded tree T and inputs x */ |
---|
| 366 | else if (nrhs==2) { |
---|
| 367 | T = mxGetPr(prhs[0]); |
---|
| 368 | nrnodes = mxGetN(prhs[0]); |
---|
| 369 | x = mxGetPr(prhs[1]); |
---|
| 370 | N = mxGetM(prhs[1]); |
---|
| 371 | D = mxGetN(prhs[1]); |
---|
| 372 | |
---|
| 373 | plhs[0] = mxCreateNumericMatrix(N,1,mxDOUBLE_CLASS,mxREAL); |
---|
| 374 | ptr = mxGetPr(plhs[0]); |
---|
| 375 | for (i=0;i<N;i++) *(ptr+i) = classify_data(T,0,i); |
---|
| 376 | } |
---|
| 377 | else |
---|
| 378 | { |
---|
| 379 | printf("ERROR: only 2 or 4 inputs allowed!\n"); |
---|
| 380 | return; |
---|
| 381 | } |
---|
| 382 | } |
---|
| 383 | |
---|
| 384 | |
---|