[5] | 1 | /* Copyright (C) 1998 |
---|
| 2 | Berwin A Turlach <bturlach@stats.adelaide.edu.au> */ |
---|
| 3 | |
---|
| 4 | /* This library is free software; you can redistribute it and/or |
---|
| 5 | modify it under the terms of the GNU Library General Public License |
---|
| 6 | as published by the Free Software Foundation; either version 2 of |
---|
| 7 | the License, or (at your option) any later version. */ |
---|
| 8 | |
---|
| 9 | /* This library is distributed in the hope that it will be useful, but |
---|
| 10 | WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
---|
| 12 | Library General Public License for more details. */ |
---|
| 13 | |
---|
| 14 | /* You should have received a copy of the GNU Library General Public |
---|
| 15 | License along with this library; if not, write to the Free Software |
---|
| 16 | Foundation, Inc., 59 Temple Place, Suite 330, Boston, |
---|
| 17 | MA 02111-1307, USA. */ |
---|
| 18 | |
---|
| 19 | #include "lasso.h" |
---|
| 20 | //#include "fortify.h" |
---|
| 21 | |
---|
| 22 | #define PRECISION FLT_EPSILON |
---|
| 23 | |
---|
| 24 | static void lasso_alloc(long n, long m); |
---|
| 25 | static void lasso_free(void); |
---|
| 26 | static void qr_init(int n); |
---|
| 27 | static void qr_incr(void); |
---|
| 28 | static void qr_free(void); |
---|
| 29 | static void qr_del(int l, int aug); |
---|
| 30 | static void qr_add(double *x, int swap); |
---|
| 31 | #if defined (S_Plus) |
---|
| 32 | static void errmsg(char* string); |
---|
| 33 | #else |
---|
| 34 | static void errmsg(char* where, char* string); |
---|
| 35 | #endif |
---|
| 36 | |
---|
| 37 | static int QR_CHUNK = 10; |
---|
| 38 | |
---|
| 39 | static double *xtr=NULL, *btmp=NULL, *qtr=NULL, |
---|
| 40 | *rinvt_theta=NULL, *step=NULL, ytyd2=0.0; |
---|
| 41 | static int *theta=NULL, *nz_x=NULL, num_nz_x=0; |
---|
| 42 | static double *qmat=NULL, *rmat=NULL; |
---|
| 43 | static int qr_max_size=0, r_ncol=0, q_nrow=0, q_use_row=0; |
---|
| 44 | static char *no_dyn_mem_message="Cannot allocate dynamic memory"; |
---|
| 45 | |
---|
| 46 | void lasso(double *x, /* input: data */ |
---|
| 47 | long *pn, /* input: rows */ |
---|
| 48 | long *pm, /* input: columns */ |
---|
| 49 | double *pt, /* input: threshold */ |
---|
| 50 | double *beta, /* input/ouput: intial/final beta */ |
---|
| 51 | double *y, /* input: labels */ |
---|
| 52 | |
---|
| 53 | double *yhat1, /* output */ |
---|
| 54 | double *r, /* output */ |
---|
| 55 | double *lagrangian, /* output */ |
---|
| 56 | long *psuc, /* output: succeed: 0/-1 */ |
---|
| 57 | long *pverb, /* input: verbose: yes/no */ |
---|
| 58 | long *pas_sub) /* input: no!? */ |
---|
| 59 | { |
---|
| 60 | |
---|
| 61 | double t = *pt, prec; |
---|
| 62 | long n = *pn, m = *pm, verb = *pverb, as_sub = *pas_sub; |
---|
| 63 | int not_solved; |
---|
| 64 | double *x_elem=NULL, tmp, max_val, b_1norm; |
---|
| 65 | int i, j, max_ind; |
---|
| 66 | double p_obj, d_obj; |
---|
| 67 | int num_iter=0, max_num; |
---|
| 68 | double b_1norm_old, mu; |
---|
| 69 | double rho_up, rho_low, rho; |
---|
| 70 | int to_del; |
---|
| 71 | double *q_elem, wtc, wtw; |
---|
| 72 | double *r_elem; |
---|
| 73 | int k; |
---|
| 74 | int add; |
---|
| 75 | if( !as_sub) |
---|
| 76 | lasso_alloc(n,m); |
---|
| 77 | |
---|
| 78 | prec = sqrt(PRECISION); |
---|
| 79 | |
---|
| 80 | if( as_sub ){ |
---|
| 81 | |
---|
| 82 | b_1norm = 0.0; |
---|
| 83 | for(j=0;j<num_nz_x;j++) |
---|
| 84 | b_1norm += fabs(beta[nz_x[j]]); |
---|
| 85 | }else{ |
---|
| 86 | |
---|
| 87 | b_1norm = 0.0; |
---|
| 88 | num_nz_x = 0; |
---|
| 89 | for(j=0; j<m; j++){ |
---|
| 90 | if(fabs(beta[j]) > prec){ |
---|
| 91 | b_1norm += fabs(beta[j]); |
---|
| 92 | nz_x[num_nz_x] = j; |
---|
| 93 | num_nz_x++; |
---|
| 94 | }else |
---|
| 95 | beta[j] = 0.0; |
---|
| 96 | } |
---|
| 97 | } |
---|
| 98 | |
---|
| 99 | if( b_1norm > t){ |
---|
| 100 | if(verb){ |
---|
| 101 | mexPrintf("******************************\n"); |
---|
| 102 | mexPrintf("Rescaling beta from L1-norm %f to %f\n",b_1norm,t); |
---|
| 103 | mexEvalString("drawnow;"); |
---|
| 104 | |
---|
| 105 | } |
---|
| 106 | for(j=0; j<num_nz_x; j++) |
---|
| 107 | beta[nz_x[j]] = beta[nz_x[j]] * t/b_1norm; |
---|
| 108 | b_1norm = t; |
---|
| 109 | } |
---|
| 110 | |
---|
| 111 | for(i=0; i< n; i++) |
---|
| 112 | yhat1[i] = 0.0; |
---|
| 113 | for(j=0; j < num_nz_x; j++){ |
---|
| 114 | /* x_elem points to the first element in the column of X to which |
---|
| 115 | the j-th entry in nz_x points */ |
---|
| 116 | x_elem = x + nz_x[j]*n; |
---|
| 117 | tmp = beta[nz_x[j]]; |
---|
| 118 | for(i=0; i < n; i++){ |
---|
| 119 | yhat1[i] += *x_elem*tmp; |
---|
| 120 | x_elem++; /* now we point to the next element in X */ |
---|
| 121 | } |
---|
| 122 | } |
---|
| 123 | |
---|
| 124 | /* calculate the residual vector */ |
---|
| 125 | for(i=0; i< n; i++) |
---|
| 126 | r[i] = y[i]-yhat1[i]; |
---|
| 127 | |
---|
| 128 | /* multiply X^T with the residual vector */ |
---|
| 129 | x_elem = x; |
---|
| 130 | for(j=0; j < m; j++){ |
---|
| 131 | tmp = 0.0; |
---|
| 132 | for(i=0; i<n; i++){ |
---|
| 133 | tmp += *x_elem*r[i]; |
---|
| 134 | x_elem++; |
---|
| 135 | } |
---|
| 136 | xtr[j] = tmp; |
---|
| 137 | } |
---|
| 138 | |
---|
| 139 | max_val = fabs(xtr[0]); |
---|
| 140 | max_ind = 0; |
---|
| 141 | for(j=1; j<m; j++) |
---|
| 142 | if( fabs(xtr[j]) > max_val ){ |
---|
| 143 | max_val = fabs(xtr[j]); |
---|
| 144 | max_ind = j; |
---|
| 145 | } |
---|
| 146 | |
---|
| 147 | if( !as_sub ){ |
---|
| 148 | |
---|
| 149 | qr_add(y,TRUE); |
---|
| 150 | ytyd2 = *rmat * *rmat/2.0; |
---|
| 151 | for(j=0;j<num_nz_x;j++){ |
---|
| 152 | qr_add(x+nz_x[j]*n, TRUE); |
---|
| 153 | if(fabs(beta[nz_x[j]])<prec) |
---|
| 154 | theta[j] = xtr[nz_x[j]] < 0 ? -1 : 1; |
---|
| 155 | else |
---|
| 156 | theta[j] = beta[nz_x[j]] < 0 ? -1 : 1; |
---|
| 157 | |
---|
| 158 | } |
---|
| 159 | } |
---|
| 160 | if( num_nz_x==0 ){ |
---|
| 161 | nz_x[0] = max_ind; |
---|
| 162 | num_nz_x = 1; |
---|
| 163 | if(verb){ |
---|
| 164 | mexPrintf("******************************\n"); |
---|
| 165 | mexPrintf(" -->\tAdding variable: %d\n",max_ind+1); |
---|
| 166 | mexEvalString("drawnow;"); |
---|
| 167 | } |
---|
| 168 | qr_add(x+max_ind*n, TRUE); |
---|
| 169 | theta[0] = xtr[max_ind] < 0 ? -1 : 1; |
---|
| 170 | } |
---|
| 171 | *psuc=0; |
---|
| 172 | if(verb){ |
---|
| 173 | |
---|
| 174 | |
---|
| 175 | /* Find out how many times [[max_val]] is attained */ |
---|
| 176 | tmp = (1.0-prec)*max_val; |
---|
| 177 | if( tmp < prec ) tmp = 0.0; |
---|
| 178 | max_num = 1; |
---|
| 179 | for(j=0; j<m; j++){ |
---|
| 180 | if( tmp <= fabs(xtr[j]) && j!=max_ind){ |
---|
| 181 | /*we found another element equal to the (current) |
---|
| 182 | maximal absolute value */ |
---|
| 183 | max_num++; |
---|
| 184 | } |
---|
| 185 | } |
---|
| 186 | |
---|
| 187 | /* for the value of the dual objective function we need |
---|
| 188 | to calculate the L2 norm of the vector of fitted values */ |
---|
| 189 | p_obj = 0.0; |
---|
| 190 | d_obj = 0.0; |
---|
| 191 | for(i=0;i<n;i++){ |
---|
| 192 | p_obj += r[i]*r[i]; |
---|
| 193 | d_obj += yhat1[i]*yhat1[i]; |
---|
| 194 | } |
---|
| 195 | p_obj /= 2.0; |
---|
| 196 | d_obj = ytyd2 - d_obj/2.0 - t*max_val ; |
---|
| 197 | mexPrintf("******************************\n"); |
---|
| 198 | mexPrintf("\nIteration number: %d\n", num_iter); |
---|
| 199 | mexPrintf("Value of primal object function : %f\n", p_obj); |
---|
| 200 | mexPrintf("Value of dual object function : %f\n", d_obj); |
---|
| 201 | mexPrintf("L1 norm of current beta : %f", b_1norm); |
---|
| 202 | mexPrintf(" <= %f\n", t); |
---|
| 203 | mexPrintf("Maximal absolute value in t(X)%%*%%r : %e", max_val); |
---|
| 204 | mexPrintf(" attained %d time(s)\n", max_num); |
---|
| 205 | mexPrintf("Number of parameters allowed to vary : %d\n", num_nz_x); |
---|
| 206 | mexEvalString("drawnow;"); |
---|
| 207 | num_iter++; |
---|
| 208 | } |
---|
| 209 | while(1){ |
---|
| 210 | do{ |
---|
| 211 | |
---|
| 212 | q_elem = qmat; |
---|
| 213 | |
---|
| 214 | for(j=0;j<num_nz_x;j++){ |
---|
| 215 | tmp = 0.0; |
---|
| 216 | for(i=0;i<n;i++,q_elem++) { |
---|
| 217 | tmp += *q_elem*r[i]; |
---|
| 218 | } |
---|
| 219 | qtr[j] = tmp; |
---|
| 220 | } |
---|
| 221 | |
---|
| 222 | if( b_1norm < (1.0-prec)*t ) |
---|
| 223 | mu = 0.0; |
---|
| 224 | else{ |
---|
| 225 | |
---|
| 226 | /* z=R^{-T}\theta can be calculated by solving R^Tz=\theta */ |
---|
| 227 | r_elem = rmat; |
---|
| 228 | for(j=0;j<num_nz_x;j++,r_elem++){ |
---|
| 229 | tmp = theta[j]; |
---|
| 230 | for(k=0;k<j;k++,r_elem++) |
---|
| 231 | tmp -= *r_elem * rinvt_theta[k]; |
---|
| 232 | rinvt_theta[j] = tmp / *r_elem; |
---|
| 233 | } |
---|
| 234 | |
---|
| 235 | wtc = 0.0; |
---|
| 236 | wtw = 0.0; |
---|
| 237 | for(j=0;j<num_nz_x;j++){ |
---|
| 238 | wtc += rinvt_theta[j]*qtr[j]; |
---|
| 239 | wtw += rinvt_theta[j]*rinvt_theta[j]; |
---|
| 240 | } |
---|
| 241 | mu = (wtc-(t-b_1norm))/wtw; |
---|
| 242 | } |
---|
| 243 | |
---|
| 244 | if(mu<=0.0) |
---|
| 245 | for(j=0;j<num_nz_x;j++) |
---|
| 246 | step[j] = qtr[j]; |
---|
| 247 | else |
---|
| 248 | for(j=0;j<num_nz_x;j++) |
---|
| 249 | step[j] = qtr[j] - mu*rinvt_theta[j]; |
---|
| 250 | |
---|
| 251 | /* h=R^{-1}z can be calculated by solving Rh=z */ |
---|
| 252 | for(j=num_nz_x-1;j>=0;j--){ |
---|
| 253 | tmp = step[j]; |
---|
| 254 | for(k=num_nz_x-1;k>j;k--) |
---|
| 255 | tmp -= RMAT(j,k) * step[k]; |
---|
| 256 | step[j] = tmp / RMAT(j,j); |
---|
| 257 | } |
---|
| 258 | for(j=0;j<num_nz_x;j++){ |
---|
| 259 | btmp[j] = beta[nz_x[j]]; |
---|
| 260 | beta[nz_x[j]] += step[j]; |
---|
| 261 | } |
---|
| 262 | b_1norm_old=b_1norm; |
---|
| 263 | |
---|
| 264 | b_1norm = 0.0; |
---|
| 265 | for(j=0;j<num_nz_x;j++) { |
---|
| 266 | b_1norm += fabs(beta[nz_x[j]]); |
---|
| 267 | } |
---|
| 268 | |
---|
| 269 | not_solved=FALSE; |
---|
| 270 | if( b_1norm > (1+prec)*t){ |
---|
| 271 | not_solved=TRUE; |
---|
| 272 | if(b_1norm_old < (1.0-prec)*t){ |
---|
| 273 | |
---|
| 274 | if(verb) { |
---|
| 275 | mexPrintf(" -->\tStepping onto the border of the L1 ball.\n"); |
---|
| 276 | mexEvalString("drawnow;"); |
---|
| 277 | } |
---|
| 278 | rho_up = rho = 1.0; |
---|
| 279 | rho_low = 0.0; |
---|
| 280 | while( fabs(t-b_1norm) > prec*t ){ |
---|
| 281 | if( b_1norm > t){ |
---|
| 282 | rho_up = rho; |
---|
| 283 | rho = (rho+rho_low)/2.0; |
---|
| 284 | } |
---|
| 285 | if( b_1norm < t){ |
---|
| 286 | rho_low = rho; |
---|
| 287 | rho = (rho+rho_up)/2.0; |
---|
| 288 | } |
---|
| 289 | if(rho < prec) break; |
---|
| 290 | for(j=0; j<num_nz_x; j++) { |
---|
| 291 | beta[nz_x[j]] = btmp[j]+rho*step[j]; |
---|
| 292 | } |
---|
| 293 | |
---|
| 294 | b_1norm = 0.0; |
---|
| 295 | for(j=0;j<num_nz_x;j++) |
---|
| 296 | b_1norm += fabs(beta[nz_x[j]]); |
---|
| 297 | } |
---|
| 298 | for(j=0;j<num_nz_x;j++){ |
---|
| 299 | if(fabs(beta[nz_x[j]])<prec) |
---|
| 300 | theta[j] = btmp[j] > 0 ? -1 : 1; |
---|
| 301 | else |
---|
| 302 | theta[j] = beta[nz_x[j]] < 0 ? -1 : 1; |
---|
| 303 | } |
---|
| 304 | }else{ |
---|
| 305 | |
---|
| 306 | |
---|
| 307 | rho = 1.0; |
---|
| 308 | to_del = -1; |
---|
| 309 | for(j=0; j<num_nz_x; j++){ |
---|
| 310 | if(fabs(step[j]) > prec){ |
---|
| 311 | tmp = -btmp[j]/(step[j]); |
---|
| 312 | if( 0.0 < tmp && tmp < rho ){ |
---|
| 313 | rho = tmp; |
---|
| 314 | to_del = j; |
---|
| 315 | } |
---|
| 316 | } |
---|
| 317 | } |
---|
| 318 | if(to_del < 0 ){ |
---|
| 319 | *psuc= -1; |
---|
| 320 | goto EXIT_HERE; |
---|
| 321 | } |
---|
| 322 | for(j=0; j<num_nz_x; j++) { |
---|
| 323 | beta[nz_x[j]] = btmp[j]+rho*step[j]; |
---|
| 324 | } |
---|
| 325 | if(verb) { |
---|
| 326 | mexPrintf(" -->\tRemoving variable: %d",nz_x[to_del]+1); |
---|
| 327 | mexEvalString("drawnow;"); |
---|
| 328 | } |
---|
| 329 | beta[nz_x[to_del]] = 0.0; |
---|
| 330 | qr_del(to_del,TRUE); |
---|
| 331 | for(j=to_del+1; j< num_nz_x; j++){ |
---|
| 332 | nz_x[j-1] = nz_x[j]; |
---|
| 333 | theta[j-1] = theta[j]; |
---|
| 334 | } |
---|
| 335 | num_nz_x--; |
---|
| 336 | |
---|
| 337 | b_1norm = 0.0; |
---|
| 338 | for(j=0;j<num_nz_x;j++) |
---|
| 339 | b_1norm += fabs(beta[nz_x[j]]); |
---|
| 340 | if(verb) |
---|
| 341 | if(b_1norm < (1-prec)*t) |
---|
| 342 | mexPrintf(", and stepping into the interior of the L1 ball\n"); |
---|
| 343 | else |
---|
| 344 | mexPrintf("\n"); |
---|
| 345 | } |
---|
| 346 | mexEvalString("drawnow;"); |
---|
| 347 | } |
---|
| 348 | |
---|
| 349 | for(i=0; i< n; i++) |
---|
| 350 | yhat1[i] = 0.0; |
---|
| 351 | for(j=0; j < num_nz_x; j++){ |
---|
| 352 | /* x_elem points to the first element in the column of X to which |
---|
| 353 | the j-th entry in nz_x points */ |
---|
| 354 | x_elem = x + nz_x[j]*n; |
---|
| 355 | tmp = beta[nz_x[j]]; |
---|
| 356 | for(i=0; i < n; i++){ |
---|
| 357 | yhat1[i] += *x_elem*tmp; |
---|
| 358 | x_elem++; /* now we point to the next element in X */ |
---|
| 359 | } |
---|
| 360 | } |
---|
| 361 | |
---|
| 362 | /* calculate the residual vector */ |
---|
| 363 | for(i=0; i< n; i++) { |
---|
| 364 | r[i] = y[i]-yhat1[i]; |
---|
| 365 | } |
---|
| 366 | }while(not_solved); |
---|
| 367 | |
---|
| 368 | |
---|
| 369 | for(i=0; i< n; i++) |
---|
| 370 | yhat1[i] = 0.0; |
---|
| 371 | for(j=0; j < num_nz_x; j++){ |
---|
| 372 | /* x_elem points to the first element in the column of X to which |
---|
| 373 | the j-th entry in nz_x points */ |
---|
| 374 | x_elem = x + nz_x[j]*n; |
---|
| 375 | tmp = beta[nz_x[j]]; |
---|
| 376 | for(i=0; i < n; i++){ |
---|
| 377 | yhat1[i] += *x_elem*tmp; |
---|
| 378 | x_elem++; /* now we point to the next element in X */ |
---|
| 379 | } |
---|
| 380 | } |
---|
| 381 | |
---|
| 382 | /* calculate the residual vector */ |
---|
| 383 | for(i=0; i< n; i++) |
---|
| 384 | r[i] = y[i]-yhat1[i]; |
---|
| 385 | |
---|
| 386 | |
---|
| 387 | /* multiply X^T with the residual vector */ |
---|
| 388 | x_elem = x; |
---|
| 389 | for(j=0; j < m; j++){ |
---|
| 390 | tmp = 0.0; |
---|
| 391 | for(i=0; i<n; i++){ |
---|
| 392 | tmp += *x_elem*r[i]; |
---|
| 393 | x_elem++; |
---|
| 394 | } |
---|
| 395 | xtr[j] = tmp; |
---|
| 396 | } |
---|
| 397 | max_val = fabs(xtr[0]); |
---|
| 398 | max_ind = 0; |
---|
| 399 | for(j=1; j<m; j++) |
---|
| 400 | if( fabs(xtr[j]) > max_val ){ |
---|
| 401 | max_val = fabs(xtr[j]); |
---|
| 402 | max_ind = j; |
---|
| 403 | } |
---|
| 404 | |
---|
| 405 | if(verb){ |
---|
| 406 | |
---|
| 407 | |
---|
| 408 | /* Find out how many times [[max_val]] is attained */ |
---|
| 409 | tmp = (1.0-prec)*max_val; |
---|
| 410 | if( tmp < prec ) tmp = 0.0; |
---|
| 411 | max_num = 1; |
---|
| 412 | for(j=0; j<m; j++){ |
---|
| 413 | if( tmp <= fabs(xtr[j]) && j!=max_ind){ |
---|
| 414 | /*we found another element equal to the (current) |
---|
| 415 | maximal absolute value */ |
---|
| 416 | max_num++; |
---|
| 417 | } |
---|
| 418 | } |
---|
| 419 | |
---|
| 420 | /* for the value of the dual objective function we need |
---|
| 421 | to calculate the L2 norm of the vector of fitted values */ |
---|
| 422 | p_obj = 0.0; |
---|
| 423 | d_obj = 0.0; |
---|
| 424 | for(i=0;i<n;i++){ |
---|
| 425 | p_obj += r[i]*r[i]; |
---|
| 426 | d_obj += yhat1[i]*yhat1[i]; |
---|
| 427 | } |
---|
| 428 | p_obj /= 2.0; |
---|
| 429 | d_obj = ytyd2 - d_obj/2.0 - t*max_val ; |
---|
| 430 | mexPrintf("******************************\n"); |
---|
| 431 | mexPrintf("\nIteration number: %d\n", num_iter); |
---|
| 432 | mexPrintf("Value of primal object function : %f\n", p_obj); |
---|
| 433 | mexPrintf("Value of dual object function : %f\n", d_obj); |
---|
| 434 | mexPrintf("L1 norm of current beta : %f", b_1norm); |
---|
| 435 | mexPrintf(" <= %f\n", t); |
---|
| 436 | mexPrintf("Maximal absolute value in t(X)%%*%%r : %e", max_val); |
---|
| 437 | mexPrintf(" attained %d time(s)\n", max_num); |
---|
| 438 | mexPrintf("Number of parameters allowed to vary : %d\n", num_nz_x); |
---|
| 439 | mexEvalString("drawnow;"); |
---|
| 440 | num_iter++; |
---|
| 441 | } |
---|
| 442 | |
---|
| 443 | add = TRUE; |
---|
| 444 | for(i=0; i<num_nz_x; i++) |
---|
| 445 | if(nz_x[i]==max_ind){ |
---|
| 446 | add = FALSE; |
---|
| 447 | break; |
---|
| 448 | } |
---|
| 449 | if(add){ |
---|
| 450 | qr_add(x+max_ind*n,TRUE); |
---|
| 451 | nz_x[num_nz_x] = max_ind; |
---|
| 452 | theta[num_nz_x] = xtr[max_ind]<0 ? -1 : 1; |
---|
| 453 | num_nz_x++; |
---|
| 454 | }else{ |
---|
| 455 | *lagrangian = max_val; |
---|
| 456 | break; |
---|
| 457 | } |
---|
| 458 | |
---|
| 459 | b_1norm = 0.0; |
---|
| 460 | for(j=0;j<num_nz_x;j++) |
---|
| 461 | b_1norm += fabs(beta[nz_x[j]]); |
---|
| 462 | if(verb) { |
---|
| 463 | mexPrintf(" -->\tAdding variable: %d\n",max_ind+1); |
---|
| 464 | mexEvalString("drawnow;"); |
---|
| 465 | } |
---|
| 466 | } |
---|
| 467 | EXIT_HERE: |
---|
| 468 | if( !as_sub) |
---|
| 469 | lasso_free(); |
---|
| 470 | } |
---|
| 471 | void mult_lasso(double *x, long *pn, long *pm, double *pt, long *pl, |
---|
| 472 | double *beta, double *y, double *yhat1, double *r, |
---|
| 473 | double *lagrangian, long *psuc, long *pverb) |
---|
| 474 | { |
---|
| 475 | |
---|
| 476 | double prec; |
---|
| 477 | long n = *pn, m = *pm, l = *pl, verb = *pverb, as_sub = TRUE, i, j; |
---|
| 478 | lasso_alloc(n,m); |
---|
| 479 | |
---|
| 480 | qr_add(y,TRUE); |
---|
| 481 | ytyd2 = *rmat * *rmat/2.0; |
---|
| 482 | prec = sqrt(PRECISION); |
---|
| 483 | num_nz_x = 0; |
---|
| 484 | for(j=0; j<m; j++){ |
---|
| 485 | if(fabs(beta[j]) > prec){ |
---|
| 486 | qr_add(x+j*n, TRUE); |
---|
| 487 | nz_x[num_nz_x] = j; |
---|
| 488 | num_nz_x++; |
---|
| 489 | }else |
---|
| 490 | beta[j] = 0.0; |
---|
| 491 | } |
---|
| 492 | *psuc = 0; |
---|
| 493 | for(i=0; i<l; i++){ |
---|
| 494 | |
---|
| 495 | if(verb){ |
---|
| 496 | mexPrintf("\n\n++++++++++++++++++++++++++++++\n"); |
---|
| 497 | mexPrintf("Solving problem number %ld with bound %f\n", i+1, pt[i]); |
---|
| 498 | mexPrintf("++++++++++++++++++++++++++++++\n"); |
---|
| 499 | mexEvalString("drawnow;"); |
---|
| 500 | } |
---|
| 501 | if(i>0) Memcpy(beta,beta-m,m); |
---|
| 502 | lasso(x, pn, pm, pt+i, beta, y, yhat1, r, lagrangian, psuc, pverb, &as_sub); |
---|
| 503 | if( *psuc < 0 ){ |
---|
| 504 | goto EXIT_HERE; |
---|
| 505 | } |
---|
| 506 | |
---|
| 507 | beta += m; |
---|
| 508 | yhat1 += n; |
---|
| 509 | r += n; |
---|
| 510 | lagrangian++; |
---|
| 511 | } |
---|
| 512 | EXIT_HERE: |
---|
| 513 | lasso_free(); |
---|
| 514 | } |
---|
| 515 | static void lasso_alloc(long n, long m){ |
---|
| 516 | |
---|
| 517 | #if defined(S_Plus) |
---|
| 518 | if( nz_x != NULL || theta != NULL || xtr != NULL || btmp != NULL || |
---|
| 519 | qtr != NULL || rinvt_theta != NULL || step != NULL || |
---|
| 520 | num_nz_x != 0 || ytyd2 != 0.0){ |
---|
| 521 | MESSAGE "Possible memory corruption or memory leak.\n We" |
---|
| 522 | "advise to restart your S+ session" WARNING(NULL_ENTRY); |
---|
| 523 | lasso_free(); |
---|
| 524 | } |
---|
| 525 | #endif |
---|
| 526 | |
---|
| 527 | QR_CHUNK = m; // Added: CJV |
---|
| 528 | |
---|
| 529 | nz_x = Calloc(m,int); |
---|
| 530 | if( nz_x == NULL ) |
---|
| 531 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
| 532 | theta = Calloc(m,int); |
---|
| 533 | if( theta==NULL ) |
---|
| 534 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
| 535 | xtr = Calloc(m,double); |
---|
| 536 | if( xtr==NULL ) |
---|
| 537 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
| 538 | btmp = Calloc(m,double); |
---|
| 539 | if( btmp==NULL ) |
---|
| 540 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
| 541 | qtr = Calloc(m,double); |
---|
| 542 | if( qtr==NULL ) |
---|
| 543 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
| 544 | rinvt_theta = Calloc(m,double); |
---|
| 545 | if( rinvt_theta==NULL ) |
---|
| 546 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
| 547 | step = Calloc(m,double); |
---|
| 548 | if( step==NULL ) |
---|
| 549 | ERRMSG("lasso_alloc", no_dyn_mem_message); |
---|
| 550 | qr_init(n); |
---|
| 551 | } |
---|
| 552 | static void lasso_free(void){ |
---|
| 553 | num_nz_x=0; |
---|
| 554 | ytyd2 = 0.0; |
---|
| 555 | Free(nz_x); |
---|
| 556 | Free(theta); |
---|
| 557 | Free(xtr); |
---|
| 558 | Free(btmp); |
---|
| 559 | Free(qtr); |
---|
| 560 | Free(rinvt_theta); |
---|
| 561 | Free(step); |
---|
| 562 | qr_free(); |
---|
| 563 | } |
---|
| 564 | static void qr_init(int n) { |
---|
| 565 | #if defined (S_Plus) |
---|
| 566 | if(qr_max_size!=0 || r_ncol!=0 || q_nrow!=0 || q_use_row!=0 || |
---|
| 567 | qmat!=NULL || rmat!=NULL){ |
---|
| 568 | MESSAGE "Possible memory corruption or memory leak.\n We" |
---|
| 569 | "advise to restart your S+ session" WARNING(NULL_ENTRY); |
---|
| 570 | qr_free(); |
---|
| 571 | } |
---|
| 572 | #endif |
---|
| 573 | qr_max_size = QR_CHUNK; |
---|
| 574 | r_ncol = 0; |
---|
| 575 | q_nrow = n; |
---|
| 576 | qmat = Calloc(n*qr_max_size,double); |
---|
| 577 | if(qmat==NULL) |
---|
| 578 | ERRMSG("qr_init", no_dyn_mem_message); |
---|
| 579 | rmat = Calloc(qr_max_size*(qr_max_size+1)/2,double); |
---|
| 580 | if(rmat==NULL) |
---|
| 581 | ERRMSG("qr_init", no_dyn_mem_message); |
---|
| 582 | } |
---|
| 583 | static void qr_incr(void) { |
---|
| 584 | qr_max_size += QR_CHUNK; |
---|
| 585 | |
---|
| 586 | /* reallocate R always */ |
---|
| 587 | rmat = Realloc(rmat,qr_max_size*(qr_max_size+1)/2,double); |
---|
| 588 | if(rmat==NULL) |
---|
| 589 | ERRMSG("qr_incr", no_dyn_mem_message); |
---|
| 590 | |
---|
| 591 | /* reallocate Q only if necessary and only to maximal necessary size */ |
---|
| 592 | if( qr_max_size >= q_nrow){ |
---|
| 593 | if( qr_max_size-QR_CHUNK < q_nrow){ |
---|
| 594 | qmat = Realloc(qmat,q_nrow*q_nrow,double); |
---|
| 595 | if(qmat==NULL) |
---|
| 596 | ERRMSG("qr_incr", no_dyn_mem_message); |
---|
| 597 | } |
---|
| 598 | }else{ |
---|
| 599 | qmat = Realloc(qmat,q_nrow*qr_max_size,double); |
---|
| 600 | if(qmat==NULL) |
---|
| 601 | ERRMSG("qr_incr", no_dyn_mem_message); |
---|
| 602 | } |
---|
| 603 | } |
---|
| 604 | static void qr_free(void) |
---|
| 605 | { |
---|
| 606 | qr_max_size = 0; |
---|
| 607 | r_ncol = 0; |
---|
| 608 | q_nrow = 0; |
---|
| 609 | q_use_row = 0; |
---|
| 610 | Free(qmat); |
---|
| 611 | Free(rmat); |
---|
| 612 | } |
---|
| 613 | static void qr_del(int l, int aug) { |
---|
| 614 | |
---|
| 615 | double c, s, tau, nu, *col_k, *col_kp1, *a, *b, tmp; |
---|
| 616 | int i, j, k, l0; |
---|
| 617 | |
---|
| 618 | if( l<0 || l>=r_ncol ) |
---|
| 619 | ERRMSG("qr_del", "Invalid column number"); |
---|
| 620 | r_ncol--; |
---|
| 621 | if( l==r_ncol) |
---|
| 622 | if( aug ) |
---|
| 623 | ERRMSG("qr_del", "Trying to delete last column of augmented matrix"); |
---|
| 624 | else |
---|
| 625 | return; |
---|
| 626 | |
---|
| 627 | /* this is TRUE if $m\le n$ ($m-1\le n$ if the matrix is augmented) */ |
---|
| 628 | if (r_ncol < q_nrow ){ |
---|
| 629 | for(k=l; k<r_ncol; k++) /* Update the factorisation and be done */ |
---|
| 630 | { |
---|
| 631 | |
---|
| 632 | col_k = RCOL(k); /* first element in column $k$ in R */ |
---|
| 633 | Memcpy(col_k,col_k+k+1,k+1); |
---|
| 634 | a = col_k+k; |
---|
| 635 | b = a+k+2; |
---|
| 636 | tau = fabs(*a)+fabs(*b); |
---|
| 637 | if( tau == 0.0 ) continue; /* both elements are zero |
---|
| 638 | nothing to update */ |
---|
| 639 | nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau)); |
---|
| 640 | c = *a/nu; |
---|
| 641 | s = *b/nu; |
---|
| 642 | *a = nu; |
---|
| 643 | b += k+2; |
---|
| 644 | a = b-1; |
---|
| 645 | for(j=k+2;j<=r_ncol;j++, a+=j, b+=j){ |
---|
| 646 | tmp = c * *a + s * *b; |
---|
| 647 | *b = c * *b - s * *a; |
---|
| 648 | *a = tmp; |
---|
| 649 | } |
---|
| 650 | col_k = QCOL(k); |
---|
| 651 | col_kp1 = col_k+q_nrow; |
---|
| 652 | for(j=0;j<q_nrow;j++){ |
---|
| 653 | tmp = c*col_k[j]+s*col_kp1[j]; |
---|
| 654 | col_kp1[j] = c*col_kp1[j]-s*col_k[j]; |
---|
| 655 | col_k[j] = tmp; |
---|
| 656 | } |
---|
| 657 | } |
---|
| 658 | }else |
---|
| 659 | if( l < q_nrow ){ |
---|
| 660 | /* Update columns upto $m$ and than shift remaining columns */ |
---|
| 661 | for(k=l; k<q_nrow-1; k++){ |
---|
| 662 | |
---|
| 663 | col_k = RCOL(k); /* first element in column $k$ in R */ |
---|
| 664 | Memcpy(col_k,col_k+k+1,k+1); |
---|
| 665 | a = col_k+k; |
---|
| 666 | b = a+k+2; |
---|
| 667 | tau = fabs(*a)+fabs(*b); |
---|
| 668 | if( tau == 0.0 ) continue; /* both elements are zero |
---|
| 669 | nothing to update */ |
---|
| 670 | nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau)); |
---|
| 671 | c = *a/nu; |
---|
| 672 | s = *b/nu; |
---|
| 673 | *a = nu; |
---|
| 674 | b += k+2; |
---|
| 675 | a = b-1; |
---|
| 676 | for(j=k+2;j<=r_ncol;j++, a+=j, b+=j){ |
---|
| 677 | tmp = c * *a + s * *b; |
---|
| 678 | *b = c * *b - s * *a; |
---|
| 679 | *a = tmp; |
---|
| 680 | } |
---|
| 681 | col_k = QCOL(k); |
---|
| 682 | col_kp1 = col_k+q_nrow; |
---|
| 683 | for(j=0;j<q_nrow;j++){ |
---|
| 684 | tmp = c*col_k[j]+s*col_kp1[j]; |
---|
| 685 | col_kp1[j] = c*col_kp1[j]-s*col_k[j]; |
---|
| 686 | col_k[j] = tmp; |
---|
| 687 | } |
---|
| 688 | } |
---|
| 689 | l0 = q_nrow-1; |
---|
| 690 | |
---|
| 691 | col_k = RCOL(l0); /* first element in column $l_0$ in R */ |
---|
| 692 | for(i=l0; i<r_ncol; i++, col_k +=i) |
---|
| 693 | Memcpy(col_k,col_k+i+1,q_nrow); |
---|
| 694 | |
---|
| 695 | /* [[rmat + (q_nrow-1)*q_nrow/2+q_nrow-1]] is last element in column |
---|
| 696 | [[q_nrow]] in R */ |
---|
| 697 | a = rmat + (q_nrow-1)*(q_nrow+2)/2; |
---|
| 698 | if( *a < 0 ){ |
---|
| 699 | /* [[j<r_ncol]] sufficient since we shifted the columns already */ |
---|
| 700 | for(j=l0;j<r_ncol;j++,a+=j) |
---|
| 701 | *a = - *a; |
---|
| 702 | col_k = qmat+(q_nrow-1)*q_nrow; |
---|
| 703 | for(i=0;i<q_nrow;i++,col_k++) |
---|
| 704 | *col_k = -*col_k; |
---|
| 705 | } |
---|
| 706 | }else{ /* just shift last columns to the left */ |
---|
| 707 | l0 = l; |
---|
| 708 | |
---|
| 709 | col_k = RCOL(l0); /* first element in column $l_0$ in R */ |
---|
| 710 | for(i=l0; i<r_ncol; i++, col_k +=i) |
---|
| 711 | Memcpy(col_k,col_k+i+1,q_nrow); |
---|
| 712 | |
---|
| 713 | /* [[rmat + (q_nrow-1)*q_nrow/2+q_nrow-1]] is last element in column |
---|
| 714 | [[q_nrow]] in R */ |
---|
| 715 | a = rmat + (q_nrow-1)*(q_nrow+2)/2; |
---|
| 716 | if( *a < 0 ){ |
---|
| 717 | /* [[j<r_ncol]] sufficient since we shifted the columns already */ |
---|
| 718 | for(j=l0;j<r_ncol;j++,a+=j) |
---|
| 719 | *a = - *a; |
---|
| 720 | col_k = qmat+(q_nrow-1)*q_nrow; |
---|
| 721 | for(i=0;i<q_nrow;i++,col_k++) |
---|
| 722 | *col_k = -*col_k; |
---|
| 723 | } |
---|
| 724 | } |
---|
| 725 | } |
---|
| 726 | static void qr_add(double *x, int swap){ |
---|
| 727 | |
---|
| 728 | double tmp, norm_orig, norm_init, norm_last, *q_new_col; |
---|
| 729 | int i; |
---|
| 730 | double *col_l, *col_lm1, *q_elem, *r_new_col; |
---|
| 731 | double norm_new; |
---|
| 732 | int j; |
---|
| 733 | double c, s, tau, nu, *a, *b; |
---|
| 734 | |
---|
| 735 | if( r_ncol == qr_max_size ) |
---|
| 736 | qr_incr(); |
---|
| 737 | |
---|
| 738 | norm_orig = 0.0; |
---|
| 739 | tmp = 0.0; |
---|
| 740 | for(i=0; i<q_nrow; i++) tmp += x[i]*x[i]; |
---|
| 741 | norm_orig = sqrt(tmp); |
---|
| 742 | if(r_ncol<q_nrow){ |
---|
| 743 | q_new_col = QCOL(r_ncol); /* points to the first element of new column */ |
---|
| 744 | tmp = 0.0; |
---|
| 745 | for(i=0; i<q_nrow; i++){ |
---|
| 746 | q_new_col[i] = x[i]/norm_orig; |
---|
| 747 | tmp += q_new_col[i]*q_new_col[i]; |
---|
| 748 | } |
---|
| 749 | if(r_ncol==0){ |
---|
| 750 | *rmat = norm_orig; |
---|
| 751 | r_ncol++; |
---|
| 752 | return; |
---|
| 753 | } |
---|
| 754 | norm_init = norm_last = sqrt(tmp); |
---|
| 755 | } |
---|
| 756 | r_new_col = RCOL(r_ncol); |
---|
| 757 | for(i=0; i<=r_ncol; i++) r_new_col[i] = 0.0; |
---|
| 758 | if( r_ncol >= q_nrow ){ |
---|
| 759 | /* Multiply new column with Q-transpose in [[qr_add]] */ |
---|
| 760 | q_elem = qmat; |
---|
| 761 | for(i=0;i<q_nrow;i++){ |
---|
| 762 | tmp = 0.0; |
---|
| 763 | for(j=0;j<q_nrow;j++,q_elem++) |
---|
| 764 | tmp += *q_elem * x[j]; |
---|
| 765 | r_new_col[i] = tmp; |
---|
| 766 | } |
---|
| 767 | if( swap ){ /* swap last two columns */ |
---|
| 768 | col_l = r_new_col; |
---|
| 769 | col_lm1 = r_new_col-r_ncol; |
---|
| 770 | for(i=0; i<q_nrow; i++,col_l++,col_lm1++){ |
---|
| 771 | tmp = *col_l; |
---|
| 772 | *col_l = *col_lm1; |
---|
| 773 | *col_lm1 = tmp; |
---|
| 774 | } |
---|
| 775 | col_lm1--; |
---|
| 776 | col_l--; |
---|
| 777 | if(q_nrow==r_ncol && *col_lm1<0){ |
---|
| 778 | /* if new R[$n$,$n$] is negative, than change signs */ |
---|
| 779 | *col_lm1 = - *col_lm1; |
---|
| 780 | *col_l = - *col_l; |
---|
| 781 | q_new_col = qmat+(q_nrow-1)*q_nrow; |
---|
| 782 | for(j=0;j<q_nrow;j++) |
---|
| 783 | q_new_col[j] = -q_new_col[j]; |
---|
| 784 | } |
---|
| 785 | } |
---|
| 786 | r_ncol++; |
---|
| 787 | return; |
---|
| 788 | } |
---|
| 789 | while(1){ |
---|
| 790 | |
---|
| 791 | q_elem = qmat; |
---|
| 792 | for(j=0; j<r_ncol;j++){ |
---|
| 793 | tmp = 0.0; |
---|
| 794 | for(i=0;i<q_nrow;i++,q_elem++) |
---|
| 795 | tmp += *q_elem * q_new_col[i]; |
---|
| 796 | r_new_col[j] += tmp; |
---|
| 797 | q_elem -= q_nrow; |
---|
| 798 | for(i=0;i<q_nrow;i++,q_elem++) |
---|
| 799 | q_new_col[i] -= *q_elem*tmp; |
---|
| 800 | } |
---|
| 801 | tmp = 0.0; |
---|
| 802 | for(i=0;i<q_nrow;i++) tmp += q_new_col[i]*q_new_col[i]; |
---|
| 803 | norm_new = sqrt(tmp); |
---|
| 804 | |
---|
| 805 | if( norm_new >= norm_last/2.0) break; |
---|
| 806 | if( norm_new > 0.1*norm_init*PRECISION ) |
---|
| 807 | norm_last = norm_new; |
---|
| 808 | else{ |
---|
| 809 | norm_init = norm_last = 0.1*norm_last*PRECISION; |
---|
| 810 | for(i=0;i<q_nrow;i++) |
---|
| 811 | q_new_col[i] = 0.0; |
---|
| 812 | if(q_use_row == q_nrow) |
---|
| 813 | ERRMSG("qr_add","Cannot orthogonalise new column\n"); |
---|
| 814 | |
---|
| 815 | q_new_col[q_use_row] = norm_new = norm_last; |
---|
| 816 | q_use_row++; |
---|
| 817 | } |
---|
| 818 | } |
---|
| 819 | |
---|
| 820 | for(i=0;i<q_nrow;i++) |
---|
| 821 | q_new_col[i] /= norm_new; |
---|
| 822 | for(i=0;i<r_ncol;i++) |
---|
| 823 | r_new_col[i] *= norm_orig; |
---|
| 824 | r_new_col[r_ncol] = norm_new*norm_orig; |
---|
| 825 | if(swap){ |
---|
| 826 | |
---|
| 827 | col_l = r_new_col; /* first element in last column in R */ |
---|
| 828 | col_lm1 = r_new_col-r_ncol; /* first element in column before last in R */ |
---|
| 829 | for(j=0;j<r_ncol;j++,col_l++,col_lm1++){ |
---|
| 830 | tmp = *col_l; |
---|
| 831 | *col_l = *col_lm1; |
---|
| 832 | *col_lm1 = tmp; |
---|
| 833 | } |
---|
| 834 | a = col_lm1-1; |
---|
| 835 | b = col_l; |
---|
| 836 | tau = fabs(*a)+fabs(*b); |
---|
| 837 | if( tau > 0.0 ){ |
---|
| 838 | /*Calculate Givens rotation*/ |
---|
| 839 | nu = tau*sqrt((*a/tau)*(*a/tau)+(*b/tau)*(*b/tau)); |
---|
| 840 | c = *a/nu; |
---|
| 841 | s = *b/nu; |
---|
| 842 | |
---|
| 843 | *a = nu; /* Fix column before last in R */ |
---|
| 844 | *b = -s* *(b-1); /* Fix last element of last column in R */ |
---|
| 845 | *(b-1) = c * *(b-1); /* Fix second last element of last column in R */ |
---|
| 846 | if( *b < 0){ |
---|
| 847 | *b = -*b; |
---|
| 848 | |
---|
| 849 | col_lm1 = QCOL(r_ncol-1); |
---|
| 850 | col_l = col_lm1+q_nrow; |
---|
| 851 | for(j=0;j<q_nrow;j++){ |
---|
| 852 | tmp = c*col_lm1[j]+s*col_l[j]; |
---|
| 853 | col_l[j] = -c*col_l[j]+s*col_lm1[j]; |
---|
| 854 | col_lm1[j] = tmp; |
---|
| 855 | } |
---|
| 856 | |
---|
| 857 | }else{ |
---|
| 858 | |
---|
| 859 | col_lm1 = QCOL(r_ncol-1); |
---|
| 860 | col_l = col_lm1+q_nrow; |
---|
| 861 | for(j=0;j<q_nrow;j++){ |
---|
| 862 | tmp = c*col_lm1[j]+s*col_l[j]; |
---|
| 863 | col_l[j] = c*col_l[j]-s*col_lm1[j]; |
---|
| 864 | col_lm1[j] = tmp; |
---|
| 865 | } |
---|
| 866 | } |
---|
| 867 | } |
---|
| 868 | |
---|
| 869 | } |
---|
| 870 | r_ncol++; |
---|
| 871 | |
---|
| 872 | } |
---|
| 873 | #if defined (S_Plus) |
---|
| 874 | static void errmsg(char *string){ |
---|
| 875 | PROBLEM "%s\n", string RECOVER(NULL_ENTRY); |
---|
| 876 | } |
---|
| 877 | #elif defined(Matlab) |
---|
| 878 | static void errmsg(char *where, char *string){ |
---|
| 879 | char str[1024]; |
---|
| 880 | sprintf(str, "Error in %s: %s\n", where, string); |
---|
| 881 | mexErrMsgTxt(str); |
---|
| 882 | } |
---|
| 883 | #else |
---|
| 884 | static void errmsg(char *where, char *string){ |
---|
| 885 | fprintf(stderr, "Error in %s: %s\n", where, string); |
---|
| 886 | exit(EXIT_FAILURE); |
---|
| 887 | } |
---|
| 888 | #endif |
---|
| 889 | |
---|
| 890 | int assertFail(const char *ex, const char *file, const char *func, const int line) |
---|
| 891 | { |
---|
| 892 | static bool ignore = false; |
---|
| 893 | |
---|
| 894 | if (!ignore) { |
---|
| 895 | char str[1024]; |
---|
| 896 | sprintf(str, "%s:%u: %s(): failed assertion `%s'\n", |
---|
| 897 | file, line, func, ex); |
---|
| 898 | mexErrMsgTxt(str); |
---|
| 899 | } /* if */ |
---|
| 900 | return(1); |
---|
| 901 | } |
---|
| 902 | |
---|
| 903 | static void fortifyPrintf(const char *s) |
---|
| 904 | { |
---|
| 905 | mexPrintf(s); |
---|
| 906 | } |
---|
| 907 | |
---|
| 908 | void |
---|
| 909 | mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[]) |
---|
| 910 | { |
---|
| 911 | mxArray *array_ptr; |
---|
| 912 | int i; |
---|
| 913 | long dataSize; |
---|
| 914 | long dimensions; |
---|
| 915 | long succeed; |
---|
| 916 | long verbose = 0; |
---|
| 917 | long as_sub = 0; |
---|
| 918 | double lagrangian; |
---|
| 919 | double threshold; |
---|
| 920 | double *mx; |
---|
| 921 | double *y; |
---|
| 922 | double *yhat1; |
---|
| 923 | double *r; |
---|
| 924 | double *beta; |
---|
| 925 | int arg; |
---|
| 926 | |
---|
| 927 | #ifdef FORTIFY |
---|
| 928 | Fortify_SetOutputFunc(fortifyPrintf); |
---|
| 929 | Fortify_EnterScope(); |
---|
| 930 | #endif |
---|
| 931 | |
---|
| 932 | if (nrhs > 3) { |
---|
| 933 | mexErrMsgTxt("Syntax error: too many arguments\n"); |
---|
| 934 | return; |
---|
| 935 | } /* if */ |
---|
| 936 | |
---|
| 937 | if (nrhs < 3) { |
---|
| 938 | mexErrMsgTxt("Syntax error: arguments missing\n"); |
---|
| 939 | return; |
---|
| 940 | } /* if */ |
---|
| 941 | |
---|
| 942 | if (nlhs > 1) { |
---|
| 943 | mexErrMsgTxt("Syntax error: too many return arguments\n"); |
---|
| 944 | return; |
---|
| 945 | } /* if */ |
---|
| 946 | |
---|
| 947 | arg = 2; |
---|
| 948 | array_ptr = (mxArray *) prhs[arg]; |
---|
| 949 | if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) { |
---|
| 950 | mexErrMsgTxt("Threshold argument is not a number\n"); |
---|
| 951 | return; |
---|
| 952 | } /* if */ |
---|
| 953 | |
---|
| 954 | threshold = mxGetScalar(array_ptr); |
---|
| 955 | |
---|
| 956 | arg = 1; |
---|
| 957 | array_ptr = (mxArray *) prhs[arg]; |
---|
| 958 | if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) { |
---|
| 959 | mexErrMsgTxt("Label argument is not a number\n"); |
---|
| 960 | return; |
---|
| 961 | } /* if */ |
---|
| 962 | |
---|
| 963 | y = mxGetPr(array_ptr); |
---|
| 964 | |
---|
| 965 | arg = 0; |
---|
| 966 | array_ptr = (mxArray *) prhs[arg]; |
---|
| 967 | if (mxGetClassID(array_ptr) != mxDOUBLE_CLASS) { |
---|
| 968 | mexErrMsgTxt("Data set argument is not a double matrix\n"); |
---|
| 969 | return; |
---|
| 970 | } /* if */ |
---|
| 971 | |
---|
| 972 | dataSize = mxGetM(array_ptr); |
---|
| 973 | dimensions = mxGetN(array_ptr); |
---|
| 974 | |
---|
| 975 | mx = mxGetPr(array_ptr); |
---|
| 976 | |
---|
| 977 | arg = 0; |
---|
| 978 | plhs[arg] = mxCreateDoubleMatrix(dimensions, 1, mxREAL); |
---|
| 979 | |
---|
| 980 | yhat1 = new double[dataSize]; |
---|
| 981 | r = new double[dataSize]; |
---|
| 982 | beta = mxGetPr(plhs[arg]); |
---|
| 983 | |
---|
| 984 | for (i=0; i<dimensions; i++) { |
---|
| 985 | beta[i] = 0.0; |
---|
| 986 | } /* for */ |
---|
| 987 | |
---|
| 988 | if (dimensions > 0) { |
---|
| 989 | lasso(mx, &dataSize, &dimensions, &threshold, beta, y, |
---|
| 990 | yhat1, r, &lagrangian, &succeed, &verbose, &as_sub); |
---|
| 991 | } else { |
---|
| 992 | mexPrintf("Dataset with 0 dimensions\n"); |
---|
| 993 | } /* if */ |
---|
| 994 | |
---|
| 995 | delete [] yhat1; |
---|
| 996 | delete [] r; |
---|
| 997 | |
---|
| 998 | #ifdef FORTIFY |
---|
| 999 | Fortify_LeaveScope(); |
---|
| 1000 | #endif |
---|
| 1001 | |
---|
| 1002 | } |
---|
| 1003 | |
---|