source: distools/clevald.m @ 76

Last change on this file since 76 was 21, checked in by bduin, 13 years ago
File size: 8.6 KB
Line 
1%CLEVALD Classifier evaluation (learning curve) for dissimilarity data
2%
3%   E = CLEVALD(D,CLASSF,TRAINSIZES,REPSIZE,NREPS,T,TESTFUN)
4%
5% INPUT
6%   D          Square dissimilarity dataset
7%   CLASSF     Classifiers to be evaluated (cell array)
8%   TRAINSIZE  Vector of training set sizes, used to generate subsets of D
9%              (default [2,3,5,7,10,15,20,30,50,70,100]). TRAINSIZE is per
10%              class unless D has no priors set or has soft labels.
11%   REPSIZE    Representation set size per class (>=1), or fraction (<1)
12%              (default total, training set)
13%   NREPS      Number of repetitions (default 1)
14%   T          Test dataset (default [], use remaining samples in A)
15%   TESTFUN    Mapping,evaluation function (default classification error)
16%
17% OUTPUT
18%   E          Error structure (see PLOTE) containing training and test
19%              errors
20%
21% DESCRIPTION
22% Generates at random, for all class sizes defined in TRAINSIZES, training
23% sets out of the dissimilarity dataset D. The representation set is either
24% equal to the training set (REPSIZE = []), or a fraction of it (REPSIZE <1)
25% or a random subset of it of a given size (REPSIZE>1). This set is used
26% for training the untrained classifiers CLASSF. The resulting trained
27% classifiers are tested on the training objects and on the left-over test
28% objects, or, if supplied, the testset T. This procedure is then repeated
29% NREPS times. The default test routine is classification error estimation
30% by TESTC([],'crisp').
31%
32% The returned structure E contains several fields for annotating the plot
33% produced by PLOTE. They may be changed by the users. Removal of the field
34% 'apperror' (RMFIELD(E,'apperror')) suppresses the draw of the error
35% curves for the training set.
36%
37% Training set generation is done "with replacement" and such that for each
38% run the larger training sets include the smaller ones and that for all
39% classifiers the same training sets are used.
40%
41% This function uses the RAND random generator and thereby reproduces
42% if its seed is reset (see RAND).
43% If CLASSF uses RANDN, its seed should be reset as well.
44%
45% SEE ALSO
46% MAPPINGS, DATASETS, CLEVAL, TESTC, PLOTE
47
48% R.P.W. Duin, r.p.w.duin@prtools.org
49% Faculty EWI, Delft University of Technology
50% P.O. Box 5031, 2600 GA Delft, The Netherlands
51
52function e = clevald(a,classf,learnsizes,repsize,nreps,t,testfun)
53
54        prtrace(mfilename);
55
56  if (nargin < 7) | isempty(testfun)
57    testfun = testc([],'crisp');
58  end;
59  if (nargin < 6)
60    t = [];
61  end;
62  if (nargin < 5) | isempty(nreps);
63    nreps = 1;
64  end;
65  if (nargin < 4)
66    repsize = [];
67  end
68  if (nargin < 3) | isempty(learnsizes);
69    learnsizes = [2,3,5,7,10,15,20,30,50,70,100];
70  end;
71  if ~iscell(classf), classf = {classf}; end
72
73  % Assert that all is right.
74  isdataset(a); issquare(a); ismapping(classf{1});
75  if (~isempty(t)), isdataset(t); end
76
77  % Remove requested class sizes that are larger than the size of the
78  % smallest class.
79
80        [m,k,c] = getsize(a);
81        if ~isempty(a,'prior') & islabtype(a,'crisp')
82                classs = true;
83                mc = classsizes(a);
84                toolarge = find(learnsizes >= min(mc));
85                if (~isempty(toolarge))
86                        prwarning(2,['training set class sizes ' num2str(learnsizes(toolarge)) ...
87                                                                         ' larger than the minimal class size; removed them']);
88                        learnsizes(toolarge) = [];
89                end
90        else
91                if islabtype(a,'crisp') & isempty(a,'prior')
92                        prwarning(1,['No priors found in dataset, class frequencies are used.' ...
93                        newline '            Training set sizes hold for entire dataset']);
94                end
95                classs = false;
96                toolarge = find(learnsizes >= m);
97                if (~isempty(toolarge))
98                        prwarning(2,['training set sizes ' num2str(learnsizes(toolarge)) ...
99                                                                         ' larger than number of objects; removed them']);
100                        learnsizes(toolarge) = [];
101                end
102        end
103  learnsizes = learnsizes(:)';
104
105  % Fill the error structure.
106
107  nw = length(classf(:));
108  datname = getname(a);
109
110  e.n       = nreps;
111  e.error   = zeros(nw,length(learnsizes));
112  e.std     = zeros(nw,length(learnsizes));
113  e.apperror   = zeros(nw,length(learnsizes));
114  e.appstd     = zeros(nw,length(learnsizes));
115  e.xvalues = learnsizes(:)';
116        if classs
117                e.xlabel   = 'Training set size per class';
118        else
119                e.xlabel   = 'Training set size';
120        end
121  e.names   = [];
122  if (nreps > 1)
123    e.ylabel= ['Averaged error (' num2str(nreps) ' experiments)'];
124  elseif (nreps == 1)
125    e.ylabel = 'Error';
126  else
127    error('Number of repetitions NREPS should be >= 1.');
128  end;
129  if (~isempty(datname))
130    if isempty(repsize)
131      e.title = [datname ', Rep. Set = Train Set'];
132    elseif repsize < 1
133      e.title = [datname ', Rep. size = ' num2str(repsize) ' Train size'];
134    else
135      e.title = [datname ', Rep. size = ' num2str(repsize) ' per class'];
136    end
137  end
138  if (learnsizes(end)/learnsizes(1) > 20)
139    e.plot = 'semilogx';        % If range too large, use a log-plot for X.
140  end
141
142  % Report progress.
143       
144        s1 = sprintf('cleval: %i classifiers: ',nw);
145        prwaitbar(nw,s1);
146
147  % Store the seed, to reset the random generator later for different
148  % classifiers.
149
150        seed = rand('state');
151
152  % Loop over all classifiers (with index WI).
153
154  for wi = 1:nw
155               
156    if (~isuntrained(classf{wi}))
157      error('Classifiers should be untrained.')
158    end
159    name = getname(classf{wi});
160    e.names = char(e.names,name);
161    prwaitbar(nw,wi,[s1 name]);
162
163    % E1 will contain the error estimates.
164
165    e1 = zeros(nreps,length(learnsizes));
166    e0 = zeros(nreps,length(learnsizes));
167
168    % Take care that classifiers use same training set.
169
170    rand('state',seed); seed2 = seed;
171
172                % For NREPS repetitions...
173               
174                s2 = sprintf('cleval: %i repetitions: ',nreps);
175                prwaitbar(nreps,s2);
176
177                for i = 1:nreps
178       
179                        prwaitbar(nreps,i,[s2 int2str(i)]);
180      % Store the randomly permuted indices of samples of class CI to use in
181      % this training set in JR(CI,:).
182
183                        if classs
184
185                                JR = zeros(c,max(learnsizes));
186                       
187                                for ci = 1:c
188
189                                        JC = findnlab(a,ci);
190
191                                        % Necessary for reproducable training sets: set the seed and store
192                                        % it after generation, so that next time we will use the previous one.
193                                        rand('state',seed2);
194
195                                        JD = JC(randperm(mc(ci)));
196                                        JR(ci,:) = JD(1:max(learnsizes))';
197                                        seed2 = rand('state');
198                                end
199                               
200                        elseif islabtype(a,'crisp')
201                               
202                                rand('state',seed2); % get seed for reproducable training sets
203                                % generate indices for the entire dataset taking care that in
204                                % the first 2c objects we have 2 objects for every class
205                                [a1,a2,I1,I2] = gendat(a,2*ones(1,c));
206                                JD = randperm(m-2*c);
207                                JR = [I1;I2(JD)];
208                                seed2 = rand('state'); % save seed for reproducable training sets
209                               
210                        else  % soft labels
211                               
212                                rand('state',seed2); % get seed for reproducable training sets
213                                JR = randperm(m);
214                                seed2 = rand('state'); % save seed for reproducable training sets
215                               
216                        end
217
218                        li = 0;                                                                         % Index of training set.
219                       
220                        nlearns = length(learnsizes);
221                        s3 = sprintf('cleval: %i sizes: ',nlearns);
222                        prwaitbar(nreps,s3);
223                       
224                        for j = 1:nlearns
225                               
226                                nj = learnsizes(j);
227                               
228                                prwaitbar(nlearns,j,[s3 int2str(j) ' (' int2str(nj) ')']);
229                                li = li + 1;
230
231        % J will contain the indices for this training set.
232
233        J = [];
234        R = [];
235                               
236                                if classs
237                                        for ci = 1:c
238                                                J = [J;JR(ci,1:nj)'];
239                                                if isempty(repsize)
240                                                        R = [R JR(ci,1:nj)];
241                                                elseif repsize < 1
242                                                        R = [R JR(ci,1:ceil(repsize*nj))];
243                                                else
244                                                        R = [R JR(ci,1:min(nj,repsize))];
245                                                end
246                                        end;
247                                else
248                                        J = JR(1:nj);
249                                        if isempty(repsize)
250                                                R = J;
251                                        elseif repsize < 1
252                                                R = JR(1:ceil(repsize*nj));
253                                        else
254                                                R = JR(1:min(nj,repsize));
255                                        end
256                                end;
257                                       
258                                trainset = a(J,R);
259                                trainset = setprior(trainset,getprior(trainset,0));
260                                w = trainset*classf{wi};                                        % Use right classifier.
261                                e0(i,li) = trainset*w*testfun;
262                                if (isempty(t))
263                                Jt = ones(m,1);
264                                        Jt(J) = zeros(size(J));
265                                        Jt = find(Jt);                                                          % Don't use training set for testing.
266                                        testset = a(Jt,R);
267                                        testset = setprior(testset,getprior(testset,0));
268                                        e1(i,li) = testset*w*testfun;
269                                else
270                                        testset = t(:,R);
271                                        testset = setprior(testset,getprior(testset,0));
272                                        e1(i,li) = testset*w*testfun;
273                                end
274
275                        end
276                        prwaitbar(0);
277
278                end
279                prwaitbar(0);
280
281    % Calculate average error and standard deviation for this classifier
282    % (or set the latter to zero if there's been just 1 repetition).
283
284                e.error(wi,:) = mean(e1,1);
285                e.apperror(wi,:) = mean(e0,1);
286                if (nreps == 1)
287                        e.std(wi,:) = zeros(1,size(e.std,2));
288                        e.appstd(wi,:) = zeros(1,size(e.appstd,2));
289                else
290                        e.std(wi,:) = std(e1)/sqrt(nreps);
291                        e.appstd(wi,:) = std(e0)/sqrt(nreps);
292                end
293        end
294        prwaitbar(0);
295
296        % The first element is the empty string [], remove it.
297        e.names(1,:) = [];
298
299return
300
Note: See TracBrowser for help on using the repository browser.