1 | /* Copyright (C) 1998 |
---|
2 | Berwin A Turlach <bturlach@stats.adelaide.edu.au> */ |
---|
3 | |
---|
4 | /* This library is free software; you can redistribute it and/or |
---|
5 | modify it under the terms of the GNU Library General Public License |
---|
6 | as published by the Free Software Foundation; either version 2 of |
---|
7 | the License, or (at your option) any later version. */ |
---|
8 | |
---|
9 | /* This library is distributed in the hope that it will be useful, but |
---|
10 | WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
---|
12 | Library General Public License for more details. */ |
---|
13 | |
---|
14 | /* You should have received a copy of the GNU Library General Public |
---|
15 | License along with this library; if not, write to the Free Software |
---|
16 | Foundation, Inc., 59 Temple Place, Suite 330, Boston, |
---|
17 | MA 02111-1307, USA. */ |
---|
18 | |
---|
19 | #include "lasso.h" |
---|
20 | //#include "fortify.h" |
---|
21 | |
---|
22 | #define PRECISION FLT_EPSILON |
---|
23 | |
---|
24 | static void lasso_alloc(long n, long m); |
---|
25 | static void lasso_free(void); |
---|
26 | static void qr_init(int n); |
---|
27 | static void qr_incr(void); |
---|
28 | static void qr_free(void); |
---|
29 | static void qr_del(int l, int aug); |
---|
30 | static void qr_add(double *x, int swap); |
---|
31 | #if defined (S_Plus) |
---|
32 | static void errmsg(char* string); |
---|
33 | #else |
---|
34 | static void errmsg(char* where, char* string); |
---|
35 | #endif |
---|
36 | |
---|
37 | static int QR_CHUNK = 10; |
---|
38 | |
---|
39 | static double *xtr=NULL, *btmp=NULL, *qtr=NULL, |
---|
40 | *rinvt_theta=NULL, *step=NULL, ytyd2=0.0; |
---|
41 | static int *theta=NULL, *nz_x=NULL, num_nz_x=0; |
---|
42 | static double *qmat=NULL, *rmat=NULL; |
---|
43 | static int qr_max_size=0, r_ncol=0, q_nrow=0, q_use_row=0; |
---|
44 | static char *no_dyn_mem_message="Cannot allocate dynamic memory"; |
---|
45 | |
---|
46 | void lasso(double *x, /* input: data */ |
---|
47 | long *pn, /* input: rows */ |
---|
48 | long *pm, /* input: columns */ |
---|
49 | double *pt, /* input: threshold */ |
---|
50 | double *beta, /* input/ouput: intial/final beta */ |
---|
51 | double *y, /* input: labels */ |
---|
52 | |
---|
53 | double *yhat1, /* output */ |
---|
54 | double *r, /* output */ |
---|
55 | double *lagrangian, /* output */ |
---|
56 | long *psuc, /* output: succeed: 0/-1 */ |
---|
57 | long *pverb, /* input: verbose: yes/no */ |
---|
58 | long *pas_sub) /* input: no!? */ |
---|
59 | { |
---|
60 | |
---|
61 | double t = *pt, prec; |
---|
62 | long n = *pn, m = *pm, verb = *pverb, as_sub = *pas_sub; |
---|
63 | int not_solved; |
---|
64 | double *x_elem=NULL, tmp, max_val, b_1norm; |
---|
65 | int i, j, max_ind; |
---|
66 | double p_obj, d_obj; |
---|
67 | int num_iter=0, max_num; |
---|
68 | double b_1norm_old, mu; |
---|
69 | double rho_up, rho_low, rho; |
---|
70 | int to_del; |
---|
71 | double *q_elem, wtc, wtw; |
---|
72 | double *r_elem; |
---|
73 | int k; |
---|
74 | int add; |
---|
75 | if( !as_sub) |
---|
76 | lasso_alloc(n,m); |
---|
77 | |
---|
78 | prec = sqrt(PRECISION); |
---|
79 | |
---|
80 | if( as_sub ){ |
---|
81 | |
---|
82 | b_1norm = 0.0; |
---|
83 | for(j=0;j<num_nz_x;j++) |
---|
84 | b_1norm += fabs(beta[nz_x[j]]); |
---|
85 | }else{ |
---|
86 | |
---|
87 | b_1norm = 0.0; |
---|
88 | num_nz_x = 0; |
---|
89 | for(j=0; j<m; j++){ |
---|
90 | if(fabs(beta[j]) > prec){ |
---|
91 | b_1norm += fabs(beta[j]); |
---|
92 | nz_x[num_nz_x] = j; |
---|
93 | num_nz_x++; |
---|
94 | }else |
---|
95 | beta[j] = 0.0; |
---|
96 | } |
---|
97 | } |
---|
98 | |
---|
99 | if( b_1norm > t){ |
---|
100 | if(verb){ |
---|
101 | mexPrintf("******************************\n"); |
---|
102 | mexPrintf("Rescaling beta from L1-norm %f to %f\n",b_1norm,t); |
---|
103 | mexEvalString("drawnow;"); |
---|
104 | |
---|
105 | } |
---|
106 | for(j=0; j<num_nz_x; j++) |
---|
107 | beta[nz_x[j]] = beta[nz_x[j]] * t/b_1norm; |
---|
108 | b_1norm = t; |
---|
109 | } |
---|
110 | |
---|
111 | for(i=0; i< n; i++) |
---|
112 | yhat1[i] = 0.0; |
---|
113 | for(j=0; j < num_nz_x; j++){ |
---|
114 | /* x_elem points to the first element in the column of X to which |
---|
115 | the j-th entry in nz_x points */ |
---|
116 | x_elem = x + nz_x[j]*n; |
---|
117 | tmp = beta[nz_x[j]]; |
---|
118 | for(i=0; i < n; i++){ |
---|
119 | yhat1[i] += *x_elem*tmp; |
---|
120 | x_elem++; /* now we point to the next element in X */ |
---|
121 | } |
---|
122 | } |
---|
123 | |
---|
124 | /* calculate the residual vector */ |
---|
125 | for(i=0; i< n; i++) |
---|
126 | r[i] = y[i]-yhat1[i]; |
---|
127 | |
---|
128 | /* multiply X^T with the residual vector */ |
---|
129 | x_elem = x; |
---|
130 | for(j=0; j < m; j++){ |
---|
131 | tmp = 0.0; |
---|
132 | for(i=0; i<n; i++){ |
---|
133 | tmp += *x_elem*r[i]; |
---|
134 | x_elem++; |
---|
135 | } |
---|
136 | xtr[j] = tmp; |
---|
137 | } |
---|
138 | |
---|
139 | max_val = fabs(xtr[0]); |
---|
140 | max_ind = 0; |
---|
141 | for(j=1; j<m; j++) |
---|
142 | if( fabs(xtr[j]) > max_val ){ |
---|
143 | max_val = fabs(xtr[j]); |
---|
144 | max_ind = j; |
---|
145 | } |
---|
146 | |
---|
147 | if( !as_sub ){ |
---|
148 | |
---|
149 | qr_add(y,TRUE); |
---|
150 | ytyd2 = *rmat * *rmat/2.0; |
---|
151 | for(j=0;j<num_nz_x;j++){ |
---|
152 | qr_add(x+nz_x[j]*n, TRUE); |
---|
153 | if(fabs(beta[nz_x[j]])<prec) |
---|
154 | theta[j] = xtr[nz_x[j]] < 0 ? -1 : 1; |
---|
155 | else |
---|
156 | theta[j] = beta[nz_x[j]] < 0 ? -1 : 1; |
---|
157 | |
---|
158 | } |
---|
159 | } |
---|
160 | if( num_nz_x==0 ){ |
---|
161 | nz_x[0] = max_ind; |
---|
162 | num_nz_x = 1; |
---|
163 | if(verb){ |
---|
164 | mexPrintf("******************************\n"); |
---|
165 | mexPrintf(" -->\tAdding variable: %d\n",max_ind+1); |
---|
166 | mexEvalString("drawnow;"); |
---|
167 | } |
---|
168 | qr_add(x+max_ind*n, TRUE); |
---|
169 | theta[0] = xtr[max_ind] < 0 ? -1 : 1; |
---|
170 | } |
---|
171 | *psuc=0; |
---|
172 | if(verb){ |
---|
173 | |
---|
174 | |
---|
175 | /* Find out how many times [[max_val]] is attained */ |
---|
176 | tmp = (1.0-prec)*max_val; |
---|
177 | if( tmp < prec ) tmp = 0.0; |
---|
178 | max_num = 1; |
---|
179 | for(j=0; j<m; j++){ |
---|
180 | if( tmp <= fabs(xtr[j]) && j!=max_ind){ |
---|
181 | /*we found another element equal to the (current) |
---|
182 | maximal absolute value */ |
---|
183 | max_num++; |
---|
184 | } |
---|
185 | } |
---|
186 | |
---|
187 | /* for the value of the dual objective function we need |
---|
188 | to calculate the L2 norm of the vector of fitted values */ |
---|
189 | p_obj = 0.0; |
---|
190 | d_obj = 0.0; |
---|
191 | for(i=0;i<n;i++){ |
---|
192 | p_obj += r[i]*r[i]; |
---|
193 | d_obj += yhat1[i]*yhat1[i]; |
---|
194 | } |
---|
195 | p_obj /= 2.0; |
---|
196 | d_obj = ytyd2 - d_obj/2.0 - t*max_val ; |
---|
197 | mexPrintf("******************************\n"); |
---|
198 | mexPrintf("\nIteration number: %d\n", num_iter); |
---|
199 | mexPrintf("Value of primal object function : %f\n", p_obj); |
---|
200 | mexPrintf("Value of dual object function : %f\n", d_obj); |
---|
201 | mexPrintf("L1 norm of current beta : %f", b_1norm); |
---|
202 | mexPrintf(" <= %f\n", t); |
---|
203 | mexPrintf("Maximal absolute value in t(X)%%*%%r : %e", max_val); |
---|
204 | mexPrintf(" attained %d time(s)\n", max_num); |
---|
205 | mexPrintf("Number of parameters allowed to vary : %d\n", num_nz_x); |
---|
206 | mexEvalString("drawnow;"); |
---|
207 | num_iter++; |
---|
208 | } |
---|
209 | while(1){ |
---|
210 | do{ |
---|
211 | |
---|
212 | q_elem = qmat; |
---|
213 | |
---|
214 | for(j=0;j<num_nz_x;j++){ |
---|
215 | tmp = 0.0; |
---|
216 | for(i=0;i<n;i++,q_elem++) { |
---|
217 | tmp += *q_elem*r[i]; |
---|
218 | } |
---|
219 | qtr[j] = tmp; |
---|
220 | } |
---|
221 | |
---|
222 | if( b_1norm < (1.0-prec)*t ) |
---|
223 | mu = 0.0; |
---|
224 | else{ |
---|
225 | |
---|
226 | /* z=R^{-T}\theta can be calculated by solving R^Tz=\theta */ |
---|
227 | r_elem = rmat; |
---|
228 | for(j=0;j<num_nz_x;j++,r_elem++){ |
---|
229 | tmp = theta[j]; |
---|
230 | for(k=0;k<j;k++,r_elem++) |
---|
231 | tmp -= *r_elem * rinvt_theta[k]; |
---|
232 | rinvt_theta[j] = tmp / *r_elem; |
---|
233 | } |
---|
234 | |
---|
235 | wtc = 0.0; |
---|
236 | wtw = 0.0; |
---|
237 | for(j=0;j<num_nz_x;j++){ |
---|
238 | wtc += rinvt_theta[j]*qtr[j]; |
---|
239 | wtw += rinvt_theta[j]*rinvt_theta[j]; |
---|
240 | } |
---|
241 | mu = (wtc-(t-b_1norm))/wtw; |
---|
242 | } |
---|
243 | |
---|
244 | if(mu<=0.0) |
---|
245 | for(j=0;j<num_nz_x;j++) |
---|
246 | step[j] = qtr[j]; |
---|
247 | else |
---|
248 | for(j=0;j<num_nz_x;j++) |
---|
249 | step[j] = qtr[j] - mu*rinvt_theta[j]; |
---|
250 | |
---|
251 | /* h=R^{-1}z can be calculated by solving Rh=z */ |
---|
252 | for(j=num_nz_x-1;j>=0;j--){ |
---|
253 | tmp = step[j]; |
---|
254 | for(k=num_nz_x-1;k>j;k--) |
---|
255 | tmp -= RMAT(j,k) * step[k]; |
---|
256 | step[j] = tmp / RMAT(j,j); |
---|
257 | } |
---|
258 | for(j=0;j<num_nz_x;j++){ |
---|
259 | btmp[j] = beta[nz_x[j]]; |
---|
260 | beta[nz_x[j]] += step[j]; |
---|
261 | } |
---|
262 | b_1norm_old=b_1norm; |
---|
263 | |
---|
264 | b_1norm = 0.0; |
---|
265 | for(j=0;j<num_nz_x;j++) { |
---|
266 | b_1norm += fabs(beta[nz_x[j]]); |
---|
267 | } |
---|
268 | |
---|
269 | not_solved=FALSE; |
---|
270 | if( b_1norm > (1+prec)*t){ |
---|
271 | not_solved=TRUE; |
---|
272 | if(b_1norm_old < (1.0-prec)*t){ |
---|
273 | |
---|
274 | if(verb) { |
---|
275 | mexPrintf(" -->\tStepping onto the border of the L1 ball.\n"); |
---|
276 | mexEvalString("drawnow;"); |
---|
277 | } |
---|
278 | rho_up = rho = 1.0; |
---|
279 | rho_low = 0.0; |
---|
280 | while( fabs(t-b_1norm) > prec*t ){ |
---|
281 | if( b_1norm > t){ |
---|
282 | rho_up = rho; |
---|
283 | rho = (rho+rho_low)/2.0; |
---|
284 | } |
---|
285 | if( b_1norm < t){ |
---|
286 | rho_low = rho; |
---|
287 | rho = (rho+rho_up)/2.0; |
---|
288 | } |
---|
289 | if(rho < prec) break; |
---|
290 | for(j=0; j<num_nz_x; j++) { |
---|
291 | beta[nz_x[j]] = btmp[j]+rho*step[j]; |
---|
292 | } |
---|
293 | |
---|
294 | b_1norm = 0.0; |
---|
295 | for(j=0;j<num_nz_x;j++) |
---|
296 | b_1norm += fabs(beta[nz_x[j]]); |
---|
297 | } |
---|
298 | for(j=0;j<num_nz_x;j++){ |
---|
299 | if(fabs(beta[nz_x[j]])<prec) |
---|
300 | theta[j] = btmp[j] > 0 ? -1 : 1; |
---|
301 | else |
---|
302 | theta[j] = beta[nz_x[j]] < 0 ? -1 : 1; |
---|
303 | } |
---|
304 | }else{ |
---|
305 | |
---|
306 | |
---|
307 | rho = 1.0; |
---|
308 | to_del = -1; |
---|
309 | for(j=0; j<num_nz_x; j++){ |
---|
310 | if(fabs(step[j]) > prec){ |
---|
311 | tmp = -btmp[j]/(step[j]); |
---|
312 | if( 0.0 < tmp && tmp < rho ){ |
---|
313 | rho = tmp; |
---|
314 | to_del = j; |
---|
315 | } |
---|
316 | } |
---|
317 | } |
---|
318 | if(to_del < 0 ){ |
---|
319 | *psuc= -1; |
---|
320 | goto EXIT_HERE; |
---|
321 | } |
---|
322 | for(j=0; j<num_nz_x; j++) { |
---|
323 | beta[nz_x[j]] = btmp[j]+rho*step[j]; |
---|
324 | } |
---|
325 | if(verb) { |
---|
326 | mexPrintf(" -->\tRemoving variable: %d",nz_x[to_del]+1); |
---|
327 | mexEvalString("drawnow;"); |
---|
328 | } |
---|
329 | beta[nz_x[to_del]] = 0.0; |
---|
330 | qr_del(to_del,TRUE); |
---|
331 | for(j=to_del+1; j< num_nz_x; j++){ |
---|
332 | nz_x[j-1] = nz_x[j]; |
---|
333 | theta[j-1] = theta[j]; |
---|
334 | } |
---|
335 | num_nz_x--; |
---|
336 | |
---|
337 | b_1norm = 0.0; |
---|
338 | for(j=0;j<num_nz_x;j++) |
---|
339 | b_1norm += fabs(beta[nz_x[j]]); |
---|
340 | if(verb) |
---|
341 | if(b_1norm < (1-prec)*t) |
---|
342 | mexPrintf(", and stepping into the interior of the L1 ball\n"); |
---|
343 | else |
---|
344 | mexPrintf("\n"); |
---|
345 | } |
---|
346 | mexEvalString("drawnow;"); |
---|
347 | } |
---|
348 | |
---|
349 | for(i=0; i< n; i++) |
---|
350 | yhat1[i] = 0.0; |
---|
351 | for(j=0; j < num_nz_x; j++){ |
---|
352 | /* x_elem points to the first element in the column of X to which |
---|
353 | the j-th entry in nz_x points */ |
---|
354 | x_elem = x + nz_x[j]*n; |
---|
355 | tmp = beta[nz_x[j]]; |
---|
356 | for(i=0; i < n; i++){ |
---|
357 | yhat1[i] += *x_elem*tmp; |
---|
358 | x_elem++; /* now we point to the next element in X */ |
---|
359 | } |
---|
360 | } |
---|
361 | |
---|
362 | /* calculate the residual vector */ |
---|
363 | for(i=0; i< n; i++) { |
---|
364 | r[i] = y[i]-yhat1[i]; |
---|
365 | } |
---|
366 | }while(not_solved); |
---|
367 | |
---|
368 | |
---|
369 | for(i=0; i< n; i++) |
---|
370 | yhat1[i] = 0.0; |
---|
371 | for(j=0; j < num_nz_x; j++){ |
---|
372 | /* x_elem points to the first element in the column of X to which |
---|
373 | the j-th entry in nz_x points */ |
---|
374 | x_elem = x + nz_x[j]*n; |
---|
375 | tmp = beta[nz_x[j]]; |
---|
376 | for(i=0; i < n; i++){ |
---|
377 | yhat1[i] += *x_elem*tmp; |
---|
378 | x_elem++; /* now we point to the next element in X */ |
---|
379 | } |
---|
380 | } |
---|
381 | |
---|
382 | /* calculate the residual vector */ |
---|
383 | for(i=0; i< n; i++) |
---|
384 | r[i] = y[i]-yhat1[i]; |
---|
385 | |
---|
386 | |
---|
387 | /* multiply X^T with the residual vector */ |
---|
388 | x_elem = x; |
---|
389 | for(j=0; j < m; j++){ |
---|
390 | tmp = 0.0; |
---|
391 | for(i=0; i<n; i++){ |
---|
392 | tmp += *x_elem*r[i]; |
---|
393 | x_elem++; |
---|
394 | } |
---|
395 | xtr[j] = tmp; |
---|
396 | } |
---|
397 | max_val = fabs(xtr[0]); |
---|
398 | max_ind = 0; |
---|
399 | for(j=1; j<m; j++) |
---|
400 | if( fabs(xtr[j]) > max_val ){ |
---|
401 | max_val = fabs(xtr[j]); |
---|
402 | max_ind = j; |
---|
403 | } |
---|
404 | |
---|
405 | if(verb){ |
---|
406 | |
---|
407 | |
---|
408 | /* Find out how many times [[max_val]] is attained */ |
---|
409 | tmp = (1.0-prec)*max_val; |
---|
410 | if( tmp < prec ) tmp = 0.0; |
---|
411 | max_num = 1; |
---|
412 | for(j=0; j<m; j++){ |
---|
413 | if( tmp <= fabs(xtr[j]) && j!=max_ind){ |
---|
414 | /*we found another element equal to the (current) |
---|
415 | maximal absolute value */ |
---|
416 | max_num++; |
---|
417 | } |
---|
418 | } |
---|
419 | |
---|
420 | /* for the value of the dual objective function we need |
---|
421 | to calculate the L2 norm of the vector of fitted values */ |
---|
422 | p_obj = 0.0; |
---|
423 | d_obj = 0.0; |
---|
424 | for(i=0;i<n;i++){ |
---|
425 | p_obj += r[i]*r[i]; |
---|
426 | d_obj += yhat1[i]*yhat1[i]; |
---|
427 | } |
---|
428 | p_obj /= 2.0; |
---|
429 | d_obj = ytyd2 - d_obj/2.0 - t*max_val ; |
---|
430 | mexPrintf("******************************\n"); |
---|
431 | mexPrintf("\nIteration number: %d\n", num_iter); |
---|
432 | mexPrintf("Value of primal object function : %f\n", p_obj); |
---|
433 | mexPrintf("Value of dual object function : %f\n", d_obj); |
---|
434 | mexPrintf("L1 norm of current beta : %f", b_1norm); |
---|
435 | mexPrintf(" <= %f\n", t); |
---|
436 | mexPrintf("Maximal absolute value in t(X)%%*%%r : %e", max_val); |
---|
437 | mexPrintf(" attained %d time(s)\n", max_num); |
---|
438 | mexPrintf("Number of parameters allowed to vary : %d\n", num_nz_x); |
---|
439 | mexEvalString("drawnow;"); |
---|
440 | num_iter++; |
---|
441 | } |
---|
442 | |
---|
443 | add = TRUE; |
---|
444 | for(i=0; i<num_nz_x; i++) |
---|
445 | if(nz_x[i]==max_ind){ |
---|
446 | add = FALSE; |
---|
447 | break; |
---|
448 | } |
---|
449 | if(add){ |
---|
450 | qr_add(x+max_ind*n,TRUE); |
---|
451 | nz_x[num_nz_x] = max_ind; |
---|
452 | theta[num_nz_x] = xtr[max_ind]<0 ? -1 : 1; |
---|
453 | num_nz_x++; |
---|
454 | }else{ |
---|
455 | *lagrangian = max_val; |
---|
456 | break; |
---|
457 | } |
---|
458 | |
---|
459 | b_1norm = 0.0; |
---|
460 | for(j=0;j<num_nz_x;j++) |
---|
461 | b_1norm += fabs(beta[nz_x[j]]); |
---|
462 | if(verb) { |
---|
463 | mexPrintf(" -->\tAdding variable: %d\n",max_ind+1); |
---|
464 | mexEvalString("drawnow;"); |
---|
465 | } |
---|
466 | } |
---|
467 | EXIT_HERE: |
---|
468 | if( !as_sub) |
---|
469 | lasso_free(); |
---|
470 | } |
---|
471 | void mult_lasso(double *x, long *pn, long *pm, double *pt, long *pl, |
---|
472 | double *beta, double *y, double *yhat1, double *r, |
---|
473 | double *lagrangian, long *psuc, long *pverb) |
---|
474 | { |
---|
475 | |
---|
476 | double prec; |
---|
477 | long n = *pn, m = *pm, l = *pl, verb = *pverb, as_sub = TRUE, i, j; |
---|
478 | lasso_alloc(n,m); |
---|
479 | |
---|
480 | qr_add(y,TRUE); |
---|
481 | ytyd2 = *rmat * *rmat/2.0; |
---|
482 | prec = sqrt(PRECISION); |
---|
483 | num_nz_x = 0; |
---|
484 | for(j=0; j<m; j++){ |
---|
485 | if(fabs(beta[j]) > prec){ |
---|
486 | qr_add(x+j*n, TRUE); |
---|
487 | nz_x[num_nz_x] = j; |
---|
488 | num_nz_x++; |
---|
489 | }else |
---|
490 | beta[j] = 0.0; |
---|
491 | } |
---|
492 | *psuc = 0; |
---|
493 | for(i=0; i<l; i++){ |
---|
494 | |
---|
495 | if(verb){ |
---|
496 | mexPrintf("\n\n++++++++++++++++++++++++++++++\n"); |
---|
497 | mexPrintf("Solving problem number %ld with bound %f\n", i+1, pt[i]); |
---|
498 | mexPrintf("++++++++++++++++++++++++++++++\n"); |
---|
499 | mexEvalString("drawnow;"); |
---|
500 | } |
---|
501 | if(i>0) Memcpy(beta,beta-m,m); |
---|
502 | lasso(x, pn, pm, pt+i, beta, y, yhat1, r, lagrangian, psuc, pverb, &as_sub); |
---|
503 | if( *psuc < 0 ){ |
---|
504 | goto EXIT_HERE; |
---|
505 | } |
---|
506 | |
---|
507 | beta += m; |
---|
508 | yhat1 += n; |
---|
509 | r += n; |
---|
510 | lagrangian++; |
---|
511 | } |
---|
512 | EXIT_HERE: |
---|
513 | lasso_free(); |
---|
514 | } |
---|
515 | static void lasso_alloc(long n, long m){ |
---|
516 | |
---|
517 | #if defined(S_Plus) |
---|
518 | if( nz_x != NULL || theta != NULL || xtr != NULL || btmp != NULL || |
---|
519 | qtr != NULL || rinvt_theta != NULL || step != NULL || |
---|
520 | num_nz_x != 0 || ytyd2 != 0.0){ |
---|
521 | MESSAGE "Possible memory corruption or memory leak.\n We" |
---|
522 | "advise to restart your S+ session" WARNING(NULL_ENTRY); |
---|
523 | lasso_free(); |
---|
524 | } |
---|
525 | #endif |
---|
526 | |
---|
527 | QR_CHUNK = m; // Added: CJV |
---|
528 | |
---|
529 | nz_x = Calloc(m,int); |
---|
530 | if( nz_x == NULL ) |
---|
531 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
532 | theta = Calloc(m,int); |
---|
533 | if( theta==NULL ) |
---|
534 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
535 | xtr = Calloc(m,double); |
---|
536 | if( xtr==NULL ) |
---|
537 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
538 | btmp = Calloc(m,double); |
---|
539 | if( btmp==NULL ) |
---|
540 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
541 | qtr = Calloc(m,double); |
---|
542 | if( qtr==NULL ) |
---|
543 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
544 | rinvt_theta = Calloc(m,double); |
---|
545 | if( rinvt_theta==NULL ) |
---|
546 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
547 | step = Calloc(m,double); |
---|
548 | if( step==NULL ) |
---|
549 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
550 | qr_init(n); |
---|
551 | } |
---|
552 | static void lasso_free(void){ |
---|
553 | num_nz_x=0; |
---|
554 | ytyd2 = 0.0; |
---|
555 | Free(nz_x); |
---|
556 | Free(theta); |
---|
557 | Free(xtr); |
---|
558 | Free(btmp); |
---|
559 | Free(qtr); |
---|
560 | Free(rinvt_theta); |
---|
561 | Free(step); |
---|
562 | qr_free(); |
---|
563 | } |
---|
564 | static void qr_init(int n) { |
---|
565 | #if defined (S_Plus) |
---|
566 | if(qr_max_size!=0 || r_ncol!=0 || q_nrow!=0 || q_use_row!=0 || |
---|
567 | qmat!=NULL || rmat!=NULL){ |
---|
568 | MESSAGE "Possible memory corruption or memory leak.\n We" |
---|
569 | "advise to restart your S+ session" WARNING(NULL_ENTRY); |
---|
570 | qr_free(); |
---|
571 | } |
---|
572 | #endif |
---|
573 | qr_max_size = QR_CHUNK; |
---|
574 | r_ncol = 0; |
---|
575 | q_nrow = n; |
---|
576 | qmat = Calloc(n*qr_max_size,double); |
---|
577 | if(qmat==NULL) |
---|
578 | ERRMSG("qr_init", no_dyn_mem_message); |
---|
579 | rmat = Calloc(qr_max_size*(qr_max_size+1)/2,double); |
---|
580 | if(rmat==NULL) |
---|
581 | ERRMSG("qr_init", no_dyn_mem_message); |
---|
582 | } |
---|
583 | static void qr_incr(void) { |
---|
584 | qr_max_size += QR_CHUNK; |
---|
585 | |
---|
586 | /* reallocate R always */ |
---|
587 | rmat = Realloc(rmat,qr_max_size*(qr_max_size+1)/2,double); |
---|
588 | if(rmat==NULL) |
---|
589 | ERRMSG("qr_incr", no_dyn_mem_message); |
---|
590 | |
---|
591 | /* reallocate Q only if necessary and only to maximal necessary size */ |
---|
592 | if( qr_max_size >= q_nrow){ |
---|
593 | if( qr_max_size-QR_CHUNK < q_nrow){ |
---|
594 | qmat = Realloc(qmat,q_nrow*q_nrow,double); |
---|
595 | if(qmat==NULL) |
---|
596 | ERRMSG("qr_incr", no_dyn_mem_message); |
---|
597 | } |
---|
598 | }else{ |
---|
599 | qmat = Realloc(qmat,q_nrow*qr_max_size,double); |
---|
600 | if(qmat==NULL) |
---|
601 | ERRMSG("qr_incr", no_dyn_mem_message); |
---|
602 | } |
---|
603 | } |
---|
604 | static void qr_free(void) |
---|
605 | { |
---|
606 | qr_max_size = 0; |
---|
607 | r_ncol = 0; |
---|
608 | q_nrow = 0; |
---|
609 | q_use_row = 0; |
---|
610 | Free(qmat); |
---|
611 | Free(rmat); |
---|
612 | } |
---|
613 | static void qr_del(int l, int aug) { |
---|
614 | |
---|
615 | double c, s, tau, nu, *col_k, *col_kp1, *a, *b, tmp; |
---|
616 | int i, j, k, l0; |
---|
617 | |
---|
618 | if( l<0 || l>=r_ncol ) |
---|
619 | ERRMSG("qr_del", "Invalid column number"); |
---|
620 | r_ncol--; |
---|
621 | if( l==r_ncol) |
---|
622 | if( aug ) |
---|
623 | ERRMSG("qr_del", "Trying to delete last column of augmented matrix"); |
---|
624 | else |
---|
625 | return; |
---|
626 | |
---|
627 | /* this is TRUE if $m\le n$ ($m-1\le n$ if the matrix is augmented) */ |
---|
628 | if (r_ncol < q_nrow ){ |
---|
629 | for(k=l; k<r_ncol; k++) /* Update the factorisation and be done */ |
---|
630 | { |
---|
631 | |
---|
632 | col_k = RCOL(k); /* first element in column $k$ in R */ |
---|
633 | Memcpy(col_k,col_k+k+1,k+1); |
---|
634 | a = col_k+k; |
---|
635 | b = a+k+2; |
---|
636 | tau = fabs(*a)+fabs(*b); |
---|
637 | if( tau == 0.0 ) continue; /* both elements are zero |
---|
638 | nothing to update */ |
---|
639 | nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau)); |
---|
640 | c = *a/nu; |
---|
641 | s = *b/nu; |
---|
642 | *a = nu; |
---|
643 | b += k+2; |
---|
644 | a = b-1; |
---|
645 | for(j=k+2;j<=r_ncol;j++, a+=j, b+=j){ |
---|
646 | tmp = c * *a + s * *b; |
---|
647 | *b = c * *b - s * *a; |
---|
648 | *a = tmp; |
---|
649 | } |
---|
650 | col_k = QCOL(k); |
---|
651 | col_kp1 = col_k+q_nrow; |
---|
652 | for(j=0;j<q_nrow;j++){ |
---|
653 | tmp = c*col_k[j]+s*col_kp1[j]; |
---|
654 | col_kp1[j] = c*col_kp1[j]-s*col_k[j]; |
---|
655 | col_k[j] = tmp; |
---|
656 | } |
---|
657 | } |
---|
658 | }else |
---|
659 | if( l < q_nrow ){ |
---|
660 | /* Update columns upto $m$ and than shift remaining columns */ |
---|
661 | for(k=l; k<q_nrow-1; k++){ |
---|
662 | |
---|
663 | col_k = RCOL(k); /* first element in column $k$ in R */ |
---|
664 | Memcpy(col_k,col_k+k+1,k+1); |
---|
665 | a = col_k+k; |
---|
666 | b = a+k+2; |
---|
667 | tau = fabs(*a)+fabs(*b); |
---|
668 | if( tau == 0.0 ) continue; /* both elements are zero |
---|
669 | nothing to update */ |
---|
670 | nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau)); |
---|
671 | c = *a/nu; |
---|
672 | s = *b/nu; |
---|
673 | *a = nu; |
---|
674 | b += k+2; |
---|
675 | a = b-1; |
---|
676 | for(j=k+2;j<=r_ncol;j++, a+=j, b+=j){ |
---|
677 | tmp = c * *a + s * *b; |
---|
678 | *b = c * *b - s * *a; |
---|
679 | *a = tmp; |
---|
680 | } |
---|
681 | col_k = QCOL(k); |
---|
682 | col_kp1 = col_k+q_nrow; |
---|
683 | for(j=0;j<q_nrow;j++){ |
---|
684 | tmp = c*col_k[j]+s*col_kp1[j]; |
---|
685 | col_kp1[j] = c*col_kp1[j]-s*col_k[j]; |
---|
686 | col_k[j] = tmp; |
---|
687 | } |
---|
688 | } |
---|
689 | l0 = q_nrow-1; |
---|
690 | |
---|
691 | col_k = RCOL(l0); /* first element in column $l_0$ in R */ |
---|
692 | for(i=l0; i<r_ncol; i++, col_k +=i) |
---|
693 | Memcpy(col_k,col_k+i+1,q_nrow); |
---|
694 | |
---|
695 | /* [[rmat + (q_nrow-1)*q_nrow/2+q_nrow-1]] is last element in column |
---|
696 | [[q_nrow]] in R */ |
---|
697 | a = rmat + (q_nrow-1)*(q_nrow+2)/2; |
---|
698 | if( *a < 0 ){ |
---|
699 | /* [[j<r_ncol]] sufficient since we shifted the columns already */ |
---|
700 | for(j=l0;j<r_ncol;j++,a+=j) |
---|
701 | *a = - *a; |
---|
702 | col_k = qmat+(q_nrow-1)*q_nrow; |
---|
703 | for(i=0;i<q_nrow;i++,col_k++) |
---|
704 | *col_k = -*col_k; |
---|
705 | } |
---|
706 | }else{ /* just shift last columns to the left */ |
---|
707 | l0 = l; |
---|
708 | |
---|
709 | col_k = RCOL(l0); /* first element in column $l_0$ in R */ |
---|
710 | for(i=l0; i<r_ncol; i++, col_k +=i) |
---|
711 | Memcpy(col_k,col_k+i+1,q_nrow); |
---|
712 | |
---|
713 | /* [[rmat + (q_nrow-1)*q_nrow/2+q_nrow-1]] is last element in column |
---|
714 | [[q_nrow]] in R */ |
---|
715 | a = rmat + (q_nrow-1)*(q_nrow+2)/2; |
---|
716 | if( *a < 0 ){ |
---|
717 | /* [[j<r_ncol]] sufficient since we shifted the columns already */ |
---|
718 | for(j=l0;j<r_ncol;j++,a+=j) |
---|
719 | *a = - *a; |
---|
720 | col_k = qmat+(q_nrow-1)*q_nrow; |
---|
721 | for(i=0;i<q_nrow;i++,col_k++) |
---|
722 | *col_k = -*col_k; |
---|
723 | } |
---|
724 | } |
---|
725 | } |
---|
726 | static void qr_add(double *x, int swap){ |
---|
727 | |
---|
728 | double tmp, norm_orig, norm_init, norm_last, *q_new_col; |
---|
729 | int i; |
---|
730 | double *col_l, *col_lm1, *q_elem, *r_new_col; |
---|
731 | double norm_new; |
---|
732 | int j; |
---|
733 | double c, s, tau, nu, *a, *b; |
---|
734 | |
---|
735 | if( r_ncol == qr_max_size ) |
---|
736 | qr_incr(); |
---|
737 | |
---|
738 | norm_orig = 0.0; |
---|
739 | tmp = 0.0; |
---|
740 | for(i=0; i<q_nrow; i++) tmp += x[i]*x[i]; |
---|
741 | norm_orig = sqrt(tmp); |
---|
742 | if(r_ncol<q_nrow){ |
---|
743 | q_new_col = QCOL(r_ncol); /* points to the first element of new column */ |
---|
744 | tmp = 0.0; |
---|
745 | for(i=0; i<q_nrow; i++){ |
---|
746 | q_new_col[i] = x[i]/norm_orig; |
---|
747 | tmp += q_new_col[i]*q_new_col[i]; |
---|
748 | } |
---|
749 | if(r_ncol==0){ |
---|
750 | *rmat = norm_orig; |
---|
751 | r_ncol++; |
---|
752 | return; |
---|
753 | } |
---|
754 | norm_init = norm_last = sqrt(tmp); |
---|
755 | } |
---|
756 | r_new_col = RCOL(r_ncol); |
---|
757 | for(i=0; i<=r_ncol; i++) r_new_col[i] = 0.0; |
---|
758 | if( r_ncol >= q_nrow ){ |
---|
759 | /* Multiply new column with Q-transpose in [[qr_add]] */ |
---|
760 | q_elem = qmat; |
---|
761 | for(i=0;i<q_nrow;i++){ |
---|
762 | tmp = 0.0; |
---|
763 | for(j=0;j<q_nrow;j++,q_elem++) |
---|
764 | tmp += *q_elem * x[j]; |
---|
765 | r_new_col[i] = tmp; |
---|
766 | } |
---|
767 | if( swap ){ /* swap last two columns */ |
---|
768 | col_l = r_new_col; |
---|
769 | col_lm1 = r_new_col-r_ncol; |
---|
770 | for(i=0; i<q_nrow; i++,col_l++,col_lm1++){ |
---|
771 | tmp = *col_l; |
---|
772 | *col_l = *col_lm1; |
---|
773 | *col_lm1 = tmp; |
---|
774 | } |
---|
775 | col_lm1--; |
---|
776 | col_l--; |
---|
777 | if(q_nrow==r_ncol && *col_lm1<0){ |
---|
778 | /* if new R[$n$,$n$] is negative, than change signs */ |
---|
779 | *col_lm1 = - *col_lm1; |
---|
780 | *col_l = - *col_l; |
---|
781 | q_new_col = qmat+(q_nrow-1)*q_nrow; |
---|
782 | for(j=0;j<q_nrow;j++) |
---|
783 | q_new_col[j] = -q_new_col[j]; |
---|
784 | } |
---|
785 | } |
---|
786 | r_ncol++; |
---|
787 | return; |
---|
788 | } |
---|
789 | while(1){ |
---|
790 | |
---|
791 | q_elem = qmat; |
---|
792 | for(j=0; j<r_ncol;j++){ |
---|
793 | tmp = 0.0; |
---|
794 | for(i=0;i<q_nrow;i++,q_elem++) |
---|
795 | tmp += *q_elem * q_new_col[i]; |
---|
796 | r_new_col[j] += tmp; |
---|
797 | q_elem -= q_nrow; |
---|
798 | for(i=0;i<q_nrow;i++,q_elem++) |
---|
799 | q_new_col[i] -= *q_elem*tmp; |
---|
800 | } |
---|
801 | tmp = 0.0; |
---|
802 | for(i=0;i<q_nrow;i++) tmp += q_new_col[i]*q_new_col[i]; |
---|
803 | norm_new = sqrt(tmp); |
---|
804 | |
---|
805 | if( norm_new >= norm_last/2.0) break; |
---|
806 | if( norm_new > 0.1*norm_init*PRECISION ) |
---|
807 | norm_last = norm_new; |
---|
808 | else{ |
---|
809 | norm_init = norm_last = 0.1*norm_last*PRECISION; |
---|
810 | for(i=0;i<q_nrow;i++) |
---|
811 | q_new_col[i] = 0.0; |
---|
812 | if(q_use_row == q_nrow) |
---|
813 | ERRMSG("qr_add","Cannot orthogonalise new column\n"); |
---|
814 | |
---|
815 | q_new_col[q_use_row] = norm_new = norm_last; |
---|
816 | q_use_row++; |
---|
817 | } |
---|
818 | } |
---|
819 | |
---|
820 | for(i=0;i<q_nrow;i++) |
---|
821 | q_new_col[i] /= norm_new; |
---|
822 | for(i=0;i<r_ncol;i++) |
---|
823 | r_new_col[i] *= norm_orig; |
---|
824 | r_new_col[r_ncol] = norm_new*norm_orig; |
---|
825 | if(swap){ |
---|
826 | |
---|
827 | col_l = r_new_col; /* first element in last column in R */ |
---|
828 | col_lm1 = r_new_col-r_ncol; /* first element in column before last in R */ |
---|
829 | for(j=0;j<r_ncol;j++,col_l++,col_lm1++){ |
---|
830 | tmp = *col_l; |
---|
831 | *col_l = *col_lm1; |
---|
832 | *col_lm1 = tmp; |
---|
833 | } |
---|
834 | a = col_lm1-1; |
---|
835 | b = col_l; |
---|
836 | tau = fabs(*a)+fabs(*b); |
---|
837 | if( tau > 0.0 ){ |
---|
838 | /*Calculate Givens rotation*/ |
---|
839 | nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau)); |
---|
840 | c = *a/nu; |
---|
841 | s = *b/nu; |
---|
842 | |
---|
843 | *a = nu; /* Fix column before last in R */ |
---|
844 | *b = -s* *(b-1); /* Fix last element of last column in R */ |
---|
845 | *(b-1) = c * *(b-1); /* Fix second last element of last column in R */ |
---|
846 | if( *b < 0){ |
---|
847 | *b = -*b; |
---|
848 | |
---|
849 | col_lm1 = QCOL(r_ncol-1); |
---|
850 | col_l = col_lm1+q_nrow; |
---|
851 | for(j=0;j<q_nrow;j++){ |
---|
852 | tmp = c*col_lm1[j]+s*col_l[j]; |
---|
853 | col_l[j] = -c*col_l[j]+s*col_lm1[j]; |
---|
854 | col_lm1[j] = tmp; |
---|
855 | } |
---|
856 | |
---|
857 | }else{ |
---|
858 | |
---|
859 | col_lm1 = QCOL(r_ncol-1); |
---|
860 | col_l = col_lm1+q_nrow; |
---|
861 | for(j=0;j<q_nrow;j++){ |
---|
862 | tmp = c*col_lm1[j]+s*col_l[j]; |
---|
863 | col_l[j] = c*col_l[j]-s*col_lm1[j]; |
---|
864 | col_lm1[j] = tmp; |
---|
865 | } |
---|
866 | } |
---|
867 | } |
---|
868 | |
---|
869 | } |
---|
870 | r_ncol++; |
---|
871 | |
---|
872 | } |
---|
873 | #if defined (S_Plus) |
---|
874 | static void errmsg(char *string){ |
---|
875 | PROBLEM "%s\n", string RECOVER(NULL_ENTRY); |
---|
876 | } |
---|
877 | #elif defined(Matlab) |
---|
878 | static void errmsg(char *where, char *string){ |
---|
879 | char str[1024]; |
---|
880 | sprintf(str, "Error in %s: %s\n", where, string); |
---|
881 | mexErrMsgTxt(str); |
---|
882 | } |
---|
883 | #else |
---|
884 | static void errmsg(char *where, char *string){ |
---|
885 | fprintf(stderr, "Error in %s: %s\n", where, string); |
---|
886 | exit(EXIT_FAILURE); |
---|
887 | } |
---|
888 | #endif |
---|
889 | |
---|
890 | int assertFail(const char *ex, const char *file, const char *func, const int line) |
---|
891 | { |
---|
892 | static bool ignore = false; |
---|
893 | |
---|
894 | if (!ignore) { |
---|
895 | char str[1024]; |
---|
896 | sprintf(str, "%s:%u: %s(): failed assertion `%s'\n", |
---|
897 | file, line, func, ex); |
---|
898 | mexErrMsgTxt(str); |
---|
899 | } /* if */ |
---|
900 | return(1); |
---|
901 | } |
---|
902 | |
---|
903 | static void fortifyPrintf(const char *s) |
---|
904 | { |
---|
905 | mexPrintf(s); |
---|
906 | } |
---|
907 | |
---|
908 | void |
---|
909 | mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[]) |
---|
910 | { |
---|
911 | mxArray *array_ptr; |
---|
912 | int i; |
---|
913 | long dataSize; |
---|
914 | long dimensions; |
---|
915 | long succeed; |
---|
916 | long verbose = 0; |
---|
917 | long as_sub = 0; |
---|
918 | double lagrangian; |
---|
919 | double threshold; |
---|
920 | double *mx; |
---|
921 | double *y; |
---|
922 | double *yhat1; |
---|
923 | double *r; |
---|
924 | double *beta; |
---|
925 | int arg; |
---|
926 | |
---|
927 | #ifdef FORTIFY |
---|
928 | Fortify_SetOutputFunc(fortifyPrintf); |
---|
929 | Fortify_EnterScope(); |
---|
930 | #endif |
---|
931 | |
---|
932 | if (nrhs > 3) { |
---|
933 | mexErrMsgTxt("Syntax error: too many arguments\n"); |
---|
934 | return; |
---|
935 | } /* if */ |
---|
936 | |
---|
937 | if (nrhs < 3) { |
---|
938 | mexErrMsgTxt("Syntax error: arguments missing\n"); |
---|
939 | return; |
---|
940 | } /* if */ |
---|
941 | |
---|
942 | if (nlhs > 1) { |
---|
943 | mexErrMsgTxt("Syntax error: too many return arguments\n"); |
---|
944 | return; |
---|
945 | } /* if */ |
---|
946 | |
---|
947 | arg = 2; |
---|
948 | array_ptr = (mxArray *) prhs[arg]; |
---|
949 | if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) { |
---|
950 | mexErrMsgTxt("Threshold argument is not a number\n"); |
---|
951 | return; |
---|
952 | } /* if */ |
---|
953 | |
---|
954 | threshold = mxGetScalar(array_ptr); |
---|
955 | |
---|
956 | arg = 1; |
---|
957 | array_ptr = (mxArray *) prhs[arg]; |
---|
958 | if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) { |
---|
959 | mexErrMsgTxt("Label argument is not a number\n"); |
---|
960 | return; |
---|
961 | } /* if */ |
---|
962 | |
---|
963 | y = mxGetPr(array_ptr); |
---|
964 | |
---|
965 | arg = 0; |
---|
966 | array_ptr = (mxArray *) prhs[arg]; |
---|
967 | if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) { |
---|
968 | mexErrMsgTxt("Data set argument is not a double matrix\n"); |
---|
969 | return; |
---|
970 | } /* if */ |
---|
971 | |
---|
972 | dataSize = mxGetM(array_ptr); |
---|
973 | dimensions = mxGetN(array_ptr); |
---|
974 | |
---|
975 | mx = mxGetPr(array_ptr); |
---|
976 | |
---|
977 | arg = 0; |
---|
978 | plhs[arg] = mxCreateDoubleMatrix(dimensions, 1, mxREAL); |
---|
979 | |
---|
980 | yhat1 = new double[dataSize]; |
---|
981 | r = new double[dataSize]; |
---|
982 | beta = mxGetPr(plhs[arg]); |
---|
983 | |
---|
984 | for (i=0; i<dimensions; i++) { |
---|
985 | beta[i] = 0.0; |
---|
986 | } /* for */ |
---|
987 | |
---|
988 | if (dimensions > 0) { |
---|
989 | lasso(mx, &dataSize, &dimensions, &threshold, beta, y, |
---|
990 | yhat1, r, &lagrangian, &succeed, &verbose, &as_sub); |
---|
991 | } else { |
---|
992 | mexPrintf("Dataset with 0 dimensions\n"); |
---|
993 | } /* if */ |
---|
994 | |
---|
995 | delete [] yhat1; |
---|
996 | delete [] r; |
---|
997 | |
---|
998 | #ifdef FORTIFY |
---|
999 | Fortify_LeaveScope(); |
---|
1000 | #endif |
---|
1001 | |
---|
1002 | } |
---|
1003 | |
---|