source: prextra/loglc2.m @ 99

Last change on this file since 99 was 82, checked in by dtax, 11 years ago

Now better?

File size: 5.8 KB
Line 
1function [B, ll] = loglc2(A, W)
2%LOGLC2 Logistic Linear Classifier
3%
4%   W = LOGLC2(A, L)
5%
6% INPUT
7%   A   Dataset
8%   L   Regularization parameter (L2)
9%
10% OUTPUT
11%   W   Logistic linear classifier
12%
13% DESCRIPTION 
14% Computation of the linear classifier for the dataset A by maximizing the
15% L2-regularized likelihood criterion using the logistic (sigmoid) function.
16% The default value for L is 0.
17%
18%
19%  SEE ALSO
20%  MAPPINGS, DATASETS, LDC, FISHERC
21
22
23    name = 'Logistic2';
24    if ~exist('minFunc', 'file')
25        error('LOGLC2 requires the minFunc optimizer. Please download it from www.di.ens.fr/~mschmidt/Software/minFunc.html and add it to the Matlab path.');
26    end
27   
28    % Handle untrained calls like W = loglc2([]);
29    if nargin == 0 || isempty(A)
30        B = prmapping(mfilename);
31        B = setname(B, name);
32        return;
33       
34    % Handle training on dataset A (use A * loglc2, A * loglc2([]), and loglc2(A))
35    elseif (nargin == 1 && isdataset(A)) || (isdataset(A) && isa(W, 'double'))
36        if nargin < 2
37            W = 0;
38        end
39        islabtype(A, 'crisp');
40        isvaldfile(A, 1, 2);
41        A = testdatasize(A, 'features');
42        A = setprior(A, getprior(A));
43        [~, k, c] = getsize(A);
44       
45        % Train the logistic regressor
46        [data.E, data.E_bias] = train_logreg(+A', getnlab(A)', W);
47        B = prmapping(mfilename, 'trained', data, getlablist(A), k, c);
48        B = setname(B, name);
49       
50    % Handle evaluation of a trained LOGLC2 W for a dataset A
51    elseif (isdataset(A) && ismapping(W)) || (isa(A,'double') && ismapping(W))
52       
53        % Evaluate logistic classifier
54        [~, test_post] = eval_logreg(+A', W.data.E, W.data.E_bias);
55        A = prdataset(A);
56        B = setdata(A, test_post', getlabels(W));
57        ll = [];
58   
59    % This should not happen
60    else
61        error('Illegal call');
62    end
63end
64
65function [E, E_bias] = train_logreg(train_X, train_labels, lambda, E_init, E_bias_init)
66
67    % Initialize solution
68    if ~iscell(train_X)
69        D = size(train_X, 1);
70    else
71        D = 0;
72        for i=1:length(train_X)
73            D = max(D, max(train_X{i}));
74        end
75    end
76    [lablist, foo, train_labels] = unique(train_labels);
77    K = length(lablist);
78    if ~exist('E_init', 'var') || isempty(E_init)
79        E = randn(D, K) * .0001;
80    else
81        E = E_init; clear E_init
82    end
83    if ~exist('E_bias_init', 'var') || isempty(E_bias_init)
84        E_bias = zeros(1, K);
85    else
86        E_bias = E_bias_init; clear E_bias_init;
87    end
88   
89    % Compute positive part of gradient
90    pos_E = zeros(D, K);
91    pos_E_bias = zeros(1, K);
92    if ~iscell(train_X)
93        for k=1:K
94            pos_E(:,k) = sum(train_X(:,train_labels == k), 2);           
95        end
96    else
97        for i=1:length(train_X)
98            pos_E(train_X{i}, train_labels(i)) = pos_E(train_X{i}, train_labels(i)) + 1;           
99        end
100    end
101    for k=1:K
102        pos_E_bias(k) = sum(train_labels == k);
103    end
104   
105    % Perform learning using L-BFGS
106    x = [E(:); E_bias(:)];
107    options.Method = 'lbfgs';
108    %options.Display = 'on'; %DXD: nooooo!
109    options.Display = 'off';
110    options.TolFun = 1e-4;
111    options.TolX = 1e-4;
112    options.MaxIter = 5000;   
113    if ~iscell(train_X)
114        x = minFunc(@logreg_grad, x, options, train_X, train_labels, lambda, pos_E, pos_E_bias);
115    else
116        x = minFunc(@logreg_discrete_grad, x, options, train_X, train_labels, lambda, pos_E, pos_E_bias);
117    end       
118   
119    % Decode solution
120    E = reshape(x(1:D * K), [D K]);
121    E_bias = reshape(x(D * K + 1:end), [1 K]);
122end
123
124
125function [est_labels, posterior] = eval_logreg(test_X, E, E_bias)
126
127    % Perform labeling
128    if ~iscell(test_X)
129        log_Pyx = bsxfun(@plus, E' * test_X, E_bias');
130    else
131        log_Pyx = zeros(length(E_bias), length(test_X));
132        for i=1:length(test_X)
133            for j=1:length(test_X{i})
134                log_Pyx(:,i) = log_Pyx(:,i) + sum(E(test_X{i}{j},:), 1)';
135            end
136        end
137        log_Pyx = bsxfun(@plus, log_Pyx, E_bias');
138    end
139    [~, est_labels] = max(log_Pyx, [], 1);
140   
141    % Compute posterior
142    if nargout > 1
143        posterior = exp(bsxfun(@minus, log_Pyx, max(log_Pyx, [], 1)));
144        posterior = bsxfun(@rdivide, posterior, sum(posterior, 1));
145    end
146end
147
148
149function [C, dC] = logreg_grad(x, train_X, train_labels, lambda, pos_E, pos_E_bias)
150%LOGREG_GRAD Gradient of L2-regularized logistic regressor
151%
152%   [C, dC] = logreg_grad(x, train_X, train_labels, lambda, pos_E, pos_E_bias)
153%
154% Gradient of L2-regularized logistic regressor.
155
156
157    % Decode solution
158    [D, N] = size(train_X);
159    K = numel(x) / (D + 1);
160    E = reshape(x(1:D * K), [D K]);
161    E_bias = reshape(x(D * K + 1:end), [1 K]);
162
163    % Compute p(y|x)
164    gamma = bsxfun(@plus, E' * train_X, E_bias');
165    gamma = exp(bsxfun(@minus, gamma, max(gamma, [], 1)));
166    gamma = bsxfun(@rdivide, gamma, max(sum(gamma, 1), realmin));
167   
168    % Compute conditional log-likelihood
169    C = 0;
170    for n=1:N
171        C = C - log(max(gamma(train_labels(n), n), realmin));
172    end
173    C = C + lambda .* sum(x .^ 2);
174   
175    % Only compute gradient when required
176    if nargout > 1
177   
178        % Compute positive part of gradient
179        if ~exist('pos_E', 'var') || isempty(pos_E)
180            pos_E = zeros(D, K);
181            for k=1:K
182                pos_E(:,k) = sum(train_X(:,train_labels == k), 2);
183            end
184        end
185        if ~exist('pos_E_bias', 'var') || isempty(pos_E_bias)
186            pos_E_bias = zeros(1, K);
187            for k=1:K       
188                pos_E_bias(k) = sum(train_labels == k);
189            end
190        end
191
192        % Compute negative part of gradient   
193        neg_E = train_X * gamma';
194        neg_E_bias = sum(gamma, 2)';
195       
196        % Compute gradient
197        dC = -[pos_E(:) - neg_E(:); pos_E_bias(:) - neg_E_bias(:)] + 2 .* lambda .* x;
198    end   
199end
Note: See TracBrowser for help on using the repository browser.