source: prextra/mexlasso.cc

Last change on this file was 5, checked in by bduin, 14 years ago
File size: 24.1 KB
Line 
1/* Copyright (C) 1998
2   Berwin A Turlach <bturlach@stats.adelaide.edu.au> */
3
4/* This library is free software; you can redistribute it and/or
5   modify it under the terms of the GNU Library General Public License
6   as published by the Free Software Foundation; either version 2 of
7   the License, or (at your option) any later version. */
8
9/* This library is distributed in the hope that it will be useful, but
10   WITHOUT ANY WARRANTY; without even the implied warranty of
11   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12   Library General Public License for more details. */
13
14/* You should have received a copy of the GNU Library General Public
15   License along with this library; if not, write to the Free Software
16   Foundation, Inc., 59 Temple Place, Suite 330, Boston,
17   MA 02111-1307, USA. */
18
19#include "lasso.h"
20//#include "fortify.h"
21
22#define PRECISION FLT_EPSILON
23
24static void lasso_alloc(long n, long m);
25static void lasso_free(void);
26static void qr_init(int n);
27static void qr_incr(void);
28static void qr_free(void);
29static void qr_del(int l, int aug);
30static void qr_add(double *x, int swap);
31#if defined (S_Plus)
32static void errmsg(char* string);
33#else
34static void errmsg(char* where, char* string);
35#endif
36
37  static int QR_CHUNK = 10;
38
39  static double *xtr=NULL, *btmp=NULL, *qtr=NULL,
40                *rinvt_theta=NULL, *step=NULL, ytyd2=0.0;
41  static int *theta=NULL, *nz_x=NULL, num_nz_x=0;
42  static double *qmat=NULL, *rmat=NULL;
43  static int qr_max_size=0, r_ncol=0, q_nrow=0, q_use_row=0;
44  static char *no_dyn_mem_message="Cannot allocate dynamic memory";
45
46void lasso(double *x,           /* input: data */
47           long *pn,            /* input: rows */
48           long *pm,            /* input: columns */
49           double *pt,          /* input: threshold */
50           double *beta,        /* input/ouput: intial/final beta */
51           double *y,           /* input: labels */
52
53           double *yhat1,       /* output */
54           double *r,           /* output */
55           double *lagrangian,  /* output */
56           long *psuc,          /* output: succeed: 0/-1 */
57           long *pverb,         /* input: verbose: yes/no */
58           long *pas_sub)       /* input: no!? */
59{
60
61  double t = *pt, prec;
62  long n = *pn, m = *pm, verb = *pverb, as_sub = *pas_sub;
63  int not_solved;
64  double *x_elem=NULL, tmp, max_val, b_1norm;
65  int i, j, max_ind;
66  double p_obj, d_obj;
67  int num_iter=0, max_num;
68  double b_1norm_old, mu;
69  double rho_up, rho_low, rho;
70  int to_del;
71  double *q_elem, wtc, wtw;
72  double *r_elem;
73  int k;
74  int add;
75  if( !as_sub)
76    lasso_alloc(n,m);
77
78  prec = sqrt(PRECISION);
79
80  if( as_sub ){
81
82  b_1norm = 0.0;
83  for(j=0;j<num_nz_x;j++)
84    b_1norm += fabs(beta[nz_x[j]]);
85  }else{
86
87  b_1norm = 0.0;
88  num_nz_x = 0;
89  for(j=0; j<m; j++){
90    if(fabs(beta[j]) > prec){
91      b_1norm += fabs(beta[j]);
92      nz_x[num_nz_x] = j;
93      num_nz_x++;
94    }else
95      beta[j] = 0.0;
96  }
97  }
98
99  if( b_1norm > t){
100    if(verb){
101      mexPrintf("******************************\n");
102      mexPrintf("Rescaling beta from L1-norm %f to %f\n",b_1norm,t);
103      mexEvalString("drawnow;");
104
105    }
106    for(j=0; j<num_nz_x; j++)
107      beta[nz_x[j]] = beta[nz_x[j]] * t/b_1norm;
108    b_1norm = t;
109  }
110
111  for(i=0; i< n; i++)
112    yhat1[i] = 0.0;
113  for(j=0; j < num_nz_x; j++){
114    /* x_elem points to the first element in the column of X to which
115       the j-th entry in nz_x points  */
116    x_elem = x + nz_x[j]*n;
117    tmp = beta[nz_x[j]];
118    for(i=0; i < n; i++){
119      yhat1[i] += *x_elem*tmp;
120      x_elem++;           /* now we point to the next element in X */
121    }
122  }
123
124  /* calculate the residual vector */
125  for(i=0; i< n; i++)
126    r[i] = y[i]-yhat1[i];
127
128  /* multiply X^T with the residual vector */
129  x_elem = x;
130  for(j=0; j < m; j++){
131    tmp = 0.0;
132    for(i=0; i<n; i++){
133      tmp += *x_elem*r[i];
134      x_elem++;
135    }
136    xtr[j] = tmp;
137  }
138
139  max_val = fabs(xtr[0]);
140  max_ind = 0;
141  for(j=1; j<m; j++)
142    if( fabs(xtr[j]) > max_val ){
143      max_val = fabs(xtr[j]);
144      max_ind = j;
145    }
146
147  if( !as_sub ){
148
149  qr_add(y,TRUE);
150  ytyd2 = *rmat * *rmat/2.0;
151  for(j=0;j<num_nz_x;j++){
152    qr_add(x+nz_x[j]*n, TRUE);
153    if(fabs(beta[nz_x[j]])<prec)
154      theta[j] = xtr[nz_x[j]] < 0 ? -1 : 1;
155    else
156      theta[j] = beta[nz_x[j]] < 0 ? -1 : 1;
157
158  }
159  }
160  if( num_nz_x==0 ){
161    nz_x[0] = max_ind;
162    num_nz_x = 1;
163    if(verb){
164      mexPrintf("******************************\n");
165      mexPrintf("  -->\tAdding variable: %d\n",max_ind+1);
166      mexEvalString("drawnow;");
167    }
168    qr_add(x+max_ind*n, TRUE);
169    theta[0] = xtr[max_ind] < 0 ? -1 : 1;
170  }
171  *psuc=0;
172  if(verb){
173
174
175  /* Find out how many times [[max_val]] is attained */
176  tmp = (1.0-prec)*max_val;
177  if( tmp < prec ) tmp = 0.0;
178  max_num = 1;
179  for(j=0; j<m; j++){
180    if( tmp <= fabs(xtr[j]) && j!=max_ind){
181      /*we found another element equal to the (current)
182        maximal absolute value */
183      max_num++;
184    }
185  }
186
187  /* for the value of the dual objective function we need
188     to calculate the L2 norm of the vector of fitted values */
189  p_obj = 0.0;
190  d_obj = 0.0;
191  for(i=0;i<n;i++){
192    p_obj += r[i]*r[i];
193    d_obj += yhat1[i]*yhat1[i];
194  }
195  p_obj /= 2.0;
196  d_obj = ytyd2 - d_obj/2.0 - t*max_val ;
197  mexPrintf("******************************\n");
198  mexPrintf("\nIteration number: %d\n", num_iter);
199  mexPrintf("Value of primal object function      : %f\n", p_obj);
200  mexPrintf("Value of dual   object function      : %f\n", d_obj);
201  mexPrintf("L1 norm of current beta              : %f", b_1norm);
202  mexPrintf(" <= %f\n", t);
203  mexPrintf("Maximal absolute value in t(X)%%*%%r   : %e", max_val);
204  mexPrintf(" attained %d time(s)\n", max_num);
205  mexPrintf("Number of parameters allowed to vary : %d\n", num_nz_x);
206  mexEvalString("drawnow;");
207  num_iter++;
208  }
209  while(1){
210    do{
211
212  q_elem = qmat;
213
214  for(j=0;j<num_nz_x;j++){
215    tmp = 0.0;
216    for(i=0;i<n;i++,q_elem++) {
217      tmp += *q_elem*r[i];
218    }
219    qtr[j] = tmp;
220  }
221
222  if( b_1norm < (1.0-prec)*t )
223    mu = 0.0;
224  else{
225
226  /* z=R^{-T}\theta can be calculated by solving R^Tz=\theta */
227  r_elem = rmat;
228  for(j=0;j<num_nz_x;j++,r_elem++){
229    tmp = theta[j];
230    for(k=0;k<j;k++,r_elem++)
231      tmp -= *r_elem * rinvt_theta[k];
232    rinvt_theta[j] = tmp /  *r_elem;
233  }
234
235  wtc = 0.0;
236  wtw = 0.0;
237  for(j=0;j<num_nz_x;j++){
238    wtc += rinvt_theta[j]*qtr[j];
239    wtw += rinvt_theta[j]*rinvt_theta[j];
240  }
241  mu = (wtc-(t-b_1norm))/wtw;
242  }
243
244  if(mu<=0.0)
245    for(j=0;j<num_nz_x;j++)
246      step[j] = qtr[j];
247  else
248    for(j=0;j<num_nz_x;j++)
249      step[j] = qtr[j] - mu*rinvt_theta[j];
250
251  /* h=R^{-1}z can be calculated by solving Rh=z */
252  for(j=num_nz_x-1;j>=0;j--){
253    tmp = step[j];
254    for(k=num_nz_x-1;k>j;k--)
255      tmp -= RMAT(j,k) * step[k];
256    step[j] = tmp /  RMAT(j,j);
257  }
258  for(j=0;j<num_nz_x;j++){
259    btmp[j] = beta[nz_x[j]];
260    beta[nz_x[j]] += step[j];
261  }
262  b_1norm_old=b_1norm;
263
264  b_1norm = 0.0;
265  for(j=0;j<num_nz_x;j++) {
266    b_1norm += fabs(beta[nz_x[j]]);
267  }
268
269  not_solved=FALSE;
270  if( b_1norm > (1+prec)*t){
271    not_solved=TRUE;
272    if(b_1norm_old < (1.0-prec)*t){
273
274    if(verb) {
275      mexPrintf("  -->\tStepping onto the border of the L1 ball.\n");
276      mexEvalString("drawnow;");
277    }
278    rho_up  = rho = 1.0;
279    rho_low = 0.0;
280    while( fabs(t-b_1norm) > prec*t ){
281      if( b_1norm > t){
282        rho_up = rho;
283        rho = (rho+rho_low)/2.0;
284      }
285      if( b_1norm < t){
286        rho_low = rho;
287        rho = (rho+rho_up)/2.0;
288      }
289      if(rho < prec) break;
290      for(j=0; j<num_nz_x; j++) {
291        beta[nz_x[j]] = btmp[j]+rho*step[j];
292      }
293
294  b_1norm = 0.0;
295  for(j=0;j<num_nz_x;j++)
296    b_1norm += fabs(beta[nz_x[j]]);
297    }
298    for(j=0;j<num_nz_x;j++){
299      if(fabs(beta[nz_x[j]])<prec)
300        theta[j] = btmp[j] > 0 ? -1 : 1;
301      else
302        theta[j] = beta[nz_x[j]] < 0 ? -1 : 1;
303    }
304    }else{
305
306
307  rho = 1.0;
308  to_del = -1;
309  for(j=0; j<num_nz_x; j++){
310    if(fabs(step[j]) > prec){
311      tmp = -btmp[j]/(step[j]);
312      if( 0.0 < tmp && tmp < rho ){
313        rho = tmp;
314        to_del = j;
315      }
316    }
317  }
318  if(to_del < 0 ){
319    *psuc= -1;
320    goto EXIT_HERE;
321  }
322  for(j=0; j<num_nz_x; j++) {
323    beta[nz_x[j]] = btmp[j]+rho*step[j];
324  }
325  if(verb) {
326    mexPrintf("  -->\tRemoving variable: %d",nz_x[to_del]+1);
327    mexEvalString("drawnow;");
328  }
329  beta[nz_x[to_del]] = 0.0;
330  qr_del(to_del,TRUE);
331  for(j=to_del+1; j< num_nz_x; j++){
332    nz_x[j-1] = nz_x[j];
333    theta[j-1] = theta[j];
334  }
335  num_nz_x--;
336 
337  b_1norm = 0.0;
338  for(j=0;j<num_nz_x;j++)
339    b_1norm += fabs(beta[nz_x[j]]);
340  if(verb)
341    if(b_1norm < (1-prec)*t)
342      mexPrintf(", and stepping into the interior of the L1 ball\n");
343    else
344      mexPrintf("\n");
345    }
346    mexEvalString("drawnow;");
347  }
348
349  for(i=0; i< n; i++)
350    yhat1[i] = 0.0;
351  for(j=0; j < num_nz_x; j++){
352    /* x_elem points to the first element in the column of X to which
353       the j-th entry in nz_x points  */
354    x_elem = x + nz_x[j]*n;
355    tmp = beta[nz_x[j]];
356    for(i=0; i < n; i++){
357      yhat1[i] += *x_elem*tmp;
358      x_elem++;           /* now we point to the next element in X */
359    }
360  }
361
362  /* calculate the residual vector */
363  for(i=0; i< n; i++) {
364    r[i] = y[i]-yhat1[i];
365  }
366    }while(not_solved);
367   
368 
369  for(i=0; i< n; i++)
370    yhat1[i] = 0.0;
371  for(j=0; j < num_nz_x; j++){
372    /* x_elem points to the first element in the column of X to which
373       the j-th entry in nz_x points  */
374    x_elem = x + nz_x[j]*n;
375    tmp = beta[nz_x[j]];
376    for(i=0; i < n; i++){
377      yhat1[i] += *x_elem*tmp;
378      x_elem++;           /* now we point to the next element in X */
379    }
380  }
381
382  /* calculate the residual vector */
383  for(i=0; i< n; i++)
384    r[i] = y[i]-yhat1[i];
385
386
387  /* multiply X^T with the residual vector */
388  x_elem = x;
389  for(j=0; j < m; j++){
390    tmp = 0.0;
391    for(i=0; i<n; i++){
392      tmp += *x_elem*r[i];
393      x_elem++;
394    }
395    xtr[j] = tmp;
396  }
397  max_val = fabs(xtr[0]);
398  max_ind = 0;
399  for(j=1; j<m; j++)
400    if( fabs(xtr[j]) > max_val ){
401      max_val = fabs(xtr[j]);
402      max_ind = j;
403    }
404
405    if(verb){
406
407
408  /* Find out how many times [[max_val]] is attained */
409  tmp = (1.0-prec)*max_val;
410  if( tmp < prec ) tmp = 0.0;
411  max_num = 1;
412  for(j=0; j<m; j++){
413    if( tmp <= fabs(xtr[j]) && j!=max_ind){
414      /*we found another element equal to the (current)
415        maximal absolute value */
416      max_num++;
417    }
418  }
419
420  /* for the value of the dual objective function we need
421     to calculate the L2 norm of the vector of fitted values */
422  p_obj = 0.0;
423  d_obj = 0.0;
424  for(i=0;i<n;i++){
425    p_obj += r[i]*r[i];
426    d_obj += yhat1[i]*yhat1[i];
427  }
428  p_obj /= 2.0;
429  d_obj = ytyd2 - d_obj/2.0 - t*max_val ;
430  mexPrintf("******************************\n");
431  mexPrintf("\nIteration number: %d\n", num_iter);
432  mexPrintf("Value of primal object function      : %f\n", p_obj);
433  mexPrintf("Value of dual   object function      : %f\n", d_obj);
434  mexPrintf("L1 norm of current beta              : %f", b_1norm);
435  mexPrintf(" <= %f\n", t);
436  mexPrintf("Maximal absolute value in t(X)%%*%%r   : %e", max_val);
437  mexPrintf(" attained %d time(s)\n", max_num);
438  mexPrintf("Number of parameters allowed to vary : %d\n", num_nz_x);
439  mexEvalString("drawnow;");
440  num_iter++;
441    }
442   
443  add = TRUE;
444  for(i=0; i<num_nz_x; i++)
445    if(nz_x[i]==max_ind){
446      add = FALSE;
447      break;
448    }
449  if(add){
450    qr_add(x+max_ind*n,TRUE);
451    nz_x[num_nz_x] = max_ind;
452    theta[num_nz_x] = xtr[max_ind]<0 ? -1 : 1;
453    num_nz_x++;
454  }else{
455    *lagrangian = max_val;
456    break;
457  }
458
459  b_1norm = 0.0;
460  for(j=0;j<num_nz_x;j++)
461    b_1norm += fabs(beta[nz_x[j]]);
462  if(verb) {
463    mexPrintf("  -->\tAdding variable: %d\n",max_ind+1);
464    mexEvalString("drawnow;");
465  }
466  }
467  EXIT_HERE:
468  if( !as_sub)
469    lasso_free();
470}
471void mult_lasso(double *x, long *pn, long *pm, double *pt, long *pl,
472                double *beta, double *y, double *yhat1, double *r,
473                double *lagrangian, long *psuc, long *pverb)
474{
475
476  double prec;
477  long n = *pn, m = *pm, l = *pl, verb = *pverb, as_sub = TRUE, i, j;
478  lasso_alloc(n,m);
479
480qr_add(y,TRUE);
481ytyd2 = *rmat * *rmat/2.0;
482prec = sqrt(PRECISION);
483num_nz_x = 0;
484for(j=0; j<m; j++){
485  if(fabs(beta[j]) > prec){
486    qr_add(x+j*n, TRUE);
487    nz_x[num_nz_x] = j;
488    num_nz_x++;
489  }else
490    beta[j] = 0.0;
491}
492*psuc = 0;
493  for(i=0; i<l; i++){
494
495if(verb){
496  mexPrintf("\n\n++++++++++++++++++++++++++++++\n");
497  mexPrintf("Solving problem number %ld with bound %f\n", i+1, pt[i]);
498  mexPrintf("++++++++++++++++++++++++++++++\n");
499  mexEvalString("drawnow;");
500}
501if(i>0) Memcpy(beta,beta-m,m);
502lasso(x, pn, pm, pt+i, beta, y, yhat1, r, lagrangian, psuc, pverb, &as_sub);
503if( *psuc < 0 ){
504  goto EXIT_HERE;
505}
506
507beta += m;
508yhat1 += n;
509r += n;
510lagrangian++;
511  }
512  EXIT_HERE:
513  lasso_free();
514}
515static void lasso_alloc(long n, long m){
516
517#if defined(S_Plus)
518  if( nz_x != NULL || theta != NULL || xtr != NULL || btmp != NULL ||
519      qtr != NULL || rinvt_theta != NULL || step != NULL ||
520      num_nz_x != 0 || ytyd2 != 0.0){
521    MESSAGE "Possible memory corruption or memory leak.\n  We"
522      "advise to restart your S+ session"  WARNING(NULL_ENTRY);
523     lasso_free();
524  }
525#endif
526
527  QR_CHUNK = m;  // Added: CJV
528
529  nz_x = Calloc(m,int);
530  if( nz_x == NULL )
531    ERRMSG("lasso_alloc", no_dyn_mem_message);
532  theta = Calloc(m,int);
533  if( theta==NULL )
534    ERRMSG("lasso_alloc", no_dyn_mem_message);
535  xtr = Calloc(m,double);
536  if( xtr==NULL )
537    ERRMSG("lasso_alloc", no_dyn_mem_message);
538  btmp = Calloc(m,double);
539  if( btmp==NULL )
540    ERRMSG("lasso_alloc", no_dyn_mem_message);
541  qtr = Calloc(m,double);
542  if( qtr==NULL )
543    ERRMSG("lasso_alloc", no_dyn_mem_message);
544  rinvt_theta = Calloc(m,double);
545  if( rinvt_theta==NULL )
546    ERRMSG("lasso_alloc", no_dyn_mem_message);
547  step = Calloc(m,double);
548  if( step==NULL )
549    ERRMSG("lasso_alloc", no_dyn_mem_message);
550  qr_init(n);
551}
552static void lasso_free(void){
553  num_nz_x=0;
554  ytyd2 = 0.0;
555  Free(nz_x);
556  Free(theta);
557  Free(xtr);
558  Free(btmp);
559  Free(qtr);
560  Free(rinvt_theta);
561  Free(step);
562  qr_free();
563}
564static void qr_init(int n) {
565#if defined (S_Plus)
566  if(qr_max_size!=0 || r_ncol!=0 || q_nrow!=0 || q_use_row!=0 ||
567     qmat!=NULL || rmat!=NULL){
568    MESSAGE "Possible memory corruption or memory leak.\n  We"
569      "advise to restart your S+ session"  WARNING(NULL_ENTRY);
570    qr_free();
571  }
572#endif
573  qr_max_size = QR_CHUNK;
574  r_ncol = 0;
575  q_nrow = n;
576  qmat = Calloc(n*qr_max_size,double);
577  if(qmat==NULL)
578    ERRMSG("qr_init", no_dyn_mem_message);
579  rmat = Calloc(qr_max_size*(qr_max_size+1)/2,double);
580  if(rmat==NULL)
581    ERRMSG("qr_init", no_dyn_mem_message);
582}
583static void qr_incr(void) {
584  qr_max_size += QR_CHUNK;
585
586  /* reallocate R always */
587  rmat = Realloc(rmat,qr_max_size*(qr_max_size+1)/2,double);
588  if(rmat==NULL)
589    ERRMSG("qr_incr", no_dyn_mem_message);
590
591  /* reallocate Q only if necessary and only to maximal necessary size */
592  if( qr_max_size >= q_nrow){
593    if( qr_max_size-QR_CHUNK < q_nrow){
594      qmat = Realloc(qmat,q_nrow*q_nrow,double);
595      if(qmat==NULL)
596        ERRMSG("qr_incr", no_dyn_mem_message);
597    }
598  }else{
599    qmat = Realloc(qmat,q_nrow*qr_max_size,double);
600    if(qmat==NULL)
601      ERRMSG("qr_incr", no_dyn_mem_message);
602  }
603}
604static void qr_free(void)
605{
606  qr_max_size = 0;
607  r_ncol = 0;
608  q_nrow = 0;
609  q_use_row = 0;
610  Free(qmat);
611  Free(rmat);
612}
613static void qr_del(int l, int aug) {
614
615  double c, s, tau, nu, *col_k, *col_kp1, *a, *b, tmp;
616  int i, j, k, l0;
617
618  if( l<0 || l>=r_ncol )
619    ERRMSG("qr_del", "Invalid column number");
620  r_ncol--;
621  if( l==r_ncol)
622    if( aug )
623      ERRMSG("qr_del", "Trying to delete last column of augmented matrix");
624    else
625      return;
626
627  /* this is TRUE if $m\le n$ ($m-1\le n$ if the matrix is augmented) */
628  if (r_ncol < q_nrow ){
629    for(k=l; k<r_ncol; k++) /* Update the factorisation and be done */
630    {
631
632col_k = RCOL(k);  /* first element in column $k$ in R */
633Memcpy(col_k,col_k+k+1,k+1);
634a = col_k+k;
635b = a+k+2;
636tau = fabs(*a)+fabs(*b);
637if( tau == 0.0 ) continue; /* both elements are zero
638                              nothing to update */
639nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau));
640c = *a/nu;
641s = *b/nu;
642*a = nu;
643  b += k+2;
644  a = b-1;
645  for(j=k+2;j<=r_ncol;j++, a+=j, b+=j){
646    tmp = c * *a + s * *b;
647    *b  = c * *b - s * *a;
648    *a  = tmp;
649  }
650  col_k = QCOL(k);
651  col_kp1 = col_k+q_nrow;
652  for(j=0;j<q_nrow;j++){
653    tmp = c*col_k[j]+s*col_kp1[j];
654    col_kp1[j] = c*col_kp1[j]-s*col_k[j];
655    col_k[j] = tmp;
656  }
657    }
658  }else
659    if( l < q_nrow ){
660    /* Update columns upto $m$ and than shift remaining columns */
661      for(k=l; k<q_nrow-1; k++){
662
663col_k = RCOL(k);  /* first element in column $k$ in R */
664Memcpy(col_k,col_k+k+1,k+1);
665a = col_k+k;
666b = a+k+2;
667tau = fabs(*a)+fabs(*b);
668if( tau == 0.0 ) continue; /* both elements are zero
669                              nothing to update */
670nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau));
671c = *a/nu;
672s = *b/nu;
673*a = nu;
674  b += k+2;
675  a = b-1;
676  for(j=k+2;j<=r_ncol;j++, a+=j, b+=j){
677    tmp = c * *a + s * *b;
678    *b  = c * *b - s * *a;
679    *a  = tmp;
680  }
681  col_k = QCOL(k);
682  col_kp1 = col_k+q_nrow;
683  for(j=0;j<q_nrow;j++){
684    tmp = c*col_k[j]+s*col_kp1[j];
685    col_kp1[j] = c*col_kp1[j]-s*col_k[j];
686    col_k[j] = tmp;
687  }
688      }
689      l0 = q_nrow-1;
690     
691  col_k   = RCOL(l0); /* first element in column $l_0$ in R */
692  for(i=l0; i<r_ncol; i++, col_k +=i)
693     Memcpy(col_k,col_k+i+1,q_nrow);
694
695  /* [[rmat + (q_nrow-1)*q_nrow/2+q_nrow-1]] is last element in column
696     [[q_nrow]] in R */
697  a = rmat + (q_nrow-1)*(q_nrow+2)/2;
698  if( *a < 0 ){
699    /* [[j<r_ncol]] sufficient since we shifted the columns already */
700    for(j=l0;j<r_ncol;j++,a+=j)
701      *a = - *a;
702    col_k = qmat+(q_nrow-1)*q_nrow;
703    for(i=0;i<q_nrow;i++,col_k++)
704      *col_k = -*col_k;
705  }
706    }else{ /* just shift last columns to the left */
707      l0 = l;
708     
709  col_k   = RCOL(l0); /* first element in column $l_0$ in R */
710  for(i=l0; i<r_ncol; i++, col_k +=i)
711     Memcpy(col_k,col_k+i+1,q_nrow);
712 
713  /* [[rmat + (q_nrow-1)*q_nrow/2+q_nrow-1]] is last element in column
714     [[q_nrow]] in R */
715  a = rmat + (q_nrow-1)*(q_nrow+2)/2;
716  if( *a < 0 ){
717    /* [[j<r_ncol]] sufficient since we shifted the columns already */
718    for(j=l0;j<r_ncol;j++,a+=j)
719      *a = - *a;
720    col_k = qmat+(q_nrow-1)*q_nrow;
721    for(i=0;i<q_nrow;i++,col_k++)
722      *col_k = -*col_k;
723  }
724    }
725}
726static void qr_add(double *x, int swap){
727
728  double tmp, norm_orig, norm_init, norm_last, *q_new_col;
729  int i;
730  double *col_l, *col_lm1, *q_elem, *r_new_col;
731  double norm_new;
732  int j;
733  double c, s, tau, nu, *a, *b;
734
735  if( r_ncol == qr_max_size )
736    qr_incr();
737
738norm_orig = 0.0;
739tmp = 0.0;
740for(i=0; i<q_nrow; i++) tmp += x[i]*x[i];
741norm_orig = sqrt(tmp);
742if(r_ncol<q_nrow){
743  q_new_col = QCOL(r_ncol); /* points to the first element of new column */
744  tmp = 0.0;
745  for(i=0; i<q_nrow; i++){
746    q_new_col[i] = x[i]/norm_orig;
747    tmp += q_new_col[i]*q_new_col[i];
748  }
749  if(r_ncol==0){
750    *rmat = norm_orig;
751    r_ncol++;
752    return;
753  }
754  norm_init = norm_last = sqrt(tmp);
755}
756r_new_col = RCOL(r_ncol);
757for(i=0; i<=r_ncol; i++) r_new_col[i] = 0.0;
758if( r_ncol >= q_nrow ){
759  /* Multiply new column with Q-transpose in [[qr_add]] */
760  q_elem = qmat;
761  for(i=0;i<q_nrow;i++){
762    tmp = 0.0;
763    for(j=0;j<q_nrow;j++,q_elem++)
764      tmp += *q_elem * x[j];
765    r_new_col[i] = tmp;
766  }
767  if( swap ){ /* swap last two columns */
768    col_l   = r_new_col;
769    col_lm1 = r_new_col-r_ncol;
770    for(i=0; i<q_nrow; i++,col_l++,col_lm1++){
771      tmp = *col_l;
772      *col_l = *col_lm1;
773      *col_lm1 = tmp;
774    }
775    col_lm1--;
776    col_l--;
777    if(q_nrow==r_ncol && *col_lm1<0){
778      /* if new R[$n$,$n$] is negative, than change signs */
779      *col_lm1 = - *col_lm1;
780      *col_l   = - *col_l;
781      q_new_col = qmat+(q_nrow-1)*q_nrow;
782      for(j=0;j<q_nrow;j++)
783        q_new_col[j] = -q_new_col[j];
784    }
785  }
786  r_ncol++;
787  return;
788}
789  while(1){
790
791q_elem = qmat;
792for(j=0; j<r_ncol;j++){
793  tmp = 0.0;
794  for(i=0;i<q_nrow;i++,q_elem++)
795    tmp += *q_elem * q_new_col[i];
796  r_new_col[j] += tmp;
797  q_elem -= q_nrow;
798  for(i=0;i<q_nrow;i++,q_elem++)
799    q_new_col[i] -= *q_elem*tmp;
800}
801tmp = 0.0;
802for(i=0;i<q_nrow;i++) tmp += q_new_col[i]*q_new_col[i];
803norm_new = sqrt(tmp);
804
805  if( norm_new >= norm_last/2.0) break;
806  if( norm_new > 0.1*norm_init*PRECISION )
807    norm_last = norm_new;
808  else{
809    norm_init = norm_last = 0.1*norm_last*PRECISION;
810    for(i=0;i<q_nrow;i++)
811      q_new_col[i] = 0.0;
812    if(q_use_row == q_nrow)
813      ERRMSG("qr_add","Cannot orthogonalise new column\n");
814
815    q_new_col[q_use_row] = norm_new = norm_last;
816    q_use_row++;
817   }
818  }
819 
820  for(i=0;i<q_nrow;i++)
821    q_new_col[i] /= norm_new;
822  for(i=0;i<r_ncol;i++)
823    r_new_col[i] *= norm_orig;
824  r_new_col[r_ncol] = norm_new*norm_orig;
825  if(swap){
826
827col_l = r_new_col; /* first element in last column in R */
828col_lm1 = r_new_col-r_ncol; /* first element in column before last in R */
829for(j=0;j<r_ncol;j++,col_l++,col_lm1++){
830  tmp = *col_l;
831  *col_l = *col_lm1;
832  *col_lm1 = tmp;
833}
834a = col_lm1-1;
835b = col_l;
836tau = fabs(*a)+fabs(*b);
837if( tau > 0.0 ){
838  /*Calculate Givens rotation*/
839  nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau));
840  c = *a/nu;
841  s = *b/nu;
842
843  *a = nu;   /* Fix column before last in R */
844  *b = -s* *(b-1);  /* Fix last element of last column in R */
845  *(b-1) = c * *(b-1); /* Fix second last element of last column in R */
846  if( *b < 0){
847    *b = -*b;
848   
849  col_lm1 = QCOL(r_ncol-1);
850  col_l   = col_lm1+q_nrow;
851  for(j=0;j<q_nrow;j++){
852    tmp = c*col_lm1[j]+s*col_l[j];
853    col_l[j] = -c*col_l[j]+s*col_lm1[j];
854    col_lm1[j] = tmp;
855  }
856 
857  }else{
858   
859  col_lm1 = QCOL(r_ncol-1);
860  col_l   = col_lm1+q_nrow;
861  for(j=0;j<q_nrow;j++){
862    tmp = c*col_lm1[j]+s*col_l[j];
863    col_l[j] = c*col_l[j]-s*col_lm1[j];
864    col_lm1[j] = tmp;
865  }
866  }
867}
868
869  }
870  r_ncol++;
871
872}
873#if defined (S_Plus)
874static void errmsg(char *string){
875  PROBLEM "%s\n", string RECOVER(NULL_ENTRY);
876}
877#elif defined(Matlab)
878static void errmsg(char *where, char *string){
879  char str[1024];
880  sprintf(str, "Error in %s: %s\n", where, string);
881  mexErrMsgTxt(str);
882}
883#else
884static void errmsg(char *where, char *string){
885  fprintf(stderr, "Error in %s: %s\n", where, string);
886  exit(EXIT_FAILURE);
887}
888#endif
889
890int assertFail(const char *ex, const char *file, const char *func, const int line)
891{
892static bool ignore = false;
893
894    if (!ignore) {
895        char str[1024];
896        sprintf(str, "%s:%u: %s(): failed assertion `%s'\n",
897                     file, line, func, ex);
898        mexErrMsgTxt(str);
899    } /* if */
900    return(1);
901}
902
903static void fortifyPrintf(const char *s)
904{
905   mexPrintf(s);
906}
907
908void
909mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[])
910{
911   mxArray *array_ptr;
912   int i;
913   long dataSize;
914   long dimensions;
915   long succeed;
916   long verbose = 0;
917   long as_sub = 0;
918   double lagrangian;
919   double threshold;
920   double *mx;
921   double *y;
922   double *yhat1;
923   double *r;
924   double *beta;
925   int arg;
926
927#ifdef FORTIFY
928   Fortify_SetOutputFunc(fortifyPrintf);
929   Fortify_EnterScope();
930#endif
931
932   if (nrhs > 3) {
933      mexErrMsgTxt("Syntax error: too many arguments\n");
934      return;
935   } /* if */
936
937   if (nrhs < 3) {
938      mexErrMsgTxt("Syntax error: arguments missing\n");
939      return;
940   } /* if */
941
942   if (nlhs > 1) {
943      mexErrMsgTxt("Syntax error: too many return arguments\n");
944      return;
945   } /* if */
946
947   arg = 2;
948   array_ptr = (mxArray *) prhs[arg];
949   if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) {
950      mexErrMsgTxt("Threshold argument is not a number\n");
951      return;
952   } /* if */
953
954   threshold = mxGetScalar(array_ptr);
955
956   arg = 1;
957   array_ptr = (mxArray *) prhs[arg];
958   if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) {
959      mexErrMsgTxt("Label argument is not a number\n");
960      return;
961   } /* if */
962
963   y = mxGetPr(array_ptr);
964
965   arg = 0;
966   array_ptr = (mxArray *) prhs[arg];
967   if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) {
968      mexErrMsgTxt("Data set argument is not a double matrix\n");
969      return;
970   } /* if */
971
972   dataSize = mxGetM(array_ptr);
973   dimensions = mxGetN(array_ptr);
974
975   mx = mxGetPr(array_ptr);
976
977   arg = 0;
978   plhs[arg] = mxCreateDoubleMatrix(dimensions, 1, mxREAL);
979
980   yhat1 = new double[dataSize];
981   r = new double[dataSize];
982   beta = mxGetPr(plhs[arg]);
983
984   for (i=0; i<dimensions; i++) {
985       beta[i] = 0.0;
986   } /* for */
987
988   if (dimensions > 0) {
989       lasso(mx, &dataSize, &dimensions, &threshold, beta, y,
990             yhat1, r, &lagrangian, &succeed, &verbose, &as_sub);
991   } else {
992       mexPrintf("Dataset with 0 dimensions\n");
993   } /* if */
994
995   delete [] yhat1;
996   delete [] r;
997
998#ifdef  FORTIFY
999   Fortify_LeaveScope();
1000#endif
1001
1002}
1003
Note: See TracBrowser for help on using the repository browser.