source: prextra/svmtrainc.m @ 128

Last change on this file since 128 was 101, checked in by bduin, 9 years ago
File size: 3.9 KB
RevLine 
[101]1%SVMTRAINC Stats Support Vector Classifier (Matlab Stats Toolbox)
2%
3%   W = SVMTRAINC(A,KERNEL,C,OPTTYPE)
4%   W = A*SVMTRAINC(KERNEL,C,OPTTYPE)
5%   D = B*W
6%
7% INPUT
8%   A         A PRTools dataset used fro training
9%   KERNEL  Untrained mapping to compute kernel by A*(A*KERNEL) during
10%           training, or B*(A*KERNEL) during testing with dataset B.
11%           Default: linear kernel (PROXM('p',1))
12%   C       Regularization ('boxconstraint' in SVMTRAIN)
13%   OPTTYPE Desired optimizer, 'SMO' (default) or 'QP'.
14%   B       PRTools dataset used for testing
15%
16% OUTPUT
17%   W       Mapping: Support Vector Classifier
18%   D       PRTools dataset with classification results
19%
20% DESCRIPTION
21% This is a PRTools interface to the support vector classifier SVMTRAIN
22% in Matlab's Stats toolbox. It is an alternative for STATSSVC.
23%
24% The evaluation of W in D = B*W makes use of the SVMCLASSIFY routine in
25% the Stats toolbox. This routine outputs just labels. Consequently the
26% classification matrix D has for every object the value one in the column
27% corresponding with a output label and zeros for all other columns. For
28% multi-class datasets the one-against-rest procedure is followed (MCLASSC)
29% which may results in object classifications with zeros in all columns
30% as well as multiple ones. A trained combiner may solve this, e.g.
31% W = A*(SVMTRAINC*LDC)
32%
33% Use STATSSVC for a standard PRTools evaluation of the classifier.
34%
35% SEE ALSO
36% DATASETS, MAPPINGS, STATSSVC, SVMTRAIN, SVMCLASSIFY, LDC
37
38% Copyright: R.P.W. Duin, r.p.w.duin@37steps.com
39
40
41function out = svmtrainc(varargin)
42
43  checktoolbox('stats_svmtrain');
44  mapname = 'StatsSVM';
45  argin = shiftargin(varargin,{'prmapping','char'});
46  argin = setdefaults(argin,[],proxm([],'p',1),1,'SMO');
47 
48  if mapping_task(argin,'definition')
49   
50   out = define_mapping(argin,'untrained',mapname);
51   
52  else
53    [a,kernel,C,opttype] = deal(argin{:});
54   
55    if ~(ismapping(kernel) && istrained(kernel)) % training
56      isdataset(a);
57      islabtype(a,'crisp');
58      a = testdatasize(a,'objects');
59   
60      % remove too small classes, escape in case no two classes are left
61      [a,m,k,c,lablist,L,out] = cleandset(a,1);
62      if ~isempty(out), return; end
63     
64      if c > 2                        % solve multi-class case by recursion
65        u = feval(mfilename,[],kernel);
66        out = mclassc(a,u);           % concatenation of one-against-rest
67        out = allclass(out,lablist,L);% handle with missing classes
68      else                            % two class case
69        labels = getlabels(a); 
70        ismapping(kernel);
71        isuntrained(kernel);
72        prkernel(kernel);          % make kernel mapping known to prkernel
73       
74        pp = prrmpath('stats','svmtrain');      % check / correct path
75        finishup = onCleanup(@() addpath(pp));  % restore afterwards
76        if strcmpi(opttype,'SMO')
77          ss = svmtrain(+a,labels,'kernel_function',@prkernel, ...
78               'boxconstraint',C);
79        elseif strcmpi(opttype,'QP')
80          ss = svmtrain(+a,labels,'kernel_function',@prkernel, ...
81               'method','qp', 'boxconstraint',C);
82        else
83          error('Unknown optimizer')
84        end
85        out = trained_mapping(a,{ss,kernel},getsize(a,3));
86        %out = cnormc(out,a);       % normalise outputs for confidences
87        out = setname(out,mapname);
88      end
89   
90    else                           % evaluation
91      w = kernel;                  % trained classifier
92      [ss,kernel] = getdata(w);    % get datafrom training
93      ismapping(kernel);
94      labout = svmclassify(ss,+a); % use stats toolbox for evaluation
95      nlabout = renumlab(labout,getlabels(w)); % label indices
96      out = zeros(size(a,1),2);    % construct classification matrix
97      out(sub2ind(size(out),[1:size(out,1)]',nlabout))= ones(size(a,1),1);
98      out = setdat(a,out,w);
99    end
100   
101  end
102 
103return
104 
105
106
107
108       
Note: See TracBrowser for help on using the repository browser.