[101] | 1 | %SVMTRAINC Stats Support Vector Classifier (Matlab Stats Toolbox)
|
---|
| 2 | %
|
---|
| 3 | % W = SVMTRAINC(A,KERNEL,C,OPTTYPE)
|
---|
| 4 | % W = A*SVMTRAINC(KERNEL,C,OPTTYPE)
|
---|
| 5 | % D = B*W
|
---|
| 6 | %
|
---|
| 7 | % INPUT
|
---|
| 8 | % A A PRTools dataset used fro training
|
---|
| 9 | % KERNEL Untrained mapping to compute kernel by A*(A*KERNEL) during
|
---|
| 10 | % training, or B*(A*KERNEL) during testing with dataset B.
|
---|
| 11 | % Default: linear kernel (PROXM('p',1))
|
---|
| 12 | % C Regularization ('boxconstraint' in SVMTRAIN)
|
---|
| 13 | % OPTTYPE Desired optimizer, 'SMO' (default) or 'QP'.
|
---|
| 14 | % B PRTools dataset used for testing
|
---|
| 15 | %
|
---|
| 16 | % OUTPUT
|
---|
| 17 | % W Mapping: Support Vector Classifier
|
---|
| 18 | % D PRTools dataset with classification results
|
---|
| 19 | %
|
---|
| 20 | % DESCRIPTION
|
---|
| 21 | % This is a PRTools interface to the support vector classifier SVMTRAIN
|
---|
| 22 | % in Matlab's Stats toolbox. It is an alternative for STATSSVC.
|
---|
| 23 | %
|
---|
| 24 | % The evaluation of W in D = B*W makes use of the SVMCLASSIFY routine in
|
---|
| 25 | % the Stats toolbox. This routine outputs just labels. Consequently the
|
---|
| 26 | % classification matrix D has for every object the value one in the column
|
---|
| 27 | % corresponding with a output label and zeros for all other columns. For
|
---|
| 28 | % multi-class datasets the one-against-rest procedure is followed (MCLASSC)
|
---|
| 29 | % which may results in object classifications with zeros in all columns
|
---|
| 30 | % as well as multiple ones. A trained combiner may solve this, e.g.
|
---|
| 31 | % W = A*(SVMTRAINC*LDC)
|
---|
| 32 | %
|
---|
| 33 | % Use STATSSVC for a standard PRTools evaluation of the classifier.
|
---|
| 34 | %
|
---|
| 35 | % SEE ALSO
|
---|
| 36 | % DATASETS, MAPPINGS, STATSSVC, SVMTRAIN, SVMCLASSIFY, LDC
|
---|
| 37 |
|
---|
| 38 | % Copyright: R.P.W. Duin, r.p.w.duin@37steps.com
|
---|
| 39 |
|
---|
| 40 |
|
---|
| 41 | function out = svmtrainc(varargin)
|
---|
| 42 |
|
---|
| 43 | checktoolbox('stats_svmtrain');
|
---|
| 44 | mapname = 'StatsSVM';
|
---|
| 45 | argin = shiftargin(varargin,{'prmapping','char'});
|
---|
| 46 | argin = setdefaults(argin,[],proxm([],'p',1),1,'SMO');
|
---|
| 47 |
|
---|
| 48 | if mapping_task(argin,'definition')
|
---|
| 49 |
|
---|
| 50 | out = define_mapping(argin,'untrained',mapname);
|
---|
| 51 |
|
---|
| 52 | else
|
---|
| 53 | [a,kernel,C,opttype] = deal(argin{:});
|
---|
| 54 |
|
---|
| 55 | if ~(ismapping(kernel) && istrained(kernel)) % training
|
---|
| 56 | isdataset(a);
|
---|
| 57 | islabtype(a,'crisp');
|
---|
| 58 | a = testdatasize(a,'objects');
|
---|
| 59 |
|
---|
| 60 | % remove too small classes, escape in case no two classes are left
|
---|
| 61 | [a,m,k,c,lablist,L,out] = cleandset(a,1);
|
---|
| 62 | if ~isempty(out), return; end
|
---|
| 63 |
|
---|
| 64 | if c > 2 % solve multi-class case by recursion
|
---|
| 65 | u = feval(mfilename,[],kernel);
|
---|
| 66 | out = mclassc(a,u); % concatenation of one-against-rest
|
---|
| 67 | out = allclass(out,lablist,L);% handle with missing classes
|
---|
| 68 | else % two class case
|
---|
| 69 | labels = getlabels(a);
|
---|
| 70 | ismapping(kernel);
|
---|
| 71 | isuntrained(kernel);
|
---|
| 72 | prkernel(kernel); % make kernel mapping known to prkernel
|
---|
| 73 |
|
---|
| 74 | pp = prrmpath('stats','svmtrain'); % check / correct path
|
---|
| 75 | finishup = onCleanup(@() addpath(pp)); % restore afterwards
|
---|
| 76 | if strcmpi(opttype,'SMO')
|
---|
| 77 | ss = svmtrain(+a,labels,'kernel_function',@prkernel, ...
|
---|
| 78 | 'boxconstraint',C);
|
---|
| 79 | elseif strcmpi(opttype,'QP')
|
---|
| 80 | ss = svmtrain(+a,labels,'kernel_function',@prkernel, ...
|
---|
| 81 | 'method','qp', 'boxconstraint',C);
|
---|
| 82 | else
|
---|
| 83 | error('Unknown optimizer')
|
---|
| 84 | end
|
---|
| 85 | out = trained_mapping(a,{ss,kernel},getsize(a,3));
|
---|
| 86 | %out = cnormc(out,a); % normalise outputs for confidences
|
---|
| 87 | out = setname(out,mapname);
|
---|
| 88 | end
|
---|
| 89 |
|
---|
| 90 | else % evaluation
|
---|
| 91 | w = kernel; % trained classifier
|
---|
| 92 | [ss,kernel] = getdata(w); % get datafrom training
|
---|
| 93 | ismapping(kernel);
|
---|
| 94 | labout = svmclassify(ss,+a); % use stats toolbox for evaluation
|
---|
| 95 | nlabout = renumlab(labout,getlabels(w)); % label indices
|
---|
| 96 | out = zeros(size(a,1),2); % construct classification matrix
|
---|
| 97 | out(sub2ind(size(out),[1:size(out,1)]',nlabout))= ones(size(a,1),1);
|
---|
| 98 | out = setdat(a,out,w);
|
---|
| 99 | end
|
---|
| 100 |
|
---|
| 101 | end
|
---|
| 102 |
|
---|
| 103 | return
|
---|
| 104 |
|
---|
| 105 |
|
---|
| 106 |
|
---|
| 107 |
|
---|
| 108 |
|
---|