1 | %SVMTRAINC Stats Support Vector Classifier (Matlab Stats Toolbox)
|
---|
2 | %
|
---|
3 | % W = SVMTRAINC(A,KERNEL,C,OPTTYPE)
|
---|
4 | % W = A*SVMTRAINC(KERNEL,C,OPTTYPE)
|
---|
5 | % D = B*W
|
---|
6 | %
|
---|
7 | % INPUT
|
---|
8 | % A A PRTools dataset used fro training
|
---|
9 | % KERNEL Untrained mapping to compute kernel by A*(A*KERNEL) during
|
---|
10 | % training, or B*(A*KERNEL) during testing with dataset B.
|
---|
11 | % Default: linear kernel (PROXM('p',1))
|
---|
12 | % C Regularization ('boxconstraint' in SVMTRAIN)
|
---|
13 | % OPTTYPE Desired optimizer, 'SMO' (default) or 'QP'.
|
---|
14 | % B PRTools dataset used for testing
|
---|
15 | %
|
---|
16 | % OUTPUT
|
---|
17 | % W Mapping: Support Vector Classifier
|
---|
18 | % D PRTools dataset with classification results
|
---|
19 | %
|
---|
20 | % DESCRIPTION
|
---|
21 | % This is a PRTools interface to the support vector classifier SVMTRAIN
|
---|
22 | % in Matlab's Stats toolbox. It is an alternative for STATSSVC.
|
---|
23 | %
|
---|
24 | % The evaluation of W in D = B*W makes use of the SVMCLASSIFY routine in
|
---|
25 | % the Stats toolbox. This routine outputs just labels. Consequently the
|
---|
26 | % classification matrix D has for every object the value one in the column
|
---|
27 | % corresponding with a output label and zeros for all other columns. For
|
---|
28 | % multi-class datasets the one-against-rest procedure is followed (MCLASSC)
|
---|
29 | % which may results in object classifications with zeros in all columns
|
---|
30 | % as well as multiple ones. A trained combiner may solve this, e.g.
|
---|
31 | % W = A*(SVMTRAINC*LDC)
|
---|
32 | %
|
---|
33 | % Use STATSSVC for a standard PRTools evaluation of the classifier.
|
---|
34 | %
|
---|
35 | % SEE ALSO
|
---|
36 | % DATASETS, MAPPINGS, STATSSVC, SVMTRAIN, SVMCLASSIFY, LDC
|
---|
37 |
|
---|
38 | % Copyright: R.P.W. Duin, r.p.w.duin@37steps.com
|
---|
39 |
|
---|
40 |
|
---|
41 | function out = svmtrainc(varargin)
|
---|
42 |
|
---|
43 | checktoolbox('stats_svmtrain');
|
---|
44 | mapname = 'StatsSVM';
|
---|
45 | argin = shiftargin(varargin,{'prmapping','char'});
|
---|
46 | argin = setdefaults(argin,[],proxm([],'p',1),1,'SMO');
|
---|
47 |
|
---|
48 | if mapping_task(argin,'definition')
|
---|
49 |
|
---|
50 | out = define_mapping(argin,'untrained',mapname);
|
---|
51 |
|
---|
52 | else
|
---|
53 | [a,kernel,C,opttype] = deal(argin{:});
|
---|
54 |
|
---|
55 | if ~(ismapping(kernel) && istrained(kernel)) % training
|
---|
56 | isdataset(a);
|
---|
57 | islabtype(a,'crisp');
|
---|
58 | a = testdatasize(a,'objects');
|
---|
59 |
|
---|
60 | % remove too small classes, escape in case no two classes are left
|
---|
61 | [a,m,k,c,lablist,L,out] = cleandset(a,1);
|
---|
62 | if ~isempty(out), return; end
|
---|
63 |
|
---|
64 | if c > 2 % solve multi-class case by recursion
|
---|
65 | u = feval(mfilename,[],kernel);
|
---|
66 | out = mclassc(a,u); % concatenation of one-against-rest
|
---|
67 | out = allclass(out,lablist,L);% handle with missing classes
|
---|
68 | else % two class case
|
---|
69 | labels = getlabels(a);
|
---|
70 | ismapping(kernel);
|
---|
71 | isuntrained(kernel);
|
---|
72 | prkernel(kernel); % make kernel mapping known to prkernel
|
---|
73 |
|
---|
74 | pp = prrmpath('stats','svmtrain'); % check / correct path
|
---|
75 | finishup = onCleanup(@() addpath(pp)); % restore afterwards
|
---|
76 | if strcmpi(opttype,'SMO')
|
---|
77 | ss = svmtrain(+a,labels,'kernel_function',@prkernel, ...
|
---|
78 | 'boxconstraint',C);
|
---|
79 | elseif strcmpi(opttype,'QP')
|
---|
80 | ss = svmtrain(+a,labels,'kernel_function',@prkernel, ...
|
---|
81 | 'method','qp', 'boxconstraint',C);
|
---|
82 | else
|
---|
83 | error('Unknown optimizer')
|
---|
84 | end
|
---|
85 | out = trained_mapping(a,{ss,kernel},getsize(a,3));
|
---|
86 | %out = cnormc(out,a); % normalise outputs for confidences
|
---|
87 | out = setname(out,mapname);
|
---|
88 | end
|
---|
89 |
|
---|
90 | else % evaluation
|
---|
91 | w = kernel; % trained classifier
|
---|
92 | [ss,kernel] = getdata(w); % get datafrom training
|
---|
93 | ismapping(kernel);
|
---|
94 | labout = svmclassify(ss,+a); % use stats toolbox for evaluation
|
---|
95 | nlabout = renumlab(labout,getlabels(w)); % label indices
|
---|
96 | out = zeros(size(a,1),2); % construct classification matrix
|
---|
97 | out(sub2ind(size(out),[1:size(out,1)]',nlabout))= ones(size(a,1),1);
|
---|
98 | out = setdat(a,out,w);
|
---|
99 | end
|
---|
100 |
|
---|
101 | end
|
---|
102 |
|
---|
103 | return
|
---|
104 |
|
---|
105 |
|
---|
106 |
|
---|
107 |
|
---|
108 |
|
---|