source: prextra/loglc2.m @ 32

Last change on this file since 32 was 32, checked in by lvandermaaten, 13 years ago

Added proper logistic classifier implementation (LOGLC2).

File size: 5.9 KB
Line 
1function [B, ll] = loglc2(A, W)
2%LOGLC2 Logistic Linear Classifier
3%
4%   W = LOGLC2(A, L)
5%
6% INPUT
7%   A   Dataset
8%   L   Regularization parameter (L2)
9%
10% OUTPUT
11%   W   Logistic linear classifier
12%
13% DESCRIPTION 
14% Computation of the linear classifier for the dataset A by maximizing the
15% L2-regularized likelihood criterion using the logistic (sigmoid) function.
16% The default value for L is 0.
17%
18%
19%  SEE ALSO
20%  MAPPINGS, DATASETS, LDC, FISHERC
21
22
23    name = 'Logistic regressor (implementation 2)';
24    addpath(genpath('minFunc'));
25    addpath(genpath('~/prtools'));
26    if exist('minFunc', 'file')
27        prwarning(1, 'LOGLC2 requires the minFunc optimizer. Please download it from www.di.ens.fr/~mschmidt/Software/minFunc.html and add it to the Matlab path.');
28    end
29   
30    % Handle untrained calls like W = loglc2([]);
31    if nargin == 0 || isempty(A)
32        B = mapping(mfilename);
33        B = setname(B, name);
34        return;
35       
36    % Handle training on dataset A (use A * loglc2, A * loglc2([]), and loglc2(A))
37    elseif (nargin == 1 && isa(A, 'dataset')) || (isa(A, 'dataset') && isa(W, 'double'))
38        if nargin < 2
39            W = 0;
40        end
41        islabtype(A, 'crisp');
42        isvaldfile(A, 1, 2);
43        A = testdatasize(A, 'features');
44        A = setprior(A, getprior(A));
45        [~, k, c] = getsize(A);
46       
47        % Train the logistic regressor
48        [data.E, data.E_bias] = train_logreg(+A', getnlab(A)', W);
49        B = mapping(mfilename, 'trained', data, getlablist(A), k, c);
50        B = setname(B, name);
51       
52    % Handle evaluation of a trained LOGLC2 W for a dataset A
53    elseif isa(A, 'dataset') && isa(W, 'mapping')
54       
55        % Evaluate logistic classifier
56        [~, test_post] = eval_logreg(+A', W.data.E, W.data.E_bias);
57        A = dataset(A);
58        B = setdata(A, test_post', getlabels(W));
59        ll = [];
60   
61    % This should not happen
62    else
63        error('Illegal call');
64    end
65end
66
67function [E, E_bias] = train_logreg(train_X, train_labels, lambda, E_init, E_bias_init)
68
69    % Uses fancy optimizer
70    addpath(genpath('minFunc'));
71
72    % Initialize solution
73    if ~iscell(train_X)
74        D = size(train_X, 1);
75    else
76        D = 0;
77        for i=1:length(train_X)
78            D = max(D, max(train_X{i}));
79        end
80    end
81    [lablist, foo, train_labels] = unique(train_labels);
82    K = length(lablist);
83    if ~exist('E_init', 'var') || isempty(E_init)
84        E = randn(D, K) * .0001;
85    else
86        E = E_init; clear E_init
87    end
88    if ~exist('E_bias_init', 'var') || isempty(E_bias_init)
89        E_bias = zeros(1, K);
90    else
91        E_bias = E_bias_init; clear E_bias_init;
92    end
93   
94    % Compute positive part of gradient
95    pos_E = zeros(D, K);
96    pos_E_bias = zeros(1, K);
97    if ~iscell(train_X)
98        for k=1:K
99            pos_E(:,k) = sum(train_X(:,train_labels == k), 2);           
100        end
101    else
102        for i=1:length(train_X)
103            pos_E(train_X{i}, train_labels(i)) = pos_E(train_X{i}, train_labels(i)) + 1;           
104        end
105    end
106    for k=1:K
107        pos_E_bias(k) = sum(train_labels == k);
108    end
109   
110    % Perform learning using L-BFGS
111    x = [E(:); E_bias(:)];
112    options.Method = 'lbfgs';
113    options.Display = 'on';
114    options.TolFun = 1e-4;
115    options.TolX = 1e-4;
116    options.MaxIter = 5000;   
117    if ~iscell(train_X)
118        x = minFunc(@logreg_grad, x, options, train_X, train_labels, lambda, pos_E, pos_E_bias);
119    else
120        x = minFunc(@logreg_discrete_grad, x, options, train_X, train_labels, lambda, pos_E, pos_E_bias);
121    end       
122   
123    % Decode solution
124    E = reshape(x(1:D * K), [D K]);
125    E_bias = reshape(x(D * K + 1:end), [1 K]);
126end
127
128
129function [est_labels, posterior] = eval_logreg(test_X, E, E_bias)
130
131    % Perform labeling
132    if ~iscell(test_X)
133        log_Pyx = bsxfun(@plus, E' * test_X, E_bias');
134    else
135        log_Pyx = zeros(length(E_bias), length(test_X));
136        for i=1:length(test_X)
137            for j=1:length(test_X{i})
138                log_Pyx(:,i) = log_Pyx(:,i) + sum(E(test_X{i}{j},:), 1)';
139            end
140        end
141        log_Pyx = bsxfun(@plus, log_Pyx, E_bias');
142    end
143    [~, est_labels] = max(log_Pyx, [], 1);
144   
145    % Compute posterior
146    if nargout > 1
147        posterior = exp(bsxfun(@minus, log_Pyx, max(log_Pyx, [], 1)));
148        posterior = bsxfun(@rdivide, posterior, sum(posterior, 1));
149    end
150end
151
152
153function [C, dC] = logreg_grad(x, train_X, train_labels, lambda, pos_E, pos_E_bias)
154%LOGREG_GRAD Gradient of L2-regularized logistic regressor
155%
156%   [C, dC] = logreg_grad(x, train_X, train_labels, lambda, pos_E, pos_E_bias)
157%
158% Gradient of L2-regularized logistic regressor.
159
160
161    % Decode solution
162    [D, N] = size(train_X);
163    K = numel(x) / (D + 1);
164    E = reshape(x(1:D * K), [D K]);
165    E_bias = reshape(x(D * K + 1:end), [1 K]);
166
167    % Compute p(y|x)
168    gamma = bsxfun(@plus, E' * train_X, E_bias');
169    gamma = exp(bsxfun(@minus, gamma, max(gamma, [], 1)));
170    gamma = bsxfun(@rdivide, gamma, max(sum(gamma, 1), realmin));
171   
172    % Compute conditional log-likelihood
173    C = 0;
174    for n=1:N
175        C = C - log(max(gamma(train_labels(n), n), realmin));
176    end
177    C = C + lambda .* sum(x .^ 2);
178   
179    % Only compute gradient when required
180    if nargout > 1
181   
182        % Compute positive part of gradient
183        if ~exist('pos_E', 'var') || isempty(pos_E)
184            pos_E = zeros(D, K);
185            for k=1:K
186                pos_E(:,k) = sum(train_X(:,train_labels == k), 2);
187            end
188        end
189        if ~exist('pos_E_bias', 'var') || isempty(pos_E_bias)
190            pos_E_bias = zeros(1, K);
191            for k=1:K       
192                pos_E_bias(k) = sum(train_labels == k);
193            end
194        end
195
196        % Compute negative part of gradient   
197        neg_E = train_X * gamma';
198        neg_E_bias = sum(gamma, 2)';
199       
200        % Compute gradient
201        dC = -[pos_E(:) - neg_E(:); pos_E_bias(:) - neg_E_bias(:)] + 2 .* lambda .* x;
202    end   
203end
Note: See TracBrowser for help on using the repository browser.