source: prextra/decisiontree.c @ 100

Last change on this file since 100 was 25, checked in by dtax, 13 years ago

Removed a bug in freeing a tmp variable...

File size: 9.0 KB
RevLine 
[23]1/* try to train a decision tree */
2/* I assume I have a dataset X, and labels y. Due to implementation
3 * simplification, I ask the number of classes as well (it may be that
4 * not all classes are available in y). I assume that y contains
5 * integers, from 1,2,...,K.  Also the size of the feature subset F
6 * should be given. */
7
8/* The data matrix should be nxd, where n is the number of objects, and
9 * d is the number of dimensions. */
10
11#include <stdlib.h>
12#include <stdio.h>
13#include <mex.h>
14
15/* first the tree structure */
16typedef struct dtree {
17        int class;                    /* predicted class */
18        int feat;                     /* feature to split */
19        double thres;                 /* threshold */
20        struct dtree *left, *right;   /* children */
21} dtree;
22
23/* then the data structure for sorting */
24typedef struct obj {
25        double val;
26        int class;
27        int idx;
28} obj;
29
30/* general variables (ok, global vars are bad, ok) */
31double *x;     /* data */
32double *y;     /* labels */
33size_t N,D;    /* nr objs, nr feats*/
34int K,F;       /* nr classes, nr subspaces */
35int nrnodes;   /* the size of the tree */
36double *tmp_p; /* for gini */
37int storeindex;/* for storing tree */
38
39/* All the stuff for sorting */
40int compare_objs(const void *a,const void *b)
41{
42        obj *obj_a = (obj *)a;
43        obj *obj_b = (obj *)b;
44
45        if (obj_a->val > obj_b->val)
46                return 1;
47        else
48                return -1;
49}
50int compare_doubles(const void *a,const void *b)
51{
52        double *obj_a = (double *)a;
53        double *obj_b = (double *)b;
54
55        if (obj_a > obj_b)
56                return 1;
57        else
58                return -1;
59}
60
61/* Gini */
62double gini(int *I)
63{
64        int i;
65        double out;
66
67        /* initialize to zero */
68        for (i=0;i<K;i++)
69                tmp_p[i] = 0;
70        /* count the occurance of each class */
71        /* index vector starts at 1, the class numbering as well... */
72        for (i=0;i<I[0];i++)
73                tmp_p[(int)y[I[i+1]]-1] +=1;
74        /* normalize and compute gini */
75        out = 0;
76        for (i=0;i<K;i++)
77                out = out + tmp_p[i]*(1-tmp_p[i]/I[0])/I[0];
78
79        return out;
80}
81
82/* make a tree */
83dtree *tree_train(int *I)
84{
85        double err,besterr;
86        dtree *out;
87        int *fss;
88        obj *tmp;
89        int i,j,k;
90        int bestsplit;
91        int *Ileft, *Iright, *Ileftbest, *Irightbest;
92
93        /* make the node */
94        out = (dtree *)malloc(sizeof(dtree));
95        nrnodes +=1;
96/* printf("Make NODE %d!\n",nrnodes); */
97/* printf("%d objects in this node\n",I[0]); */
98
99        /* is it good enough? */
100        err = gini(I);
101/* printf("   gini = %f\n",err); */
102        if (err==0)
103        {
104                /* leave is perfectly classified: return this */
105                out->class = y[I[1]];
106                out->feat = 0;
107                out->thres = 0;
108                out->left = NULL;
109                out->right = NULL;
110/* printf("   Node %d is leaf. Done\n",nrnodes); */
111                return out;
112        }
113        else
114        {
115                /* store illegal class number to show it is a branch */
116                out->class = -1;
117                /* what features to use? */
118                fss = (int *)malloc(D*sizeof(int));
119                if (F>0) {
120                        /* randomly permute feature indices */
121                        tmp = (obj *)malloc(D*sizeof(obj));
122                        for (i=0;i<D;i++)
123                        {
124                                tmp[i].val = random();
125                                tmp[i].idx = i;
126                                /* printf("tmp[%d]=%f,%d\n",i,tmp[i].val,tmp[i].idx); */
127                        }
128                        qsort(tmp,D,sizeof(tmp[0]),compare_objs);
129                        for (i=0;i<D;i++)
130                        {
131                                fss[i] = tmp[i].idx;
132                                /* printf("fss[%d]=%d\n",i,fss[i]); */
133                        }
[25]134                        free(tmp);
[23]135                }
136                else
[25]137                {
[23]138                        for (i=0;i<D;i++)
[25]139                        {
[23]140                                fss[i] = i;
[25]141                                /* printf("fss[%d]=%d\n",i,fss[i]); */
142                        }
143                        F = D;
144                }
[23]145
146                /* check each feature separately: */
147                besterr = 1e100;
148                tmp = (obj *)malloc(I[0]*sizeof(obj));
149                Ileft = (int *)malloc((I[0]+1)*sizeof(int));
150                Iright = (int *)malloc((I[0]+1)*sizeof(int));
151                Ileftbest = (int *)malloc((I[0]+1)*sizeof(int));
152                Irightbest = (int *)malloc((I[0]+1)*sizeof(int));
153                for (i=0;i<F;i++) {
154/* printf("Try feature %d:\n",fss[i]); */
155                        /* sort the data along feature fss[i] */
156                        for (j=0;j<I[0];j++){
157                                tmp[j].val = x[fss[i]*N+I[j+1]];
158                                tmp[j].class = y[j];
159                                tmp[j].idx = I[j+1];
160                                /* printf("   tmp[%d] = %f, idx=%d\n",j,tmp[j].val,tmp[j].idx); */
161                        }
162                        qsort((void *)tmp,I[0],sizeof(tmp[0]),compare_objs);
[25]163/* for (j=0;j<I[0];j++) printf("   -> tmp[%d] = %f, idx=%d\n",j,tmp[j].val,tmp[j].idx); */
[23]164                        /* make indices for the split */
165                        for (j=0;j<I[0];j++)
166                        {
167                                Ileft[j+1] = tmp[j].idx;
168                                Iright[I[0]-j] = tmp[j].idx;
169                        }
170/* for (k=1;k<=I[0];k++) printf("  Ileft[%d] = %d \n",k,Ileft[k]); */
171/* for (k=1;k<=I[0];k++) printf("  Iright[%d] = %d \n",k,Iright[k]); */
172/* if (nrnodes==3) return out; */
173                        /* run over all possible splits */
174                        for (j=1;j<I[0];j++)
175                        {
176/* printf("   split %d ",j); */
177                                Ileft[0]=j;  /* redefine the length of vector Ileft */
178/* for (k=1;k<=j;k++) printf("  Il[%d] = %d ",k,Ileft[k]); */
179/* printf(" -> gini left = %f\n",gini(Ileft)); */
180                                Iright[0]=I[0]-j;
181/* for (k=1;k<=I[0]-j;k++) printf("  Ir[%d] = %d ",k,Iright[k]); */
182/* printf("    gini right = %f\n",gini(Iright)); */
183                                err = j*gini(Ileft) + (I[0]-j)*gini(Iright);
184/* printf(" give err %f\n",err); */
185                                /* is this good? */
186                                if (err<besterr) {
187/* printf("   We have a better result! (%f<%f)\n",err,besterr); */
188/* printf("   Feature %d at %d ",fss[i],j); */
189                                        besterr = err;
190                                        bestsplit = j;
191                                        out->feat = fss[i];
192                                        out->thres = (tmp[j].val + tmp[j-1].val)/2;
193/* printf(" thres = %f\n",out->thres); */
194                                        for (k=0;k<=j;k++)
195                                                Ileftbest[k] = Ileft[k];
196                                        Ileftbest[0] = j;
197                                        for (k=0;k<=I[0]-j;k++)
198                                                Irightbest[k] = Iright[k];
199                                        Irightbest[0] = I[0]-j;
200                                }
201
202                        }
203
204                }
[25]205/* printf("Finally, we use feature %d on split %d, threshold %f\n",
206                out->feat,bestsplit,out->thres); */
207/*printf("Left objects:\n");
[23]208for (k=1;k<=Ileftbest[0];k++)
209        printf("   Ileft[%d] = %d\n",k,Ileftbest[k]);
210printf("Right objects:\n");
211for (k=1;k<=Irightbest[0];k++)
212        printf("   Iright[%d] = %d\n",k,Irightbest[k]);*/
213
214                /* now find the children */
215                out->left = tree_train(Ileftbest);
216                out->right = tree_train(Irightbest);
217                       
218
219                free(Ileft);
220                free(Iright);
221                free(Ileftbest);
222                free(Irightbest);
223                free(tmp);
224                free(fss);
225        }
226        return out;
227}
228
229/* Store the tree in a matrix */
230/* Order of the variables:
231 * 1. class
232 * 2. feature
233 * 3. threshold
234 * 4. left branch index
235 * 5. right branch index */
236void tree_encode(dtree *tree,double *ptr)
237{
238        int thisindex = storeindex;
239
240/* printf("Store %d \n",thisindex); */
241
242        if (tree->class<0)  /* it is branching */
243        {
244/* printf("     : split feat %d at %f\n",tree->feat,tree->thres); */
245                *(ptr+5*thisindex)    = -1;  /* encode splitting */
246                *(ptr+5*thisindex+1) = tree->feat;
247                *(ptr+5*thisindex+2) = tree->thres;
248                storeindex += 1;
249                *(ptr+5*thisindex+3) = storeindex+1; /* Matlab indexing...*/
250                tree_encode(tree->left,ptr);
251                storeindex += 1;
252                *(ptr+5*thisindex+4) = storeindex+1;
253                tree_encode(tree->right,ptr);
254        }
255        else
256        {
257/* printf("     : class %d\n",tree->class); */
258                *(ptr+5*thisindex) = tree->class;
259                *(ptr+5*thisindex+1) = 0;
260                *(ptr+5*thisindex+2) = 0;
261                *(ptr+5*thisindex+3) = 0;
262                *(ptr+5*thisindex+4) = 0;
263        }
264}
265
266void destroy_tree(dtree *tree)
267{
268        if (tree->class<0)
269        {
270                destroy_tree(tree->left);
271                destroy_tree(tree->right);
272                free(tree);
273                /* printf("removed branch\n"); */
274        }
275        else
276        {
277                free(tree);
278                /* printf("removed leave\n"); */
279        }
280}
281
282int classify_data(double *T, int idx, int obj)
283{
284        int k;
285        int feat;
286        double thres;
287
288/*      printf("Obj x(%d): [",obj);
289        for (k=0;k<D;k++)
290                printf("%f, ",x[obj+k*N]);
291        printf("]\n"); */
292
293        if (*(T+5*idx)<0) /* branching */
294        {
295                feat = (int)(*(T+5*idx+1));
296                thres = *(T+5*idx+2);
297                /* printf(" is x[%d]=%f < %f? (x=%f)\n",feat, x[obj+feat*N],thres); */
298                if (x[obj+feat*N]<thres)
299                {
300                        /* printf("left branch\n"); */
301                        return classify_data(T,*(T+5*idx+3)-1,obj);
302                }
303                else
304                {
305                        /* printf("right branch\n"); */
306                        return classify_data(T,*(T+5*idx+4)-1,obj);
307                }
308        }
309        else
310                return *(T+5*idx);
311}
312
313
314
315/* GO! */
316void mexFunction(int nlhs, mxArray *plhs[],
317                 int nrhs, const mxArray *prhs[])
318{
319        int i;
320        int *I;     /* index vector */
321        dtree *tree;
322        double *T;
323        double *ptr;
324
325        /* Four inputs: train the tree */
326        /* We require four inputs, x,y, K and F */
327        if (nrhs==4) {
328                /* Get the input and check stuff */
[25]329/* printf("get input and check\n"); */
[23]330                x = mxGetPr(prhs[0]);
331                N = mxGetM(prhs[0]);
332                D = mxGetN(prhs[0]);
333                y = mxGetPr(prhs[1]);
334                if (mxGetM(prhs[1])!=N) {
335                        printf("ERROR: Size of Y does not fit with X.\n");
336                        return;
337                }
338                K = (int)(mxGetPr(prhs[2])[0]);
339                F = (int)(mxGetPr(prhs[3])[0]);
[25]340/* printf("N=%d, D=%d, K=%d, F=%d\n",N,D,K,F); */
[23]341                /* allocate */
342                tmp_p = (double *)malloc(K*sizeof(double)); /* for gini */
343
344                /* start the tree with all data: */
345                I = (int *)malloc((N+1)*sizeof(int));
346                I[0] = N;
347                for (i=0;i<N;i++) I[i+1] = i;
348
349                /* make the tree  */
350                nrnodes = 0;
351                tree = tree_train(I);
352
353                /* store results */
354/* printf("\n\n\nStore results, of %d nodes\n",nrnodes); */
355                plhs[0] = mxCreateNumericMatrix(5,nrnodes,mxDOUBLE_CLASS,mxREAL);
356                storeindex = 0;
357                tree_encode(tree,mxGetPr(plhs[0]));
358
359                /* clean up */
360                destroy_tree(tree);
361                free(I);
362                free(tmp_p);
363        }
364        /* Two inputs: evaluate the tree */
365        /* We require the encoded tree T and inputs x */
366        else if (nrhs==2) {
367                T = mxGetPr(prhs[0]);
368                nrnodes = mxGetN(prhs[0]);
369                x = mxGetPr(prhs[1]);
370                N = mxGetM(prhs[1]);
371                D = mxGetN(prhs[1]);
372
373                plhs[0] = mxCreateNumericMatrix(N,1,mxDOUBLE_CLASS,mxREAL);
374                ptr = mxGetPr(plhs[0]);
375                for (i=0;i<N;i++) *(ptr+i) = classify_data(T,0,i);
376        }
377        else
378        {
379                printf("ERROR: only 2 or 4 inputs allowed!\n");
380                return;
381        }
382}
383
384
Note: See TracBrowser for help on using the repository browser.