source: prextra/lessc.m

Last change on this file was 104, checked in by dtax, 9 years ago

More robust thresholds.

File size: 3.1 KB
RevLine 
[5]1function [w,I] = lessc(x, C, ftype, include_bias)
2%LESSC Least Error in Sparse Subspaces classifier
3%
4%     W = LESSC(X, C, FTYPE, INCLUDE_BIAS)
5%
6% Train a linear classifier which also performs feature selection.
7% In this version we do:
8%                min \sum_i w_i + C*delta_i
9%          s.t. forall_i   w^T f(x_i) > 1 - delta_i
10%                       sum_i |w_i| = 1
11% where f(x_i) is in principle free, but as a start we use the nearest
12% mean idea:
13%
14%         f(x_i) = (x-mu2).^2 - (x-mu1).^2
15% See for further definitions of f(x_i) lessfx.
16%
17% Dxd   15-3-2004
18prtrace(mfilename);
19
20if (nargin < 4)
21        % To include a bias term in the model, we extend the number of features
22        % by one:
23        include_bias = 0;
24end
25if (nargin < 3)
26        prwarning(3,'Use default function fx.');
27        ftype = 1;
28end
29if length(ftype)>1
30        include_bias = ftype(2);
31        ftype = ftype(1);
32end
33if (nargin < 2)
34        prwarning(3,'C set to one');
35        C = 1;
36end
37if (nargin < 1) | (isempty(x))
38        w = mapping(mfilename,{C,ftype,include_bias});
39        w = setname(w,'LESS classifier');
40        return
41end
42
43       
44if ~ismapping(C)   % train the mapping
45
46        % Unpack the dataset.
47        islabtype(x,'crisp');
48        isvaldset(x,1,2); % at least 1 object per class, 2 classes
49        [m,k,c] = getsize(x);
50
51        if c == 2   % two-class classifier
52
53                % get -1/+1 labels:
54                nlab = getnlab(x);
55                y = 2*nlab-3;
56
57                % train and apply the feature mapping:
58                par = lessfx(ftype,x);
59                f = lessfx(par,x);
60
61                if (include_bias)
62                        f = [f ones(m,1)];
63                        k = k+1;
64                end
65
66                % In the LP formulation, we define the free parameter vector as:
67                %  [delta; w]
68                % setup the constraints:
69                yf = -repmat(y,1,k).*f;
70                % standard version when we have Ax<b  and  Aeq x = b;
71                A = [-eye(m)    -(+yf)];
72                b = -ones(m,1);
73                %Aeq = [zeros(1,m) ones(1,k)]; beq = 1;
74                %if (include_bias), Aeq(1,end)=0; end
75                Aeq = []; beq = [];
76                % function to optimize:
77                c = [repmat(C,1,m) ones(1,k)];
78                %c = [ones(1,m) repmat(C,1,k)];
79                if (include_bias), c(end) = 0; end
80                % upper and lower bounds:
81                lb = zeros(m+k,1);
82                if (include_bias), lb(end) = -inf; end
83                ub = repmat(inf,m+k,1);
84
85                % optimize
86                if (exist('glpkmex')==3)
87                        [out,dummy]=glpkmex(1,c',A,b,repmat('U',m,1),lb,[],repmat('C',m+k,1));
88                else
89                        out = linprog(c,A,b,Aeq,beq,lb,ub);
90                end
91                w = out((m+1):end);
92
93                % find out how many features are relevant:
94                if (include_bias)
[104]95                        I = find(abs(w(1:(end-1)))>1e-8);
[5]96                        nr = length(I);
97                else
[104]98                        I = find(abs(w)>1e-8);
[5]99                        nr = length(I);
100                end
101
102                % Store the classifier
103                W.extend = include_bias;
104                W.par = par;
105                W.w = w;
106                W.nr = nr;
107                w = mapping(mfilename,'trained',W,getlablist(x),size(x,2),2);
108                w = setname(w,'LESS classifier');
109       
110        else   % multi-class classifier:
111               
112                %error('Multiclass not implemented yet');
113                w = mclassc(x,mapping(mfilename,{C,ftype,include_bias}));
114                v = w.data{1}.data{1}.data.w;
115                for i=2:length(w.data)
116                        v = v + w.data{i}.data{1}.data.w;
117                end
118                I = find(abs(v)>0);
119               
120        end     
121else
122        % Evaluate the classifier on new data:
123        W = getdata(C);
124
125        % It is a simple linear classifier:
126        if (W.extend)
127                out = [lessfx(W.par,x) ones(size(x,1),1)]*W.w;
128        else
129                out = lessfx(W.par,x)*W.w;
130        end
131
132        % and put it nicely in a prtools dataset:
133        w = setdat(x,sigm([out -out]),C);
134
135end
136               
137return
Note: See TracBrowser for help on using the repository browser.