source: prextra/loglc2.m @ 52

Last change on this file since 52 was 37, checked in by dtax, 12 years ago

Allow also to classify double matrices.

File size: 5.8 KB
Line 
1function [B, ll] = loglc2(A, W)
2%LOGLC2 Logistic Linear Classifier
3%
4%   W = LOGLC2(A, L)
5%
6% INPUT
7%   A   Dataset
8%   L   Regularization parameter (L2)
9%
10% OUTPUT
11%   W   Logistic linear classifier
12%
13% DESCRIPTION 
14% Computation of the linear classifier for the dataset A by maximizing the
15% L2-regularized likelihood criterion using the logistic (sigmoid) function.
16% The default value for L is 0.
17%
18%
19%  SEE ALSO
20%  MAPPINGS, DATASETS, LDC, FISHERC
21
22
23    name = 'Logistic2';
24    if ~exist('minFunc', 'file')
25        error('LOGLC2 requires the minFunc optimizer. Please download it from www.di.ens.fr/~mschmidt/Software/minFunc.html and add it to the Matlab path.');
26    end
27   
28    % Handle untrained calls like W = loglc2([]);
29    if nargin == 0 || isempty(A)
30        B = mapping(mfilename);
31        B = setname(B, name);
32        return;
33       
34    % Handle training on dataset A (use A * loglc2, A * loglc2([]), and loglc2(A))
35    elseif (nargin == 1 && isa(A, 'dataset')) || (isa(A, 'dataset') && isa(W, 'double'))
36        if nargin < 2
37            W = 0;
38        end
39        islabtype(A, 'crisp');
40        isvaldfile(A, 1, 2);
41        A = testdatasize(A, 'features');
42        A = setprior(A, getprior(A));
43        [~, k, c] = getsize(A);
44       
45        % Train the logistic regressor
46        [data.E, data.E_bias] = train_logreg(+A', getnlab(A)', W);
47        B = mapping(mfilename, 'trained', data, getlablist(A), k, c);
48        B = setname(B, name);
49       
50    % Handle evaluation of a trained LOGLC2 W for a dataset A
51    elseif (isa(A, 'dataset') && isa(W, 'mapping')) || (isa(A, 'double') && isa(W, 'mapping'))
52       
53        % Evaluate logistic classifier
54        [~, test_post] = eval_logreg(+A', W.data.E, W.data.E_bias);
55        A = dataset(A);
56        B = setdata(A, test_post', getlabels(W));
57        ll = [];
58   
59    % This should not happen
60    else
61        error('Illegal call');
62    end
63end
64
65function [E, E_bias] = train_logreg(train_X, train_labels, lambda, E_init, E_bias_init)
66
67    % Initialize solution
68    if ~iscell(train_X)
69        D = size(train_X, 1);
70    else
71        D = 0;
72        for i=1:length(train_X)
73            D = max(D, max(train_X{i}));
74        end
75    end
76    [lablist, foo, train_labels] = unique(train_labels);
77    K = length(lablist);
78    if ~exist('E_init', 'var') || isempty(E_init)
79        E = randn(D, K) * .0001;
80    else
81        E = E_init; clear E_init
82    end
83    if ~exist('E_bias_init', 'var') || isempty(E_bias_init)
84        E_bias = zeros(1, K);
85    else
86        E_bias = E_bias_init; clear E_bias_init;
87    end
88   
89    % Compute positive part of gradient
90    pos_E = zeros(D, K);
91    pos_E_bias = zeros(1, K);
92    if ~iscell(train_X)
93        for k=1:K
94            pos_E(:,k) = sum(train_X(:,train_labels == k), 2);           
95        end
96    else
97        for i=1:length(train_X)
98            pos_E(train_X{i}, train_labels(i)) = pos_E(train_X{i}, train_labels(i)) + 1;           
99        end
100    end
101    for k=1:K
102        pos_E_bias(k) = sum(train_labels == k);
103    end
104   
105    % Perform learning using L-BFGS
106    x = [E(:); E_bias(:)];
107    options.Method = 'lbfgs';
108    options.Display = 'on';
109    options.TolFun = 1e-4;
110    options.TolX = 1e-4;
111    options.MaxIter = 5000;   
112    if ~iscell(train_X)
113        x = minFunc(@logreg_grad, x, options, train_X, train_labels, lambda, pos_E, pos_E_bias);
114    else
115        x = minFunc(@logreg_discrete_grad, x, options, train_X, train_labels, lambda, pos_E, pos_E_bias);
116    end       
117   
118    % Decode solution
119    E = reshape(x(1:D * K), [D K]);
120    E_bias = reshape(x(D * K + 1:end), [1 K]);
121end
122
123
124function [est_labels, posterior] = eval_logreg(test_X, E, E_bias)
125
126    % Perform labeling
127    if ~iscell(test_X)
128        log_Pyx = bsxfun(@plus, E' * test_X, E_bias');
129    else
130        log_Pyx = zeros(length(E_bias), length(test_X));
131        for i=1:length(test_X)
132            for j=1:length(test_X{i})
133                log_Pyx(:,i) = log_Pyx(:,i) + sum(E(test_X{i}{j},:), 1)';
134            end
135        end
136        log_Pyx = bsxfun(@plus, log_Pyx, E_bias');
137    end
138    [~, est_labels] = max(log_Pyx, [], 1);
139   
140    % Compute posterior
141    if nargout > 1
142        posterior = exp(bsxfun(@minus, log_Pyx, max(log_Pyx, [], 1)));
143        posterior = bsxfun(@rdivide, posterior, sum(posterior, 1));
144    end
145end
146
147
148function [C, dC] = logreg_grad(x, train_X, train_labels, lambda, pos_E, pos_E_bias)
149%LOGREG_GRAD Gradient of L2-regularized logistic regressor
150%
151%   [C, dC] = logreg_grad(x, train_X, train_labels, lambda, pos_E, pos_E_bias)
152%
153% Gradient of L2-regularized logistic regressor.
154
155
156    % Decode solution
157    [D, N] = size(train_X);
158    K = numel(x) / (D + 1);
159    E = reshape(x(1:D * K), [D K]);
160    E_bias = reshape(x(D * K + 1:end), [1 K]);
161
162    % Compute p(y|x)
163    gamma = bsxfun(@plus, E' * train_X, E_bias');
164    gamma = exp(bsxfun(@minus, gamma, max(gamma, [], 1)));
165    gamma = bsxfun(@rdivide, gamma, max(sum(gamma, 1), realmin));
166   
167    % Compute conditional log-likelihood
168    C = 0;
169    for n=1:N
170        C = C - log(max(gamma(train_labels(n), n), realmin));
171    end
172    C = C + lambda .* sum(x .^ 2);
173   
174    % Only compute gradient when required
175    if nargout > 1
176   
177        % Compute positive part of gradient
178        if ~exist('pos_E', 'var') || isempty(pos_E)
179            pos_E = zeros(D, K);
180            for k=1:K
181                pos_E(:,k) = sum(train_X(:,train_labels == k), 2);
182            end
183        end
184        if ~exist('pos_E_bias', 'var') || isempty(pos_E_bias)
185            pos_E_bias = zeros(1, K);
186            for k=1:K       
187                pos_E_bias(k) = sum(train_labels == k);
188            end
189        end
190
191        % Compute negative part of gradient   
192        neg_E = train_X * gamma';
193        neg_E_bias = sum(gamma, 2)';
194       
195        % Compute gradient
196        dC = -[pos_E(:) - neg_E(:); pos_E_bias(:) - neg_E_bias(:)] + 2 .* lambda .* x;
197    end   
198end
Note: See TracBrowser for help on using the repository browser.