source: prextra/randomforestc.m @ 82

Last change on this file since 82 was 23, checked in by dtax, 13 years ago

The decision tree and random forest, with compiled code!

File size: 1.7 KB
Line 
1%
2%          W = RANDOMFORESTC(A,L,N)
3%
4% Train a decision forest on A, using L decision trees, each trained on
5% a bootstrapped version of dataset A. Each decison tree is using random
6% feature subsets of size N in each node.  When N=0, no feature subsets
7% are used.
8
9% C-code:  train:  101.3 s    110.4
10%          test:     0.1 s      0.1
11% matlab:  train: 1726.9 s
12%          test:   520.4 s
13function w = randomforestc(a,L,featsubset)
14
15if nargin<3
16        featsubset = 1;
17end
18if nargin<2
19        L = 50;
20end
21if nargin<1 || isempty(a)
22        w = mapping(mfilename,{L,featsubset});
23        w = setname(w,'Random forest (L=%d)',L);
24        return
25end
26
27if ~ismapping(L)
28        isvaldfile(a,2,2); % at least 2 obj/class, 2 classes
29        opt = [];
30        [n,dim,opt.K] = getsize(a);
31        opt.featsubset = featsubset;
32        v = cell(L,1);
33        for i=1:L
34                [x,z] = gendat(a);
35        if exist('decisiontree')==3
36            v{i} = decisiontree(+x,getnlab(x),opt.K,opt.featsubset);
37        else
38                v{i} = tree_train(+x,getnlab(x),opt);
39        end
40        end
41        w = mapping(mfilename,'trained',v,getlablist(a),dim,opt.K);
42        w = setname(w,'Random forest (L=%d)',L);
43else
44        v = getdata(L);
45        n = size(a,1);  % nr objects
46        K = size(L,2);  % nr of classes
47        nrv = length(v); % nr of trees
48    out = zeros(n,K);
49    if exist('decisiontree')==3
50        for j=1:nrv
51            I = decisiontree(v{j},+a);
52            out = out + accumarray([(1:n)' I],ones(n,1));
53        end
54    else
55        % the old fashioned slow Matlab code
56        for i=1:n
57            x = +a(i,:);
58            for j=1:nrv
59                I = tree_eval(v{j},x);
60                out(i,I) = out(i,I)+1;
61            end
62        end
63        out = out./repmat(sum(out,2),1,K);
64    end
65        w = setdat(a,out,L);
66end
67return
68       
69       
Note: See TracBrowser for help on using the repository browser.