[23] | 1 | % |
---|
| 2 | % W = RANDOMFORESTC(A,L,N) |
---|
| 3 | % |
---|
| 4 | % Train a decision forest on A, using L decision trees, each trained on |
---|
| 5 | % a bootstrapped version of dataset A. Each decison tree is using random |
---|
| 6 | % feature subsets of size N in each node. When N=0, no feature subsets |
---|
| 7 | % are used. |
---|
| 8 | |
---|
| 9 | % C-code: train: 101.3 s 110.4 |
---|
| 10 | % test: 0.1 s 0.1 |
---|
| 11 | % matlab: train: 1726.9 s |
---|
| 12 | % test: 520.4 s |
---|
| 13 | function w = randomforestc(a,L,featsubset) |
---|
| 14 | |
---|
| 15 | if nargin<3 |
---|
| 16 | featsubset = 1; |
---|
| 17 | end |
---|
| 18 | if nargin<2 |
---|
| 19 | L = 50; |
---|
| 20 | end |
---|
| 21 | if nargin<1 || isempty(a) |
---|
[102] | 22 | w = prmapping(mfilename,{L,featsubset}); |
---|
[23] | 23 | w = setname(w,'Random forest (L=%d)',L); |
---|
| 24 | return |
---|
| 25 | end |
---|
| 26 | |
---|
| 27 | if ~ismapping(L) |
---|
| 28 | isvaldfile(a,2,2); % at least 2 obj/class, 2 classes |
---|
| 29 | opt = []; |
---|
| 30 | [n,dim,opt.K] = getsize(a); |
---|
| 31 | opt.featsubset = featsubset; |
---|
| 32 | v = cell(L,1); |
---|
| 33 | for i=1:L |
---|
| 34 | [x,z] = gendat(a); |
---|
| 35 | if exist('decisiontree')==3 |
---|
| 36 | v{i} = decisiontree(+x,getnlab(x),opt.K,opt.featsubset); |
---|
| 37 | else |
---|
[109] | 38 | prwarning(1,'Slow Matlab code used for training Randomforest classifier') |
---|
[23] | 39 | v{i} = tree_train(+x,getnlab(x),opt); |
---|
| 40 | end |
---|
| 41 | end |
---|
[102] | 42 | w = prmapping(mfilename,'trained',v,getlablist(a),dim,opt.K); |
---|
[23] | 43 | w = setname(w,'Random forest (L=%d)',L); |
---|
| 44 | else |
---|
| 45 | v = getdata(L); |
---|
| 46 | n = size(a,1); % nr objects |
---|
| 47 | K = size(L,2); % nr of classes |
---|
| 48 | nrv = length(v); % nr of trees |
---|
| 49 | out = zeros(n,K); |
---|
[109] | 50 | if exist('decisiontree')==3 && false |
---|
[23] | 51 | for j=1:nrv |
---|
| 52 | I = decisiontree(v{j},+a); |
---|
[102] | 53 | out = out + accumarray([(1:n)' I],ones(n,1),[n K]); |
---|
[23] | 54 | end |
---|
| 55 | else |
---|
| 56 | % the old fashioned slow Matlab code |
---|
[109] | 57 | prwarning(1,'Slow Matlab code used for testing Randomforest classifier') |
---|
[23] | 58 | for i=1:n |
---|
| 59 | x = +a(i,:); |
---|
| 60 | for j=1:nrv |
---|
| 61 | I = tree_eval(v{j},x); |
---|
| 62 | out(i,I) = out(i,I)+1; |
---|
| 63 | end |
---|
| 64 | end |
---|
| 65 | out = out./repmat(sum(out,2),1,K); |
---|
| 66 | end |
---|
| 67 | w = setdat(a,out,L); |
---|
| 68 | end |
---|
| 69 | return |
---|
| 70 | |
---|
| 71 | |
---|