[115] | 1 | function [B, ll] = randomforestc2(A, W)
|
---|
| 2 | %RANDOMFORESTC Random forest Classifier
|
---|
| 3 | %
|
---|
| 4 | % W = RANDOMFORESTC2(A, L)
|
---|
| 5 | %
|
---|
| 6 | % INPUT
|
---|
| 7 | % A Dataset
|
---|
| 8 | % L Number of trees
|
---|
| 9 | %
|
---|
| 10 | % OUTPUT
|
---|
| 11 | % B Random forest classifier
|
---|
| 12 | %
|
---|
| 13 | % DESCRIPTION
|
---|
| 14 | %
|
---|
| 15 | %
|
---|
| 16 |
|
---|
| 17 | name = 'randomforestc2';
|
---|
| 18 | if ~exist('classRF_train', 'file')
|
---|
| 19 | error('Please download code from http://code.google.com/p/randomforest-matlab/');
|
---|
| 20 | end
|
---|
| 21 |
|
---|
| 22 | % Handle untrained calls like W = randomforestc2([]);
|
---|
| 23 | if nargin == 0 || isempty(A)
|
---|
| 24 | if nargin < 2
|
---|
| 25 | W = 100;
|
---|
| 26 | end
|
---|
| 27 | name=['randomforestc2_' num2str(W)];
|
---|
| 28 |
|
---|
| 29 | B = prmapping(mfilename);
|
---|
| 30 | B = setname(B, name);
|
---|
| 31 | return;
|
---|
| 32 |
|
---|
| 33 | % Handle training on dataset A (use A * loglc2, A * loglc2([]), and loglc2(A))
|
---|
| 34 | elseif (nargin == 1 && isa(A, 'prdataset')) || (isa(A, 'prdataset') && isa(W, 'double'))
|
---|
| 35 | if nargin < 2
|
---|
| 36 | W = 100;
|
---|
| 37 | end
|
---|
| 38 | name=['randomforestc2_' num2str(W)];
|
---|
| 39 |
|
---|
| 40 | islabtype(A, 'crisp');
|
---|
| 41 | isvaldfile(A, 1, 2);
|
---|
| 42 | A = testdatasize(A, 'features');
|
---|
| 43 | A = setprior(A, getprior(A));
|
---|
| 44 | [~, k, c] = getsize(A);
|
---|
| 45 |
|
---|
| 46 | % Train the logistic regressor
|
---|
| 47 | rf = classRF_train(+A, getnlab(A), W);
|
---|
| 48 |
|
---|
| 49 | B = prmapping(mfilename, 'trained', rf, getlablist(A), k, c);
|
---|
| 50 | B = setname(B, name);
|
---|
| 51 |
|
---|
| 52 | % Handle evaluation of a trained RANDOMFORESTC2 W for a dataset A
|
---|
| 53 | elseif (isa(A, 'prdataset') && isa(W, 'prmapping')) || (isa(A, 'double') && isa(W, 'prmapping'))
|
---|
| 54 |
|
---|
| 55 | % Evaluate
|
---|
| 56 | [~, votes] = classRF_predict(+A, W.data);
|
---|
| 57 |
|
---|
| 58 | A = prdataset(A);
|
---|
| 59 | B = setdata(A, votes, getlabels(W));
|
---|
| 60 | ll = [];
|
---|
| 61 |
|
---|
| 62 | % This should not happen
|
---|
| 63 | else
|
---|
| 64 | error('Illegal call');
|
---|
| 65 | end
|
---|
| 66 | end
|
---|