source: prextra/randomforestc.m @ 160

Last change on this file since 160 was 109, checked in by bduin, 8 years ago
File size: 1.8 KB
Line 
1%
2%          W = RANDOMFORESTC(A,L,N)
3%
4% Train a decision forest on A, using L decision trees, each trained on
5% a bootstrapped version of dataset A. Each decison tree is using random
6% feature subsets of size N in each node.  When N=0, no feature subsets
7% are used.
8
9% C-code:  train:  101.3 s    110.4
10%          test:     0.1 s      0.1
11% matlab:  train: 1726.9 s
12%          test:   520.4 s
13function w = randomforestc(a,L,featsubset)
14
15if nargin<3
16        featsubset = 1;
17end
18if nargin<2
19        L = 50;
20end
21if nargin<1 || isempty(a)
22        w = prmapping(mfilename,{L,featsubset});
23        w = setname(w,'Random forest (L=%d)',L);
24        return
25end
26
27if ~ismapping(L)
28        isvaldfile(a,2,2); % at least 2 obj/class, 2 classes
29        opt = [];
30        [n,dim,opt.K] = getsize(a);
31        opt.featsubset = featsubset;
32        v = cell(L,1);
33        for i=1:L
34                [x,z] = gendat(a);
35        if exist('decisiontree')==3
36            v{i} = decisiontree(+x,getnlab(x),opt.K,opt.featsubset);
37        else
38        prwarning(1,'Slow Matlab code used for training Randomforest classifier')
39                v{i} = tree_train(+x,getnlab(x),opt);
40        end
41        end
42        w = prmapping(mfilename,'trained',v,getlablist(a),dim,opt.K);
43        w = setname(w,'Random forest (L=%d)',L);
44else
45        v = getdata(L);
46        n = size(a,1);  % nr objects
47        K = size(L,2);  % nr of classes
48        nrv = length(v); % nr of trees
49    out = zeros(n,K);
50    if exist('decisiontree')==3 && false
51        for j=1:nrv
52            I = decisiontree(v{j},+a);
53            out = out + accumarray([(1:n)' I],ones(n,1),[n K]);
54        end
55    else
56        % the old fashioned slow Matlab code
57        prwarning(1,'Slow Matlab code used for testing Randomforest classifier')
58        for i=1:n
59            x = +a(i,:);
60            for j=1:nrv
61                I = tree_eval(v{j},x);
62                out(i,I) = out(i,I)+1;
63            end
64        end
65        out = out./repmat(sum(out,2),1,K);
66    end
67        w = setdat(a,out,L);
68end
69return
70       
71       
Note: See TracBrowser for help on using the repository browser.