source: prextra/decisiontree.c @ 24

Last change on this file since 24 was 23, checked in by dtax, 13 years ago

The decision tree and random forest, with compiled code!

File size: 8.9 KB
Line 
1/* try to train a decision tree */
2/* I assume I have a dataset X, and labels y. Due to implementation
3 * simplification, I ask the number of classes as well (it may be that
4 * not all classes are available in y). I assume that y contains
5 * integers, from 1,2,...,K.  Also the size of the feature subset F
6 * should be given. */
7
8/* The data matrix should be nxd, where n is the number of objects, and
9 * d is the number of dimensions. */
10
11#include <stdlib.h>
12#include <stdio.h>
13#include <mex.h>
14
15/* first the tree structure */
16typedef struct dtree {
17        int class;                    /* predicted class */
18        int feat;                     /* feature to split */
19        double thres;                 /* threshold */
20        struct dtree *left, *right;   /* children */
21} dtree;
22
23/* then the data structure for sorting */
24typedef struct obj {
25        double val;
26        int class;
27        int idx;
28} obj;
29
30/* general variables (ok, global vars are bad, ok) */
31double *x;     /* data */
32double *y;     /* labels */
33size_t N,D;    /* nr objs, nr feats*/
34int K,F;       /* nr classes, nr subspaces */
35int nrnodes;   /* the size of the tree */
36double *tmp_p; /* for gini */
37int storeindex;/* for storing tree */
38
39/* All the stuff for sorting */
40int compare_objs(const void *a,const void *b)
41{
42        obj *obj_a = (obj *)a;
43        obj *obj_b = (obj *)b;
44
45        if (obj_a->val > obj_b->val)
46                return 1;
47        else
48                return -1;
49}
50int compare_doubles(const void *a,const void *b)
51{
52        double *obj_a = (double *)a;
53        double *obj_b = (double *)b;
54
55        if (obj_a > obj_b)
56                return 1;
57        else
58                return -1;
59}
60
61/* Gini */
62double gini(int *I)
63{
64        int i;
65        double out;
66
67        /* initialize to zero */
68        for (i=0;i<K;i++)
69                tmp_p[i] = 0;
70        /* count the occurance of each class */
71        /* index vector starts at 1, the class numbering as well... */
72        for (i=0;i<I[0];i++)
73                tmp_p[(int)y[I[i+1]]-1] +=1;
74        /* normalize and compute gini */
75        out = 0;
76        for (i=0;i<K;i++)
77                out = out + tmp_p[i]*(1-tmp_p[i]/I[0])/I[0];
78
79        return out;
80}
81
82/* make a tree */
83dtree *tree_train(int *I)
84{
85        double err,besterr;
86        dtree *out;
87        int *fss;
88        obj *tmp;
89        int i,j,k;
90        int bestsplit;
91        int *Ileft, *Iright, *Ileftbest, *Irightbest;
92
93        /* make the node */
94        out = (dtree *)malloc(sizeof(dtree));
95        nrnodes +=1;
96/* printf("Make NODE %d!\n",nrnodes); */
97/* printf("%d objects in this node\n",I[0]); */
98
99        /* is it good enough? */
100        err = gini(I);
101/* printf("   gini = %f\n",err); */
102        if (err==0)
103        {
104                /* leave is perfectly classified: return this */
105                out->class = y[I[1]];
106                out->feat = 0;
107                out->thres = 0;
108                out->left = NULL;
109                out->right = NULL;
110/* printf("   Node %d is leaf. Done\n",nrnodes); */
111                return out;
112        }
113        else
114        {
115                /* store illegal class number to show it is a branch */
116                out->class = -1;
117                /* what features to use? */
118                fss = (int *)malloc(D*sizeof(int));
119                if (F>0) {
120                        /* randomly permute feature indices */
121                        tmp = (obj *)malloc(D*sizeof(obj));
122                        for (i=0;i<D;i++)
123                        {
124                                tmp[i].val = random();
125                                tmp[i].idx = i;
126                                /* printf("tmp[%d]=%f,%d\n",i,tmp[i].val,tmp[i].idx); */
127                        }
128                        qsort(tmp,D,sizeof(tmp[0]),compare_objs);
129                        for (i=0;i<D;i++)
130                        {
131                                fss[i] = tmp[i].idx;
132                                /* printf("fss[%d]=%d\n",i,fss[i]); */
133                        }
134                }
135                else
136                        for (i=0;i<D;i++)
137                                fss[i] = i;
138                free(tmp);
139
140                /* check each feature separately: */
141                besterr = 1e100;
142                tmp = (obj *)malloc(I[0]*sizeof(obj));
143                Ileft = (int *)malloc((I[0]+1)*sizeof(int));
144                Iright = (int *)malloc((I[0]+1)*sizeof(int));
145                Ileftbest = (int *)malloc((I[0]+1)*sizeof(int));
146                Irightbest = (int *)malloc((I[0]+1)*sizeof(int));
147                for (i=0;i<F;i++) {
148/* printf("Try feature %d:\n",fss[i]); */
149                        /* sort the data along feature fss[i] */
150                        for (j=0;j<I[0];j++){
151                                tmp[j].val = x[fss[i]*N+I[j+1]];
152                                tmp[j].class = y[j];
153                                tmp[j].idx = I[j+1];
154                                /* printf("   tmp[%d] = %f, idx=%d\n",j,tmp[j].val,tmp[j].idx); */
155                        }
156                        qsort((void *)tmp,I[0],sizeof(tmp[0]),compare_objs);
157                        /*for (j=0;j<I[0];j++){
158                                printf("   -> tmp[%d] = %f, idx=%d\n",j,tmp[j].val,tmp[j].idx);
159                        }*/
160                        /* make indices for the split */
161                        for (j=0;j<I[0];j++)
162                        {
163                                Ileft[j+1] = tmp[j].idx;
164                                Iright[I[0]-j] = tmp[j].idx;
165                        }
166/* for (k=1;k<=I[0];k++) printf("  Ileft[%d] = %d \n",k,Ileft[k]); */
167/* for (k=1;k<=I[0];k++) printf("  Iright[%d] = %d \n",k,Iright[k]); */
168/* if (nrnodes==3) return out; */
169                        /* run over all possible splits */
170                        for (j=1;j<I[0];j++)
171                        {
172/* printf("   split %d ",j); */
173                                Ileft[0]=j;  /* redefine the length of vector Ileft */
174/* for (k=1;k<=j;k++) printf("  Il[%d] = %d ",k,Ileft[k]); */
175/* printf(" -> gini left = %f\n",gini(Ileft)); */
176                                Iright[0]=I[0]-j;
177/* for (k=1;k<=I[0]-j;k++) printf("  Ir[%d] = %d ",k,Iright[k]); */
178/* printf("    gini right = %f\n",gini(Iright)); */
179                                err = j*gini(Ileft) + (I[0]-j)*gini(Iright);
180/* printf(" give err %f\n",err); */
181                                /* is this good? */
182                                if (err<besterr) {
183/* printf("   We have a better result! (%f<%f)\n",err,besterr); */
184/* printf("   Feature %d at %d ",fss[i],j); */
185                                        besterr = err;
186                                        bestsplit = j;
187                                        out->feat = fss[i];
188                                        out->thres = (tmp[j].val + tmp[j-1].val)/2;
189/* printf(" thres = %f\n",out->thres); */
190                                        for (k=0;k<=j;k++)
191                                                Ileftbest[k] = Ileft[k];
192                                        Ileftbest[0] = j;
193                                        for (k=0;k<=I[0]-j;k++)
194                                                Irightbest[k] = Iright[k];
195                                        Irightbest[0] = I[0]-j;
196                                }
197
198                        }
199
200                }
201/*printf("Finally, we use feature %d on split %d, threshold %f\n",
202                out->feat,bestsplit,out->thres);
203printf("Left objects:\n");
204for (k=1;k<=Ileftbest[0];k++)
205        printf("   Ileft[%d] = %d\n",k,Ileftbest[k]);
206printf("Right objects:\n");
207for (k=1;k<=Irightbest[0];k++)
208        printf("   Iright[%d] = %d\n",k,Irightbest[k]);*/
209
210                /* now find the children */
211                out->left = tree_train(Ileftbest);
212                out->right = tree_train(Irightbest);
213                       
214
215                free(Ileft);
216                free(Iright);
217                free(Ileftbest);
218                free(Irightbest);
219                free(tmp);
220                free(fss);
221        }
222        return out;
223}
224
225/* Store the tree in a matrix */
226/* Order of the variables:
227 * 1. class
228 * 2. feature
229 * 3. threshold
230 * 4. left branch index
231 * 5. right branch index */
232void tree_encode(dtree *tree,double *ptr)
233{
234        int thisindex = storeindex;
235
236/* printf("Store %d \n",thisindex); */
237
238        if (tree->class<0)  /* it is branching */
239        {
240/* printf("     : split feat %d at %f\n",tree->feat,tree->thres); */
241                *(ptr+5*thisindex)    = -1;  /* encode splitting */
242                *(ptr+5*thisindex+1) = tree->feat;
243                *(ptr+5*thisindex+2) = tree->thres;
244                storeindex += 1;
245                *(ptr+5*thisindex+3) = storeindex+1; /* Matlab indexing...*/
246                tree_encode(tree->left,ptr);
247                storeindex += 1;
248                *(ptr+5*thisindex+4) = storeindex+1;
249                tree_encode(tree->right,ptr);
250        }
251        else
252        {
253/* printf("     : class %d\n",tree->class); */
254                *(ptr+5*thisindex) = tree->class;
255                *(ptr+5*thisindex+1) = 0;
256                *(ptr+5*thisindex+2) = 0;
257                *(ptr+5*thisindex+3) = 0;
258                *(ptr+5*thisindex+4) = 0;
259        }
260}
261
262void destroy_tree(dtree *tree)
263{
264        if (tree->class<0)
265        {
266                destroy_tree(tree->left);
267                destroy_tree(tree->right);
268                free(tree);
269                /* printf("removed branch\n"); */
270        }
271        else
272        {
273                free(tree);
274                /* printf("removed leave\n"); */
275        }
276}
277
278int classify_data(double *T, int idx, int obj)
279{
280        int k;
281        int feat;
282        double thres;
283
284/*      printf("Obj x(%d): [",obj);
285        for (k=0;k<D;k++)
286                printf("%f, ",x[obj+k*N]);
287        printf("]\n"); */
288
289        if (*(T+5*idx)<0) /* branching */
290        {
291                feat = (int)(*(T+5*idx+1));
292                thres = *(T+5*idx+2);
293                /* printf(" is x[%d]=%f < %f? (x=%f)\n",feat, x[obj+feat*N],thres); */
294                if (x[obj+feat*N]<thres)
295                {
296                        /* printf("left branch\n"); */
297                        return classify_data(T,*(T+5*idx+3)-1,obj);
298                }
299                else
300                {
301                        /* printf("right branch\n"); */
302                        return classify_data(T,*(T+5*idx+4)-1,obj);
303                }
304        }
305        else
306                return *(T+5*idx);
307}
308
309
310
311/* GO! */
312void mexFunction(int nlhs, mxArray *plhs[],
313                 int nrhs, const mxArray *prhs[])
314{
315        int i;
316        int *I;     /* index vector */
317        dtree *tree;
318        double *T;
319        double *ptr;
320
321        /* Four inputs: train the tree */
322        /* We require four inputs, x,y, K and F */
323        if (nrhs==4) {
324                /* Get the input and check stuff */
325                x = mxGetPr(prhs[0]);
326                N = mxGetM(prhs[0]);
327                D = mxGetN(prhs[0]);
328                y = mxGetPr(prhs[1]);
329                if (mxGetM(prhs[1])!=N) {
330                        printf("ERROR: Size of Y does not fit with X.\n");
331                        return;
332                }
333                K = (int)(mxGetPr(prhs[2])[0]);
334                F = (int)(mxGetPr(prhs[3])[0]);
335                /* allocate */
336                tmp_p = (double *)malloc(K*sizeof(double)); /* for gini */
337
338                /* start the tree with all data: */
339                I = (int *)malloc((N+1)*sizeof(int));
340                I[0] = N;
341                for (i=0;i<N;i++) I[i+1] = i;
342
343                /* make the tree  */
344                nrnodes = 0;
345                tree = tree_train(I);
346
347                /* store results */
348/* printf("\n\n\nStore results, of %d nodes\n",nrnodes); */
349                plhs[0] = mxCreateNumericMatrix(5,nrnodes,mxDOUBLE_CLASS,mxREAL);
350                storeindex = 0;
351                tree_encode(tree,mxGetPr(plhs[0]));
352
353                /* clean up */
354                destroy_tree(tree);
355                free(I);
356                free(tmp_p);
357        }
358        /* Two inputs: evaluate the tree */
359        /* We require the encoded tree T and inputs x */
360        else if (nrhs==2) {
361                T = mxGetPr(prhs[0]);
362                nrnodes = mxGetN(prhs[0]);
363                x = mxGetPr(prhs[1]);
364                N = mxGetM(prhs[1]);
365                D = mxGetN(prhs[1]);
366
367                plhs[0] = mxCreateNumericMatrix(N,1,mxDOUBLE_CLASS,mxREAL);
368                ptr = mxGetPr(plhs[0]);
369                for (i=0;i<N;i++) *(ptr+i) = classify_data(T,0,i);
370        }
371        else
372        {
373                printf("ERROR: only 2 or 4 inputs allowed!\n");
374                return;
375        }
376}
377
378
Note: See TracBrowser for help on using the repository browser.