source: distools/private/auclpm.m @ 21

Last change on this file since 21 was 13, checked in by bduin, 14 years ago
File size: 6.4 KB
RevLine 
[13]1function [w,returnA,r2] = auclpm(x, C, rtype, par, unitnorm, usematlab)
2%AUCLPM Find linear mapping with optimized AUC
3%
4%    W = AUCLPM(X, C, RTYPE, PAR)
5%
6% Optimize the AUC on dataset X and reg. param. C. This is done by
7% finding the weights W for which the ordering of the objects mapped
8% onto the line defined by W, is optimal. That means that objects from
9% class +1 is always mapped above objects from the -1 class. This
10% results in a constraint for each (+1,-1) pair of objects. The number
11% of constraints therefore become very large.  The AUC constraints can
12% be subsampled in different ways:
13%
14%               RTYPE     PAR
15%         'full',     -   use all constraints
16%         'subs',     N   subsample just N constraints
17%         'subk',     k   subsample just k*#trainobj. constraints
18%         'knn'       k   use only the k nearest neighbors
19%         'xval'      N   subsample just N constraints and use the rest to
20%                 optimize C (this version can be very slow)
21%    'xvalk'     k   subsample k*#trainobj and use remaining constraints
22%                 to optimize C
23%    'kmeans'    k   use k-means clustering with k=PAR
24%    'randk'     subsample objects to get PAR*(Npos+Nneg) constraints
25%
26%    W = AUCLPM(X, C, RTYPE, PAR, UNITNORM)
27%
28% Finally, per default the difference vectors are normalized to unit
29% length. If you don't like that, set UNITNORM to 0.
30%
31% Default: RTYPE='subk'
32%          PAR  = 1.0;
33prtrace(mfilename);
34
35if (nargin < 6)
36        usematlab = 0;
37end
38if (nargin < 5)
39        unitnorm = 0;
40end
41if (nargin < 4)
42        par = 1.00;
43end
44if (nargin < 3)
45        rtype = 'subk';
46end
47if (nargin < 2)
48        prwarning(3,'Lambda set to ten');
49        C = 10;
50end
51
52if (nargin < 1) | (isempty(x))
53        % just a check...
54        if ~isa(C,'double')
55                error('Please check your input parameters: C should be a double.');
56        end
57        w = mapping(mfilename,{C,rtype,par,unitnorm,usematlab});
58        w = setname(w,defauclpmname(C,rtype,par,unitnorm));
59        return
60end
61
62if ~ismapping(C)   % train the mapping
63
64        % just a check...
65        if ~isa(C,'double')
66                error('Please check your input parameters: C should be a double.');
67        end
68        % Unpack the dataset.
69        islabtype(x,'crisp');
70        isvaldset(x,1,2); % at least 1 object per class, 2 classes
71        [n,k,c] = getsize(x);
72        % Check some values:
73        if par<=0
74                error('Parameter ''par'' should be larger than zero');
75        end
76
77        if c == 2  % two-class classifier:
78
79                labl = getlablist(x); dim = size(x,2);
80                % first create the target values (+1 and -1):
81                % make an exception for a one-class or mil dataset:
82                tnr = strmatch('target',labl);
83                if isempty(tnr) % no target class defined.
84                        tnr = strmatch('positive',labl);
85                        if isempty(tnr) % no positive class defined.
86                                % we just take the first class as target class:
87                                tnr = 1;
88                        end
89                end
90                y = 2*(getnlab(x)==tnr) - 1;
91                tlab = labl(tnr,:);
92
93                % makes the mapping much faster:
94                X = +x; clear x;
95
96                %---create A for optauc
97        rstate = rand('state');
98                seed = 0;
99                [A,Nxi,Aval] = createA(X,y,rtype,par,seed);
100        rand('state',rstate);
101                if unitnorm
102                        % normalize the length of A:
103                        lA = sqrt(sum(A.*A,2));
104                        lenn0 = find(lA~=0);  % when labels are flipped, terrible
105                                              % things can happen
106                        A(lenn0,:) = A(lenn0,:)./repmat(lA(lenn0,:),1,size(A,2));
107                        if ~isempty(Aval)
108                                % also normalize the length of Aval:
109                                lA = sqrt(sum(Aval.*Aval,2));
110                                lenn0 = find(lA~=0);
111                                Aval(lenn0,:) = Aval(lenn0,:)./repmat(lA(lenn0,:),1,size(Aval,2));
112                        end
113                end
114orgA = A;
115                % negative should be present for the constraints:
116                A = [A -A];
117                % take also care for the xi:
118                A = [A -speye(Nxi)];
119                %A = [A -eye(Nxi)];
120                %---create f
121                % NO, do this later, maybe we want to optimize it!
122                %f = [ones(2*k,1); repmat(C,Nxi,1)];
123                %--- generate b
124                b = -ones(Nxi,1);   % no zeros, otherwise we get w=0
125                    % the constraint is changed here to <=-1
126                %---lower bound constraints
127                lb = zeros(2*k+Nxi,1);
128
129                % should we run over a range of Cs?
130                if ~isempty(Aval)
131                        M = 25;
132                        xtr = zeros(M,1);
133                        xval = zeros(M,1);
134                        C = logspace(-3,3,M);
135         % run over all the Cs
136                        for i=1:length(C)
137                                %---create f again:
138                                f = [ones(2*k,1); repmat(C(i),Nxi,1)];
139                                %---solve linear program
140                                if (exist('glpk')>0) & ~usematlab
141                                        [z,fmin,status]=glpk(f,A,b,lb,[],repmat('U',Nxi,1),...
142                                                repmat('C',size(f,1),1),1);
143                                elseif (exist('glpkmex')>0) & ~usematlab
144                                        [z,fmin,status]=glpkmex(1,f,A,b,repmat('U',Nxi,1),...
145                                                lb,[],repmat('C',size(f,1),1));
146                                else
147                                        opts = optimset('Display','off','LargeScale','on','Diagnostics','off');
148                                        z = linprog(f,A,b,[],[],lb,[],[],opts);
149                                end
150                                constr = Aval*(z(1:k)-z(k+1:2*k));
151                                % the number of satisfied constraints (=AUC:)
152                                I = find(constr<-0);
153                                if ~isempty(I)
154                                        xval(i) = length(I)/size(constr,1);
155                                end
156                                % the number of satisfied constraints (=AUC:)
157                                constr = orgA*(z(1:k)-z(k+1:2*k));
158                                I = find(constr<-0);
159                                if ~isempty(I)
160                                        xtr(i) = length(I)/size(constr,1);
161                                end
162                        end
163if nargout>1
164        returnA = xval;
165        if nargout>2
166                r2 = xtr;
167        end
168end
169                        [minxval,mini] = max(xval);
170                        C = C(mini);
171                end
172                %---create f
173                f = [ones(2*k,1); repmat(C,Nxi,1)];
174                %---solve linear program
175                if (exist('glpkmex')>0) & ~usematlab
176                        prwarning(7,'Use glpkmex');
177                        param.msglev=0;
178                        [z,fmin,status,xtra]=glpkmex(1,f,A,b,repmat('U',Nxi,1),...
179                                lb,[],repmat('C',size(f,1),1),param);
180                        alpha = []; %xtra.lambda;
181                else
182                        [z,fmin,exitflag,outp,alpha] = linprog(f,A,b,[],[],lb);
183                end
184
185                %---extract parameters
186                u = z(1:k); u = u(:);
187                v = z(k+1:2*k); v = v(:);
188                zeta = z(2*k+1:2*k+Nxi); zeta = zeta(:);
189        else
190                error('Only a two-class classifier is implemented');
191        end
192        % now find out how sparse the result is:
193        rel = (abs(u-v)>1e-6);
194        nr = sum(rel);
195        if (nr==0)
196                error('None of the features is selected. Please make the C a bit larger.');
197        end
198       
199        % and store the results:
200        W.u = u-v; %the ultimate weights
201        W.alpha = alpha;
202        W.zeta = zeta;
203        W.nr = nr;
204        W.rel = rel;
205        W.C = C;
206        w = mapping(mfilename,'trained',W,tlab,dim,1);
207        w = setname(w,defauclpmname(C,rtype,par,unitnorm));
208       
209else
210        % Evaluate the classifier on new data:
211        W = getdata(C);
212        n = size(x,1);
213
214        % linear classifier:
215        out = x*W.u;
216
217        % and put it nicely in a prtools dataset:
218        % (I am not really sure what I should output, I decided to give a 1D
219        % output:)
220        w = setdat(x,out,C);
221
222end
223               
224return
225
226function cl_name = defauclpmname(C,rtype,par,unitnorm)
227% define the correct name:
228if unitnorm
229        cl_name = sprintf('AUClpm (C=%s, %s, k=%s) 1norm',num2str(C),rtype,num2str(par));
230else
231        cl_name = sprintf('AUClpm (C=%s, %s, k=%s)',num2str(C),rtype,num2str(par));
232end
233
234return
Note: See TracBrowser for help on using the repository browser.