1 | function [w,gamma,trainCorr, testCorr, cpu_time, nu]=lpsvm(A,d,k,nu,output,delta) |
---|
2 | % version 1.3 |
---|
3 | % last revision: 07/07/03 |
---|
4 | %=========================================================================== |
---|
5 | % Usage: [w, gamma,trainCorr,testCorr,time,nu]=lpsvm(A,d,k,nu,output,delta); |
---|
6 | % |
---|
7 | % A and d are both required, everything else has a default |
---|
8 | % An example: [w gamma train test time nu] = lpsvm(A,d,10); |
---|
9 | % |
---|
10 | % Input parameters: |
---|
11 | % A: Data points |
---|
12 | % d: 1's or -1's |
---|
13 | % k: way to divide the data set into test and training set |
---|
14 | % if k = 0: simply run the algorithm without any correctness |
---|
15 | % calculation, this is the default |
---|
16 | % if k = 1: run the algorithm and calculate correctness on |
---|
17 | % the whole data set |
---|
18 | % if k = any value less than the # of rows in the data set: |
---|
19 | % divide up the data set into test and training |
---|
20 | % using k-fold method |
---|
21 | % if k = # of rows in the data set: use the 'leave 1' method |
---|
22 | % |
---|
23 | % output: 0 - no output, 1 - produce output, default is 0 |
---|
24 | % nu: weighted parameter |
---|
25 | % -1 - easy estimation |
---|
26 | % 0 - hard estimation |
---|
27 | % any other value - used as nu by the algorithm |
---|
28 | % default - 0 |
---|
29 | % delta: default is 10^-3 |
---|
30 | %=================================================================== |
---|
31 | % Output parameters: |
---|
32 | % |
---|
33 | % w: the normal vector of the classifier |
---|
34 | % gamma: the threshold |
---|
35 | % trainCorr: training set correctness |
---|
36 | % testCorr: test set correctness |
---|
37 | % cpu_time: time elapsed |
---|
38 | % nu: estimated value (or specified value) of nu |
---|
39 | %========================================================================== |
---|
40 | |
---|
41 | if nargin<6 |
---|
42 | delta=1e-3; |
---|
43 | end |
---|
44 | |
---|
45 | if nargin<5 |
---|
46 | output=0; |
---|
47 | end |
---|
48 | |
---|
49 | if ((nargin<4)|(nu==0)) |
---|
50 | nu = EstNuLong(A,d); % default is hard estimation |
---|
51 | elseif nu==-1 % easy estimation |
---|
52 | nu = EstNuShort(A,d); |
---|
53 | end |
---|
54 | |
---|
55 | if nargin<3 |
---|
56 | k=0; |
---|
57 | end |
---|
58 | |
---|
59 | r=randperm(size(d,1));d=d(r,:);A=A(r,:); % random permutation |
---|
60 | |
---|
61 | tic; |
---|
62 | |
---|
63 | trainCorr=0; |
---|
64 | testCorr=0; |
---|
65 | |
---|
66 | if k==0 |
---|
67 | [w, gamma,iter] = core(A,d,nu,delta); |
---|
68 | cpu_time=toc; |
---|
69 | if output==1 |
---|
70 | fprintf(1,'\nNumber of Iterations: %d',iter); |
---|
71 | fprintf(1,'\nElapse time: %10.2f\n\n',cpu_time); |
---|
72 | end |
---|
73 | return |
---|
74 | end |
---|
75 | |
---|
76 | %if k==1 only training set correctness is calculated |
---|
77 | if k==1 |
---|
78 | tic; |
---|
79 | [w, gamma,iter] = core(A,d,nu,delta); |
---|
80 | trainCorr = correctness(A,d,w,gamma); |
---|
81 | cpu_time = toc; |
---|
82 | if output == 1 |
---|
83 | fprintf(1,'\nTraining set correctness: %3.2f%%',trainCorr); |
---|
84 | fprintf(1,'\nNumber of Iterations: %d',iter); |
---|
85 | fprintf(1,'\nElapse time: %10.2f\n\n',cpu_time); |
---|
86 | end |
---|
87 | return |
---|
88 | end |
---|
89 | |
---|
90 | [sm sn]=size(A); |
---|
91 | accuIter = 0; |
---|
92 | |
---|
93 | cpu_time = 0; |
---|
94 | indx = [0:k]; |
---|
95 | indx = floor(sm*indx/k); %last row numbers for all 'segments' |
---|
96 | % split trainining set from test set |
---|
97 | for i = 1:k |
---|
98 | Ctest = []; dtest = [];Ctrain = []; dtrain = []; |
---|
99 | |
---|
100 | Ctest = A((indx(i)+1:indx(i+1)),:); |
---|
101 | dtest = d(indx(i)+1:indx(i+1)); |
---|
102 | |
---|
103 | Ctrain = A(1:indx(i),:); |
---|
104 | Ctrain = [Ctrain;A(indx(i+1)+1:sm,:)]; |
---|
105 | dtrain = [d(1:indx(i));d(indx(i+1)+1:sm,:)]; |
---|
106 | tic; |
---|
107 | [w, gamma,iter] = core(Ctrain,dtrain,nu,delta); |
---|
108 | thisToc = toc; |
---|
109 | |
---|
110 | tmpTrainCorr = correctness(Ctrain,dtrain,w,gamma); |
---|
111 | tmpTestCorr = correctness(Ctest,dtest,w,gamma); |
---|
112 | |
---|
113 | if output==1 |
---|
114 | fprintf(1,'________________________________________________\n'); |
---|
115 | fprintf(1,'Fold %d\n',i); |
---|
116 | fprintf(1,'Training set correctness: %3.2f%%\n',tmpTrainCorr); |
---|
117 | fprintf(1,'Testing set correctness: %3.2f%%\n',tmpTestCorr); |
---|
118 | fprintf(1,'Number of iterations: %d\n',iter); |
---|
119 | fprintf(1,'Elapse time: %10.2f\n',thisToc); |
---|
120 | end |
---|
121 | |
---|
122 | trainCorr = trainCorr + tmpTrainCorr; |
---|
123 | testCorr = testCorr + tmpTestCorr; |
---|
124 | accuIter = accuIter + iter; % accumulative iterations |
---|
125 | cpu_time = cpu_time + thisToc; |
---|
126 | |
---|
127 | end % end of for (looping through test sets) |
---|
128 | |
---|
129 | trainCorr = trainCorr/k; |
---|
130 | testCorr = testCorr/k; |
---|
131 | cpu_time=cpu_time/k; |
---|
132 | |
---|
133 | if output == 1 |
---|
134 | fprintf(1,'=============================================='); |
---|
135 | fprintf(1,'\nTraining set correctness: %3.2f%%',trainCorr); |
---|
136 | fprintf(1,'\nTesting set correctness: %3.2f%%',testCorr); |
---|
137 | fprintf(1,'\nAverage number of iterations: %d',accuIter/k); |
---|
138 | fprintf(1,'\nAverage cpu_time: %10.2f\n',cpu_time); |
---|
139 | end |
---|
140 | |
---|
141 | return; % lpsvm function return |
---|
142 | |
---|
143 | %%%%%%%%%%% core calculation function %%%%%%%%%%%%%%%%%%%%% |
---|
144 | function [w, gamma, iter] = core(A,d,nu,delta); |
---|
145 | |
---|
146 | [m,n]=size(A); |
---|
147 | |
---|
148 | if m>=n |
---|
149 | [w,gamma,iter]=lpsvm_with_smw(A,d,nu,delta); |
---|
150 | else |
---|
151 | [w,gamma,iter]=lpsvm_without_smw(A,d,nu,delta); |
---|
152 | end |
---|
153 | |
---|
154 | return |
---|
155 | |
---|
156 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
---|
157 | % lpsvm when m>=n % |
---|
158 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
---|
159 | |
---|
160 | function [w,gamma,iter]=lpsvm_with_smw(A,d,nu,delta) |
---|
161 | %with SMW without armijo |
---|
162 | %parameters |
---|
163 | epsi=10^(-3);alpha=10^(2);tol=10^(-5);maxiter=100; |
---|
164 | [m,n]=size(A); |
---|
165 | en=ones(n,1); |
---|
166 | em=ones(m,1); |
---|
167 | |
---|
168 | % initial u |
---|
169 | u=ones(m,1);iter=0; |
---|
170 | epsi=epsi*em;nu=nu*em; |
---|
171 | diff=1; |
---|
172 | DA=spdiags(d,0,m,m)*A; |
---|
173 | while (diff>tol) & (iter<maxiter) |
---|
174 | uold=u; |
---|
175 | iter=iter+1; |
---|
176 | du=d.*u;Adu=A'*du; |
---|
177 | pp=max(Adu-en,0);np=max(-Adu-en,0); |
---|
178 | dd=sum(du)*d;unu=max(u-nu,0);uu=max(-u,0); |
---|
179 | %Gradient |
---|
180 | g=-epsi+(d.*(A*pp))-(d.*(A*np))+dd+unu-alpha*uu; |
---|
181 | %Hessian |
---|
182 | E=spdiags(sqrt(sign(np)+sign(pp)),0,n,n); |
---|
183 | H=[DA*E d]; |
---|
184 | f=1./(delta+sign(unu)+alpha*sign(uu)); |
---|
185 | F=spdiags(f,0,m,m);gg=f.*g;HT=H'; |
---|
186 | di=(eye(n+1)+HT*(F*H))\(HT*gg); |
---|
187 | di=H*di;di=f.*di;di=-gg+di;u=u+di; |
---|
188 | diff=norm(g); |
---|
189 | end |
---|
190 | |
---|
191 | w=1/epsi(1)*(pp-np); |
---|
192 | gamma=-(1/epsi(1))*sum(du); |
---|
193 | iter; |
---|
194 | return |
---|
195 | |
---|
196 | |
---|
197 | |
---|
198 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
---|
199 | % lpsvm when m<n % |
---|
200 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
---|
201 | function [w,gamma,iter]=lpsvm_without_smw(A,d,nu,delta) |
---|
202 | %without sherman and without armijo |
---|
203 | %parameters |
---|
204 | epsi=10^(-1);alpha=10^3;tol=10^(-3);maxiter=50; |
---|
205 | [m,n]=size(A); |
---|
206 | en=ones(n,1); |
---|
207 | em=ones(m,1); |
---|
208 | %initial u |
---|
209 | u=ones(m,1);iter=0; |
---|
210 | epsi=epsi*em;nu=nu*em; |
---|
211 | diff=1; |
---|
212 | DA=spdiags(d,0,m,m)*A; |
---|
213 | while (diff>tol) & (iter<maxiter) |
---|
214 | uold=u; |
---|
215 | iter=iter+1; |
---|
216 | du=d.*u;Adu=A'*du; |
---|
217 | pp=max(Adu-en,0);np=max(-Adu-en,0); |
---|
218 | dd=sum(du)*d;unu=max(u-nu,0);uu=max(-u,0); |
---|
219 | %Gradient |
---|
220 | g=-epsi+(d.*(A*pp))-(d.*(A*np))+dd+unu-alpha*uu; |
---|
221 | %Hessian |
---|
222 | E=spdiags(sqrt(sign(np)+sign(pp)),0,n,n); |
---|
223 | H=[DA*E d]; |
---|
224 | F=spdiags(delta+sign(unu)+alpha*sign(uu),0,m,m); |
---|
225 | di=-((H*H'+F)\g); |
---|
226 | u=u+di; |
---|
227 | diff=norm(g); |
---|
228 | end |
---|
229 | du=d.*u;Adu=A'*du; |
---|
230 | pp=max(Adu-en,0);np=max(-Adu-en,0); |
---|
231 | w=1/epsi(1)*(pp-np);gamma=-(1/epsi(1))*sum(du); |
---|
232 | return |
---|
233 | |
---|
234 | %%%%%%%%%%%%%%%% correctness calculation %%%%%%%%%%%%%%%% |
---|
235 | |
---|
236 | function corr = correctness(AA,dd,w,gamma) |
---|
237 | |
---|
238 | p=sign(AA*w-gamma); |
---|
239 | corr=length(find(p==dd))/size(AA,1)*100; |
---|
240 | return |
---|
241 | |
---|
242 | %%%%%%%%%%%%%%EstNuLong%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
---|
243 | % hard way to estimate nu if not specified by the user |
---|
244 | function value = EstNuLong(C,d) |
---|
245 | |
---|
246 | [m,n]=size(C);e=ones(m,1); |
---|
247 | H=[C -e]; |
---|
248 | if m<201 |
---|
249 | H2=H;d2=d; |
---|
250 | else |
---|
251 | r=rand(m,1); |
---|
252 | [s1,s2]=sort(r); |
---|
253 | H2=H(s2(1:200),:); |
---|
254 | d2=d(s2(1:200)); |
---|
255 | end |
---|
256 | |
---|
257 | lamda=1; |
---|
258 | [vu,u]=eig(H2*H2');u=diag(u);p=length(u); |
---|
259 | yt=d2'*vu; |
---|
260 | lamdaO=lamda+1; |
---|
261 | |
---|
262 | cnt=0; |
---|
263 | while (abs(lamdaO-lamda)>10e-4) &(cnt<100) |
---|
264 | cnt=cnt+1; |
---|
265 | nu1=0;pr=0;ee=0;waw=0; |
---|
266 | lamdaO=lamda; |
---|
267 | for i=1:p |
---|
268 | nu1= nu1 + lamda/(u(i)+lamda); |
---|
269 | pr= pr + u(i)/(u(i)+lamda)^2; |
---|
270 | ee= ee + u(i)*yt(i)^2/(u(i)+lamda)^3; |
---|
271 | waw= waw + lamda^2*yt(i)^2/(u(i)+lamda)^2; |
---|
272 | end |
---|
273 | lamda=nu1*ee/(pr*waw); |
---|
274 | end |
---|
275 | |
---|
276 | value = lamda; |
---|
277 | if cnt==100 |
---|
278 | value=1; |
---|
279 | end |
---|
280 | |
---|
281 | return |
---|
282 | |
---|
283 | %%%%%%%%%%%%%%%%%EstNuShort%%%%%%%%%%%%%%%%%%%%%%% |
---|
284 | |
---|
285 | % easy way to estimate nu if not specified by the user |
---|
286 | function value = EstNuShort(C,d) |
---|
287 | |
---|
288 | value = 1/(sum(sum(C.^2))/size(C,2)); |
---|
289 | return |
---|