/* try to train a decision tree */ /* I assume I have a dataset X, and labels y. Due to implementation * simplification, I ask the number of classes as well (it may be that * not all classes are available in y). I assume that y contains * integers, from 1,2,...,K. Also the size of the feature subset F * should be given. */ /* The data matrix should be nxd, where n is the number of objects, and * d is the number of dimensions. */ #include #include #include /* first the tree structure */ typedef struct dtree { int class; /* predicted class */ int feat; /* feature to split */ double thres; /* threshold */ struct dtree *left, *right; /* children */ } dtree; /* then the data structure for sorting */ typedef struct obj { double val; int class; int idx; } obj; /* general variables (ok, global vars are bad, ok) */ double *x; /* data */ double *y; /* labels */ size_t N,D; /* nr objs, nr feats*/ int K,F; /* nr classes, nr subspaces */ int nrnodes; /* the size of the tree */ double *tmp_p; /* for gini */ int storeindex;/* for storing tree */ /* All the stuff for sorting */ int compare_objs(const void *a,const void *b) { obj *obj_a = (obj *)a; obj *obj_b = (obj *)b; if (obj_a->val > obj_b->val) return 1; else return -1; } int compare_doubles(const void *a,const void *b) { double *obj_a = (double *)a; double *obj_b = (double *)b; if (obj_a > obj_b) return 1; else return -1; } /* Gini */ double gini(int *I) { int i; double out; /* initialize to zero */ for (i=0;iclass = y[I[1]]; out->feat = 0; out->thres = 0; out->left = NULL; out->right = NULL; /* printf(" Node %d is leaf. Done\n",nrnodes); */ return out; } else { /* store illegal class number to show it is a branch */ out->class = -1; /* what features to use? */ fss = (int *)malloc(D*sizeof(int)); if (F>0) { /* randomly permute feature indices */ tmp = (obj *)malloc(D*sizeof(obj)); for (i=0;i tmp[%d] = %f, idx=%d\n",j,tmp[j].val,tmp[j].idx); */ /* make indices for the split */ for (j=0;j gini left = %f\n",gini(Ileft)); */ Iright[0]=I[0]-j; /* for (k=1;k<=I[0]-j;k++) printf(" Ir[%d] = %d ",k,Iright[k]); */ /* printf(" gini right = %f\n",gini(Iright)); */ err = j*gini(Ileft) + (I[0]-j)*gini(Iright); /* printf(" give err %f\n",err); */ /* is this good? */ if (errfeat = fss[i]; out->thres = (tmp[j].val + tmp[j-1].val)/2; /* printf(" thres = %f\n",out->thres); */ for (k=0;k<=j;k++) Ileftbest[k] = Ileft[k]; Ileftbest[0] = j; for (k=0;k<=I[0]-j;k++) Irightbest[k] = Iright[k]; Irightbest[0] = I[0]-j; } } } /* printf("Finally, we use feature %d on split %d, threshold %f\n", out->feat,bestsplit,out->thres); */ /*printf("Left objects:\n"); for (k=1;k<=Ileftbest[0];k++) printf(" Ileft[%d] = %d\n",k,Ileftbest[k]); printf("Right objects:\n"); for (k=1;k<=Irightbest[0];k++) printf(" Iright[%d] = %d\n",k,Irightbest[k]);*/ /* now find the children */ out->left = tree_train(Ileftbest); out->right = tree_train(Irightbest); free(Ileft); free(Iright); free(Ileftbest); free(Irightbest); free(tmp); free(fss); } return out; } /* Store the tree in a matrix */ /* Order of the variables: * 1. class * 2. feature * 3. threshold * 4. left branch index * 5. right branch index */ void tree_encode(dtree *tree,double *ptr) { int thisindex = storeindex; /* printf("Store %d \n",thisindex); */ if (tree->class<0) /* it is branching */ { /* printf(" : split feat %d at %f\n",tree->feat,tree->thres); */ *(ptr+5*thisindex) = -1; /* encode splitting */ *(ptr+5*thisindex+1) = tree->feat; *(ptr+5*thisindex+2) = tree->thres; storeindex += 1; *(ptr+5*thisindex+3) = storeindex+1; /* Matlab indexing...*/ tree_encode(tree->left,ptr); storeindex += 1; *(ptr+5*thisindex+4) = storeindex+1; tree_encode(tree->right,ptr); } else { /* printf(" : class %d\n",tree->class); */ *(ptr+5*thisindex) = tree->class; *(ptr+5*thisindex+1) = 0; *(ptr+5*thisindex+2) = 0; *(ptr+5*thisindex+3) = 0; *(ptr+5*thisindex+4) = 0; } } void destroy_tree(dtree *tree) { if (tree->class<0) { destroy_tree(tree->left); destroy_tree(tree->right); free(tree); /* printf("removed branch\n"); */ } else { free(tree); /* printf("removed leave\n"); */ } } int classify_data(double *T, int idx, int obj) { int k; int feat; double thres; /* printf("Obj x(%d): [",obj); for (k=0;k