source: prextra/wloglc2.m @ 128

Last change on this file since 128 was 87, checked in by dtax, 11 years ago

Weighted logistic

File size: 5.1 KB
RevLine 
[87]1%WLOGLC2 Weighted Logistic Linear Classifier
2%
3%   W = WLOGLC2(A, L, V)
4%
5% INPUT
6%   A   Dataset
7%   L   Regularization parameter (L2)
8%   V   Weights
9%
10% OUTPUT
11%   W   Weighted Logistic linear classifier
12%
13% DESCRIPTION 
14% Computation of the linear classifier for the dataset A by maximizing the
15% L2-regularized likelihood criterion using the logistic (sigmoid) function.
16% The default value for L is 0.
17%
18%
19%  SEE ALSO
20%  MAPPINGS, DATASETS, LDC, FISHERC
21
22function [b, ll] = wloglc2(a, lambda, v)
23
24name = 'Weighted Logistic2';
25if ~exist('minFunc', 'file')
26   error('LOGLC2 requires the minFunc optimizer. Please download it from www.di.ens.fr/~mschmidt/Software/minFunc.html and add it to the Matlab path.');
27end
28   
29if nargin<3
30   v = [];
31end
32if nargin<2 || isempty(lambda)
33   lambda = 0;
34end
35if nargin == 0 || isempty(a)
36   b = prmapping(mfilename,{lambda,v});
37   b = setname(b, name);
38   return;
39end
40
41if ~ismapping(lambda)
42   % training
43   islabtype(a, 'crisp');
44   isvaldfile(a, 1, 2);
45   a = testdatasize(a, 'features');
46   a = setprior(a, getprior(a));
47   [n, k, c] = getsize(a);
48
49   % fix the weights:
50   if isempty(n)
51      v = ones(n,1);
52   end
53   % normalize
54   v = n*v(:)./sum(v);
55       
56   % Train the logistic regressor
57   [data.E, data.E_bias] = train_logreg(+a', getnlab(a)', lambda, v);
58   b = prmapping(mfilename, 'trained', data, getlablist(a), k, c);
59   b = setname(b, name);
60       
61else
62   % Evaluate logistic classifier
63   W = getdata(lambda);
64   [~, test_post] = eval_logreg(+a', W.E, W.E_bias);
65   a = prdataset(a);
66   b = setdata(a, test_post', getlabels(lambda));
67   ll = [];
68end
69   
70end
71
72function [E, E_bias] = train_logreg(train_X, train_labels, lambda, v, E_init, E_bias_init)
73
74   % Initialize solution
75   D = size(train_X, 1);
76   [lablist, foo, train_labels] = unique(train_labels);
77   K = length(lablist);
78   if ~exist('E_init', 'var') || isempty(E_init)
79       E = randn(D, K) * .0001;
80   else
81       E = E_init; clear E_init
82   end
83   if ~exist('E_bias_init', 'var') || isempty(E_bias_init)
84       E_bias = zeros(1, K);
85   else
86       E_bias = E_bias_init; clear E_bias_init;
87   end
88   
89   % Compute positive part of gradient
90   pos_E = zeros(D, K);
91   pos_E_bias = zeros(1, K);
92   VX = repmat(v',D,1).*train_X;
93   for k=1:K
94        pos_E(:,k) = sum(VX(:,train_labels == k), 2);           
95   end
96   for k=1:K
97      I = find(train_labels==k);
98      pos_E_bias(k) = sum(v(I));
99   end
100   
101   % Perform learning using L-BFGS
102   x = [E(:); E_bias(:)];
103   options.Method = 'lbfgs';
104   %options.Display = 'on'; %DXD: nooooo!
105   options.Display = 'off';
106   options.TolFun = 1e-4;
107   options.TolX = 1e-4;
108   options.MaxIter = 5000;   
109   x = minFunc(@logreg_grad, x, options, train_X, train_labels, lambda, v, pos_E, pos_E_bias);
110   
111   % Decode solution
112   E = reshape(x(1:D * K), [D K]);
113   E_bias = reshape(x(D * K + 1:end), [1 K]);
114end
115
116
117function [est_labels, posterior] = eval_logreg(test_X, E, E_bias)
118
119   % Perform labeling
120   if ~iscell(test_X)
121       log_Pyx = bsxfun(@plus, E' * test_X, E_bias');
122   else
123       log_Pyx = zeros(length(E_bias), length(test_X));
124       for i=1:length(test_X)
125           for j=1:length(test_X{i})
126               log_Pyx(:,i) = log_Pyx(:,i) + sum(E(test_X{i}{j},:), 1)';
127           end
128       end
129       log_Pyx = bsxfun(@plus, log_Pyx, E_bias');
130   end
131   [~, est_labels] = max(log_Pyx, [], 1);
132   
133   % Compute posterior
134   if nargout > 1
135       posterior = exp(bsxfun(@minus, log_Pyx, max(log_Pyx, [], 1)));
136       posterior = bsxfun(@rdivide, posterior, sum(posterior, 1));
137   end
138end
139
140
141function [C, dC] = logreg_grad(x, train_X, train_labels, lambda, v, pos_E, pos_E_bias)
142%LOGREG_GRAD Gradient of L2-regularized logistic regressor
143%
144%   [C, dC] = logreg_grad(x, train_X, train_labels, lambda, pos_E, pos_E_bias)
145%
146% Gradient of L2-regularized logistic regressor.
147
148
149   % Decode solution
150   [D, N] = size(train_X);
151   K = numel(x) / (D + 1);
152   E = reshape(x(1:D * K), [D K]);
153   E_bias = reshape(x(D * K + 1:end), [1 K]);
154
155   % Compute p(y|x)
156   gamma = bsxfun(@plus, E' * train_X, E_bias');
157   gamma = exp(bsxfun(@minus, gamma, max(gamma, [], 1)));
158   gamma = bsxfun(@rdivide, gamma, max(sum(gamma, 1), realmin));
159   
160   % Compute conditional log-likelihood
161   C = 0;
162   for n=1:N
163       C = C - log(max(v(n)*gamma(train_labels(n), n), realmin));
164   end
165   C = C + lambda .* sum(x .^ 2);
166   
167   % Only compute gradient when required
168   if nargout > 1
169   
170      % Compute positive part of gradient
171      if ~exist('pos_E', 'var') || isempty(pos_E)
172         pos_E = zeros(D, K);
173         VX = repmat(v',D,1).*train_X;
174         for k=1:K
175              pos_E(:,k) = sum(VX(:,train_labels == k), 2);           
176         end
177      end
178      if ~exist('pos_E_bias', 'var') || isempty(pos_E_bias)
179         pos_E_bias = zeros(1, K);
180         for k=1:K       
181            I = find(train_labels==k);
182            pos_E_bias(k) = sum(v(I));
183         end
184      end
185
186      % Compute negative part of gradient   
187      neg_E = (repmat(v',D,1).*train_X) * gamma';
188      neg_E_bias = gamma*v;
189       
190      % Compute gradient
191      dC = -[pos_E(:) - neg_E(:); pos_E_bias(:) - neg_E_bias(:)] + 2 .* lambda .* x;
192   end   
193end
Note: See TracBrowser for help on using the repository browser.