source: prextra/auclpm.m @ 70

Last change on this file since 70 was 5, checked in by bduin, 14 years ago
File size: 6.1 KB
RevLine 
[5]1function [w,returnA,r2] = auclpm(x, C, rtype, par, unitnorm, usematlab)
2%AUCLPM Find linear mapping with optimized AUC
3%
4%    W = AUCLPM(X, C, RTYPE, PAR)
5%
6% Optimize the AUC on dataset X and reg. param. C. This is done by
7% finding the weights W for which the ordering of the objects mapped
8% onto the line defined by W, is optimal. That means that objects from
9% class +1 is always mapped above objects from the -1 class. This
10% results in a constraint for each (+1,-1) pair of objects. The number
11% of constraints therefore become very large.  The AUC constraints can
12% be subsampled in different ways:
13%
14%               RTYPE     PAR
15%         'full',     -   use all constraints
16%         'subs',     N   subsample just N constraints
17%         'subk',     k   subsample just k*#trainobj. constraints
18%         'knn'       k   use only the k nearest neighbors
19%         'xval'      N   subsample just N constraints and use the rest to
20%                 optimize C (this version can be very slow)
21%    'xvalk'     k   subsample k*#trainobj and use remaining constraints
22%                 to optimize C
23%    'kmeans'    k   use k-means clustering with k=PAR
24%    'randk'     subsample objects to get PAR*(Npos+Nneg) constraints
25%
26%    W = AUCLPM(X, C, RTYPE, PAR, UNITNORM)
27%
28% Finally, per default the difference vectors are normalized to unit
29% length. If you don't like that, set UNITNORM to 0.
30%
31% Default: RTYPE='subk'
32%          PAR  = 1.0;
33prtrace(mfilename);
34
35if (nargin < 6)
36        usematlab = 0;
37end
38if (nargin < 5)
39        unitnorm = 1;
40end
41if (nargin < 4)
42        par = 1.00;
43end
44if (nargin < 3)
45        rtype = 'subk';
46end
47if (nargin < 2)
48        prwarning(3,'Lambda set to ten');
49        C = 10;
50end
51
52% define the correct name:
53if unitnorm
54        cl_name = sprintf('AUC-LP %s (%s) 1norm',rtype,num2str(par));
55else
56        cl_name = sprintf('AUC-LP %s (%s)',rtype,num2str(par));
57end
58if (nargin < 1) | (isempty(x))
59        w = mapping(mfilename,{C,rtype,par,unitnorm,usematlab});
60        w = setname(w,cl_name);
61        return
62end
63
64if ~ismapping(C)   % train the mapping
65
66        % Unpack the dataset.
67        islabtype(x,'crisp');
68        isvaldset(x,1,2); % at least 1 object per class, 2 classes
69        [n,k,c] = getsize(x);
70        % Check some values:
71        if par<=0
72                error('Parameter ''par'' should be larger than zero');
73        end
74
75        if c == 2  % two-class classifier:
76
77                labl = getlablist(x); dim = size(x,2);
78                % first create the target values (+1 and -1):
79                % make an exception for a one-class or mil dataset:
80                tnr = strmatch('target',labl);
81                if isempty(tnr) % no target class defined.
82                        tnr = strmatch('positive',labl);
83                        if isempty(tnr) % no positive class defined.
84                                % we just take the first class as target class:
85                                tnr = 1;
86                        end
87                end
88                y = 2*(getnlab(x)==tnr) - 1;
89                tlab = labl(tnr,:);
90
91                % makes the mapping much faster:
92                X = +x; clear x;
93
94                %---create A for optauc
95        rstate = rand('state');
96                seed = 0;
97                [A,Nxi,Aval] = createA(X,y,rtype,par,seed);
98        rand('state',rstate);
99                if unitnorm
100                        % normalize the length of A:
101                        lA = sqrt(sum(A.*A,2));
102                        lenn0 = find(lA~=0);  % when labels are flipped, terrible
103                                              % things can happen
104                        A(lenn0,:) = A(lenn0,:)./repmat(lA(lenn0,:),1,size(A,2));
105                        if ~isempty(Aval)
106                                % also normalize the length of Aval:
107                                lA = sqrt(sum(Aval.*Aval,2));
108                                lenn0 = find(lA~=0);
109                                Aval(lenn0,:) = Aval(lenn0,:)./repmat(lA(lenn0,:),1,size(Aval,2));
110                        end
111                end
112%if nargout>1
113%       returnA = A;
114%end
115orgA = A;
116                % negative should be present for the constraints:
117                A = [A -A];
118                % take also care for the xi:
119                A = [A -speye(Nxi)];
120                %A = [A -eye(Nxi)];
121                %---create f
122                % NO, do this later, maybe we want to optimize it!
123                %f = [ones(2*k,1); repmat(C,Nxi,1)];
124                %--- generate b
125                b = -ones(Nxi,1);   % no zeros, otherwise we get w=0
126                    % the constraint is changed here to <=-1
127                %---lower bound constraints
128                lb = zeros(2*k+Nxi,1);
129
130                % should we run over a range of Cs?
131                if ~isempty(Aval)
132                        M = 25;
133xtr = zeros(M,1);
134                        xval = zeros(M,1);
135                        C = logspace(-3,3,M);
136         % run over all the Cs
137                        for i=1:length(C)
138                                %---create f again:
139                                f = [ones(2*k,1); repmat(C(i),Nxi,1)];
140                                %---solve linear program
141                                if (exist('glpk')>0) & ~usematlab
142                                        [z,fmin,status]=glpk(f,A,b,lb,[],repmat('U',Nxi,1),...
143                                                repmat('C',size(f,1),1),1);
144                                elseif (exist('glpkmex')>0) & ~usematlab
145                                        [z,fmin,status]=glpkmex(1,f,A,b,repmat('U',Nxi,1),...
146                                                lb,[],repmat('C',size(f,1),1));
147                                else
148                                        opts = optimset('Display','off','LargeScale','on','Diagnostics','off');
149                                        z = linprog(f,A,b,[],[],lb,[],[],opts);
150                                end
151                                constr = Aval*(z(1:k)-z(k+1:2*k));
152                                % the number of satisfied constraints (=AUC:)
153                                I = find(constr<-0);
154                                if ~isempty(I)
155                                        xval(i) = length(I)/size(constr,1);
156                                end
157constr = orgA*(z(1:k)-z(k+1:2*k));
158% the number of satisfied constraints (=AUC:)
159I = find(constr<-0);
160if ~isempty(I)
161        xtr(i) = length(I)/size(constr,1);
162end
163                        end
164if nargout>1
165        returnA = xval;
166        if nargout>2
167                r2 = xtr;
168        end
169end
170                        [minxval,mini] = max(xval);
171                        C = C(mini);
172                        message(4,'Optimum C = %f\n',C);
173                end
174                %---create f
175                f = [ones(2*k,1); repmat(C,Nxi,1)];
176                %---solve linear program
177                if (exist('glpkmex')>0) & ~usematlab
178                        prwarning(7,'Use glpkmex');
179                        param.msglev=0;
180                        [z,fmin,status,xtra]=glpkmex(1,f,A,b,repmat('U',Nxi,1),...
181                                lb,[],repmat('C',size(f,1),1),param);
182                        alpha = xtra.lambda;
183                else
184                        [z,fmin,exitflag,outp,alpha] = linprog(f,A,b,[],[],lb);
185                end
186
187                %---extract parameters
188                u = z(1:k); u = u(:);
189                v = z(k+1:2*k); v = v(:);
190                zeta = z(2*k+1:2*k+Nxi); zeta = zeta(:);
191%if nargout>1
192%       returnA = zeta;
193%end
194        else
195                error('Only a two-class classifier is implemented');
196        end
197        % now find out how sparse the result is:
198        %nr = sum(beta>1e-6);
199        rel = (abs(u-v)>0);
200        nr = sum(rel);
201       
202        % and store the results:
203        %W.wsc = wsc;
204        W.u = u; %the ultimate weights
205        W.v = v;
206        W.alpha = alpha;
207        W.zeta = zeta;
208        W.nr = nr;
209        W.rel = rel;
210        W.C = C;
211        w = mapping(mfilename,'trained',W,tlab,dim,1);
212        w = setname(w,sprintf('%s %s',cl_name,rtype));
213       
214else
215        % Evaluate the classifier on new data:
216        W = getdata(C);
217        n = size(x,1);
218
219        % linear classifier:
220        out = x*(W.u-W.v);
221
222        % and put it nicely in a prtools dataset:
223        % (I am not really sure what I should output, I decided to give a 1D
224        % output:)
225        w = setdat(x,out,C);
226        %w = setdat(x,[-out out],C);
227        %w = setdat(x,sigm([-out out]),C);
228
229end
230               
231return
Note: See TracBrowser for help on using the repository browser.