source: prextra/lpsvm.m @ 116

Last change on this file since 116 was 5, checked in by bduin, 14 years ago
File size: 7.2 KB
RevLine 
[5]1function [w,gamma,trainCorr, testCorr, cpu_time, nu]=lpsvm(A,d,k,nu,output,delta)
2% version 1.3
3% last revision: 07/07/03
4%===========================================================================
5% Usage: [w, gamma,trainCorr,testCorr,time,nu]=lpsvm(A,d,k,nu,output,delta);
6%
7% A and d are both required, everything else has a default
8% An example: [w gamma train test time nu] = lpsvm(A,d,10);
9%
10% Input parameters:
11%    A: Data points
12%    d: 1's or -1's
13%    k: way to divide the data set into test and training set
14%       if k = 0: simply run the algorithm without any correctness
15%         calculation, this is the default
16%       if k = 1: run the algorithm and calculate correctness on
17%         the whole data set
18%       if k = any value less than the # of rows in the data set:
19%         divide up the data set into test and training
20%         using k-fold method
21%       if k = # of rows in the data set: use the 'leave 1' method
22%
23%    output: 0 - no output, 1 - produce output, default is 0
24%    nu:             weighted parameter
25%                    -1 - easy estimation
26%                    0  - hard estimation
27%                    any other value - used as nu by the algorithm
28%                    default - 0
29%    delta:  default is 10^-3
30%===================================================================
31% Output parameters:
32%
33%       w:              the normal vector of the classifier
34%       gamma:          the threshold
35%       trainCorr:      training set correctness
36%       testCorr:       test set correctness
37%       cpu_time:       time elapsed
38%       nu:             estimated value (or specified value) of nu
39%==========================================================================
40
41if nargin<6
42delta=1e-3;
43end
44
45if nargin<5
46output=0;
47end
48
49if ((nargin<4)|(nu==0))
50     nu = EstNuLong(A,d);  % default is hard estimation
51elseif nu==-1  % easy estimation
52nu = EstNuShort(A,d);
53end
54
55if nargin<3
56k=0;
57end
58
59r=randperm(size(d,1));d=d(r,:);A=A(r,:);    % random permutation
60
61tic;
62
63trainCorr=0;
64testCorr=0;
65
66if k==0
67[w, gamma,iter] = core(A,d,nu,delta);
68cpu_time=toc;
69  if output==1
70fprintf(1,'\nNumber of Iterations: %d',iter);
71fprintf(1,'\nElapse time: %10.2f\n\n',cpu_time);
72  end
73  return
74end
75
76%if k==1 only training set correctness is calculated
77if k==1
78tic;
79[w, gamma,iter] = core(A,d,nu,delta);
80trainCorr = correctness(A,d,w,gamma);
81cpu_time = toc;
82  if output == 1
83fprintf(1,'\nTraining set correctness: %3.2f%%',trainCorr);
84fprintf(1,'\nNumber of Iterations: %d',iter);
85fprintf(1,'\nElapse time: %10.2f\n\n',cpu_time);
86  end
87  return
88end
89
90[sm sn]=size(A);
91accuIter = 0;
92
93cpu_time = 0;
94indx = [0:k];
95indx = floor(sm*indx/k);    %last row numbers for all 'segments'
96% split trainining set from test set
97for i = 1:k
98Ctest = []; dtest = [];Ctrain = []; dtrain = [];
99
100Ctest = A((indx(i)+1:indx(i+1)),:);
101dtest = d(indx(i)+1:indx(i+1));
102
103Ctrain = A(1:indx(i),:);
104Ctrain = [Ctrain;A(indx(i+1)+1:sm,:)];
105dtrain = [d(1:indx(i));d(indx(i+1)+1:sm,:)];
106tic;
107[w, gamma,iter] = core(Ctrain,dtrain,nu,delta);
108thisToc = toc;
109
110tmpTrainCorr = correctness(Ctrain,dtrain,w,gamma);
111tmpTestCorr = correctness(Ctest,dtest,w,gamma);
112
113 if output==1
114fprintf(1,'________________________________________________\n');
115fprintf(1,'Fold %d\n',i);
116fprintf(1,'Training set correctness: %3.2f%%\n',tmpTrainCorr);
117fprintf(1,'Testing set correctness: %3.2f%%\n',tmpTestCorr);
118fprintf(1,'Number of iterations: %d\n',iter);
119fprintf(1,'Elapse time: %10.2f\n',thisToc);
120end
121
122trainCorr = trainCorr + tmpTrainCorr;
123testCorr = testCorr + tmpTestCorr;
124accuIter = accuIter + iter; % accumulative iterations
125cpu_time = cpu_time + thisToc;
126
127end % end of for (looping through test sets)
128
129     trainCorr = trainCorr/k;
130     testCorr = testCorr/k;
131     cpu_time=cpu_time/k;
132
133if output == 1
134     fprintf(1,'==============================================');
135fprintf(1,'\nTraining set correctness: %3.2f%%',trainCorr);
136fprintf(1,'\nTesting set correctness: %3.2f%%',testCorr);
137fprintf(1,'\nAverage number of iterations: %d',accuIter/k);
138fprintf(1,'\nAverage cpu_time: %10.2f\n',cpu_time);
139end
140
141return;  % lpsvm function return
142
143%%%%%%%%%%% core calculation function %%%%%%%%%%%%%%%%%%%%%
144function [w, gamma, iter] = core(A,d,nu,delta);
145
146[m,n]=size(A);
147
148if m>=n
149[w,gamma,iter]=lpsvm_with_smw(A,d,nu,delta);
150else
151[w,gamma,iter]=lpsvm_without_smw(A,d,nu,delta);
152end
153
154return
155
156%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
157%   lpsvm when m>=n                                   %
158%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
159
160function [w,gamma,iter]=lpsvm_with_smw(A,d,nu,delta)
161%with SMW without armijo
162%parameters
163epsi=10^(-3);alpha=10^(2);tol=10^(-5);maxiter=100;
164[m,n]=size(A);
165en=ones(n,1);
166em=ones(m,1);
167
168% initial u
169u=ones(m,1);iter=0;
170epsi=epsi*em;nu=nu*em;
171diff=1;
172DA=spdiags(d,0,m,m)*A;
173while (diff>tol) & (iter<maxiter)
174    uold=u;
175    iter=iter+1;
176    du=d.*u;Adu=A'*du;
177    pp=max(Adu-en,0);np=max(-Adu-en,0);
178    dd=sum(du)*d;unu=max(u-nu,0);uu=max(-u,0);
179    %Gradient
180    g=-epsi+(d.*(A*pp))-(d.*(A*np))+dd+unu-alpha*uu;
181    %Hessian
182    E=spdiags(sqrt(sign(np)+sign(pp)),0,n,n);
183    H=[DA*E d];
184    f=1./(delta+sign(unu)+alpha*sign(uu));
185    F=spdiags(f,0,m,m);gg=f.*g;HT=H';
186    di=(eye(n+1)+HT*(F*H))\(HT*gg);
187    di=H*di;di=f.*di;di=-gg+di;u=u+di;
188    diff=norm(g);
189end
190
191w=1/epsi(1)*(pp-np);
192gamma=-(1/epsi(1))*sum(du);
193iter;
194return
195
196
197
198%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
199%   lpsvm when m<n                                    %
200%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
201function [w,gamma,iter]=lpsvm_without_smw(A,d,nu,delta)
202%without sherman and without armijo
203%parameters
204epsi=10^(-1);alpha=10^3;tol=10^(-3);maxiter=50;
205[m,n]=size(A);
206en=ones(n,1);
207em=ones(m,1);
208%initial u
209u=ones(m,1);iter=0;
210epsi=epsi*em;nu=nu*em;
211diff=1;
212DA=spdiags(d,0,m,m)*A;
213while (diff>tol) & (iter<maxiter)
214    uold=u;
215    iter=iter+1;
216    du=d.*u;Adu=A'*du;
217    pp=max(Adu-en,0);np=max(-Adu-en,0);
218    dd=sum(du)*d;unu=max(u-nu,0);uu=max(-u,0);
219    %Gradient
220    g=-epsi+(d.*(A*pp))-(d.*(A*np))+dd+unu-alpha*uu;
221    %Hessian
222    E=spdiags(sqrt(sign(np)+sign(pp)),0,n,n);
223    H=[DA*E d];
224    F=spdiags(delta+sign(unu)+alpha*sign(uu),0,m,m);
225    di=-((H*H'+F)\g);
226    u=u+di;
227    diff=norm(g);
228end
229du=d.*u;Adu=A'*du;
230pp=max(Adu-en,0);np=max(-Adu-en,0);
231w=1/epsi(1)*(pp-np);gamma=-(1/epsi(1))*sum(du);
232return
233
234%%%%%%%%%%%%%%%% correctness calculation %%%%%%%%%%%%%%%%
235
236function corr = correctness(AA,dd,w,gamma)
237
238p=sign(AA*w-gamma);
239corr=length(find(p==dd))/size(AA,1)*100;
240return
241
242%%%%%%%%%%%%%%EstNuLong%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
243% hard way to estimate nu if not specified by the user
244function value = EstNuLong(C,d)
245
246[m,n]=size(C);e=ones(m,1);
247H=[C -e];
248if m<201
249H2=H;d2=d;
250else
251r=rand(m,1);
252[s1,s2]=sort(r);
253H2=H(s2(1:200),:);
254d2=d(s2(1:200));
255end
256
257lamda=1;
258[vu,u]=eig(H2*H2');u=diag(u);p=length(u);
259yt=d2'*vu;
260lamdaO=lamda+1;
261
262cnt=0;
263while (abs(lamdaO-lamda)>10e-4) &(cnt<100)
264     cnt=cnt+1;
265     nu1=0;pr=0;ee=0;waw=0;
266     lamdaO=lamda;
267     for i=1:p
268     nu1= nu1 + lamda/(u(i)+lamda);
269pr= pr + u(i)/(u(i)+lamda)^2;
270ee= ee + u(i)*yt(i)^2/(u(i)+lamda)^3;
271waw= waw + lamda^2*yt(i)^2/(u(i)+lamda)^2;
272   end
273lamda=nu1*ee/(pr*waw);
274end
275
276value = lamda;
277if cnt==100
278    value=1;
279end
280
281return
282
283%%%%%%%%%%%%%%%%%EstNuShort%%%%%%%%%%%%%%%%%%%%%%%
284
285% easy way to estimate nu if not specified by the user
286function value = EstNuShort(C,d)
287
288value = 1/(sum(sum(C.^2))/size(C,2));
289return
Note: See TracBrowser for help on using the repository browser.