1 | % |
---|
2 | % W = RANDOMFORESTC(A,L,N) |
---|
3 | % |
---|
4 | % Train a decision forest on A, using L decision trees, each trained on |
---|
5 | % a bootstrapped version of dataset A. Each decison tree is using random |
---|
6 | % feature subsets of size N in each node. When N=0, no feature subsets |
---|
7 | % are used. |
---|
8 | |
---|
9 | % C-code: train: 101.3 s 110.4 |
---|
10 | % test: 0.1 s 0.1 |
---|
11 | % matlab: train: 1726.9 s |
---|
12 | % test: 520.4 s |
---|
13 | function w = randomforestc(a,L,featsubset) |
---|
14 | |
---|
15 | if nargin<3 |
---|
16 | featsubset = 1; |
---|
17 | end |
---|
18 | if nargin<2 |
---|
19 | L = 50; |
---|
20 | end |
---|
21 | if nargin<1 || isempty(a) |
---|
22 | w = prmapping(mfilename,{L,featsubset}); |
---|
23 | w = setname(w,'Random forest (L=%d)',L); |
---|
24 | return |
---|
25 | end |
---|
26 | |
---|
27 | if ~ismapping(L) |
---|
28 | isvaldfile(a,2,2); % at least 2 obj/class, 2 classes |
---|
29 | opt = []; |
---|
30 | [n,dim,opt.K] = getsize(a); |
---|
31 | opt.featsubset = featsubset; |
---|
32 | v = cell(L,1); |
---|
33 | for i=1:L |
---|
34 | [x,z] = gendat(a); |
---|
35 | if exist('decisiontree')==3 |
---|
36 | v{i} = decisiontree(+x,getnlab(x),opt.K,opt.featsubset); |
---|
37 | else |
---|
38 | prwarning(1,'Slow Matlab code used for training Randomforest classifier') |
---|
39 | v{i} = tree_train(+x,getnlab(x),opt); |
---|
40 | end |
---|
41 | end |
---|
42 | w = prmapping(mfilename,'trained',v,getlablist(a),dim,opt.K); |
---|
43 | w = setname(w,'Random forest (L=%d)',L); |
---|
44 | else |
---|
45 | v = getdata(L); |
---|
46 | n = size(a,1); % nr objects |
---|
47 | K = size(L,2); % nr of classes |
---|
48 | nrv = length(v); % nr of trees |
---|
49 | out = zeros(n,K); |
---|
50 | if exist('decisiontree')==3 && false |
---|
51 | for j=1:nrv |
---|
52 | I = decisiontree(v{j},+a); |
---|
53 | out = out + accumarray([(1:n)' I],ones(n,1),[n K]); |
---|
54 | end |
---|
55 | else |
---|
56 | % the old fashioned slow Matlab code |
---|
57 | prwarning(1,'Slow Matlab code used for testing Randomforest classifier') |
---|
58 | for i=1:n |
---|
59 | x = +a(i,:); |
---|
60 | for j=1:nrv |
---|
61 | I = tree_eval(v{j},x); |
---|
62 | out(i,I) = out(i,I)+1; |
---|
63 | end |
---|
64 | end |
---|
65 | out = out./repmat(sum(out,2),1,K); |
---|
66 | end |
---|
67 | w = setdat(a,out,L); |
---|
68 | end |
---|
69 | return |
---|
70 | |
---|
71 | |
---|