source: distools/clevald.m @ 18

Last change on this file since 18 was 18, checked in by bduin, 13 years ago

clevald, parzenddc, parzend_map and kem added

File size: 6.7 KB
RevLine 
[18]1%CLEVALD Classifier evaluation (learning curve) for dissimilarity data
2%
3%   E = CLEVALD(D,CLASSF,TRAINSIZES,REPSIZE,NREPS,T)
4%
5% INPUT
6%   D          Square dissimilarity dataset
7%   CLASSF     Classifiers to be evaluated (cell array)
8%   TRAINSIZE  Vector of class sizes, used to generate subsets of D
9%              (default [2,3,5,7,10,15,20,30,50,70,100])
10%   REPSIZE    Representation set size per class (>=1), or fraction (<1)
11%              (default total, training set)
12%   NREPS      Number of repetitions (default 1)
13%   T          Test dataset (default [], use remaining samples in A)
14%
15% OUTPUT
16%   E          Error structure (see PLOTE) containing training and test
17%              errors
18%
19% DESCRIPTION
20% Generates at random, for all class sizes defined in TRAINSIZES, training
21% sets out of the dissimilarity dataset D. The representation set is either
22% equal to the training set (REPSIZE = []), or a fraction of it (REPSIZE  <1)
23% or a random subset of it of a given size (REPSIZE>1). This set is used
24% for training the untrained classifiers CLASSF. The resulting trained
25% classifiers are tested on the training objects and on the left-over test
26% objects, or, if supplied, the testset T. This procedure is then repeated
27% NREPS times.
28%
29% The returned structure E contains several fields for annotating the plot
30% produced by PLOTE. They may be changed by the users. Removal of the field
31% 'apperror' (RMFIELD(E,'apperror')) suppresses the draw of the error
32% curves for the training set.
33%
34% Training set generation is done "with replacement" and such that for each
35% run the larger training sets include the smaller ones and that for all
36% classifiers the same training sets are used.
37%
38% This function uses the RAND random generator and thereby reproduces
39% if its seed is reset (see RAND).
40% If CLASSF uses RANDN, its seed should be reset as well.
41%
42% SEE ALSO
43% MAPPINGS, DATASETS, CLEVAL, TESTC, PLOTE
44
45% R.P.W. Duin, r.p.w.duin@prtools.org
46% Faculty EWI, Delft University of Technology
47% P.O. Box 5031, 2600 GA Delft, The Netherlands
48
49function e = cleval(a,classf,learnsizes,repsize,nreps,t)
50
51        prtrace(mfilename);
52
53  if (nargin < 6)
54    t = [];
55  end;
56  if (nargin < 5) | isempty(nreps);
57    nreps = 1;
58  end;
59  if (nargin < 4)
60    repsize = [];
61  end
62  if (nargin < 3) | isempty(learnsizes);
63    learnsizes = [2,3,5,7,10,15,20,30,50,70,100];
64  end;
65  if ~iscell(classf), classf = {classf}; end
66
67  % Assert that all is right.
68  isdataset(a); issquare(a); ismapping(classf{1});
69  if (~isempty(t)), isdataset(t); end
70
71  % Remove requested class sizes that are larger than the size of the
72  % smallest class.
73
74  mc = classsizes(a); [m,k,c] = getsize(a);
75  toolarge = find(learnsizes >= min(mc));
76  if (~isempty(toolarge))
77    prwarning(2,['training set class sizes ' num2str(learnsizes(toolarge)) ...
78                 ' larger than the minimal class size in A; remove them']);
79    learnsizes(toolarge) = [];
80  end
81  learnsizes = learnsizes(:)';
82
83  % Fill the error structure.
84
85  nw = length(classf(:));
86  datname = getname(a);
87
88  e.n       = nreps;
89  e.error   = zeros(nw,length(learnsizes));
90  e.std     = zeros(nw,length(learnsizes));
91  e.apperror   = zeros(nw,length(learnsizes));
92  e.appstd     = zeros(nw,length(learnsizes));
93  e.xvalues = learnsizes(:)';
94  e.xlabel = 'Training set size per class';
95  e.names   = [];
96  if (nreps > 1)
97    e.ylabel= ['Averaged error (' num2str(nreps) ' experiments)'];
98  elseif (nreps == 1)
99    e.ylabel = 'Error';
100  else
101    error('Number of repetitions NREPS should be >= 1.');
102  end;
103  if (~isempty(datname))
104    if isempty(repsize)
105      e.title = [datname ', Rep. Set = Train Set'];
106    elseif repsize < 1
107      e.title = [datname ', Rep. size = ' num2str(repsize) ' Train size'];
108    else
109      e.title = [datname ', Rep. size = ' num2str(repsize) ' per class'];
110    end
111  end
112  if (learnsizes(end)/learnsizes(1) > 20)
113    e.plot = 'semilogx';        % If range too large, use a log-plot for X.
114  end
115
116  % Report progress.
117       
118        s1 = sprintf('cleval: %i classifiers: ',nw);
119        prwaitbar(nw,s1);
120
121  % Store the seed, to reset the random generator later for different
122  % classifiers.
123
124        seed = rand('state');
125
126  % Loop over all classifiers (with index WI).
127
128  for wi = 1:nw
129               
130    if (~isuntrained(classf{wi}))
131      error('Classifiers should be untrained.')
132    end
133    name = getname(classf{wi});
134    e.names = char(e.names,name);
135    prwaitbar(nw,wi,[s1 name]);
136
137    % E1 will contain the error estimates.
138
139    e1 = zeros(nreps,length(learnsizes));
140    e0 = zeros(nreps,length(learnsizes));
141
142    % Take care that classifiers use same training set.
143
144    rand('state',seed); seed2 = seed;
145
146                % For NREPS repetitions...
147               
148                s2 = sprintf('cleval: %i repetitions: ',nreps);
149                prwaitbar(nreps,s2);
150
151                for i = 1:nreps
152       
153                        prwaitbar(nreps,i,[s2 int2str(i)]);
154      % Store the randomly permuted indices of samples of class CI to use in
155      % this training set in JR(CI,:).
156
157                        JR = zeros(c,max(learnsizes));
158                       
159                        for ci = 1:c
160
161                                JC = findnlab(a,ci);
162
163        % Necessary for reproducable training sets: set the seed and store
164        % it after generation, so that next time we will use the previous one.
165                                rand('state',seed2);
166
167                                JD = JC(randperm(mc(ci)));
168                                JR(ci,:) = JD(1:max(learnsizes))';
169                                seed2 = rand('state');
170                        end
171
172                        li = 0;                                                                         % Index of training set.
173                       
174                        nlearns = length(learnsizes);
175                        s3 = sprintf('cleval: %i sizes: ',nlearns);
176                        prwaitbar(nreps,s3);
177                       
178                        for j = 1:nlearns
179                               
180                                nj = learnsizes(j);
181                               
182                                prwaitbar(nlearns,j,[s3 int2str(j) ' (' int2str(nj) ')']);
183                                li = li + 1;
184
185        % J will contain the indices for this training set.
186
187        J = [];
188        R = [];
189        for ci = 1:c
190          J = [J;JR(ci,1:nj)'];
191          if isempty(repsize)
192            R = [R JR(ci,1:nj)];
193          elseif repsize < 1
194            R = [R JR(ci,1:ceil(repsize*nj))];
195          else
196            R = [R JR(ci,1:min(nj,repsize))];
197          end
198           
199        end;
200
201                                w = a(J,R)*classf{wi};                                  % Use right classifier.
202                                e0(i,li) = a(J,R)*w*testc;
203                                if (isempty(t))
204                                Jt = ones(m,1);
205                                        Jt(J) = zeros(size(J));
206                                        Jt = find(Jt);                                                          % Don't use training set for testing.
207                                        e1(i,li) = a(Jt,R)*w*testc;
208                                else
209                                        e1(i,li) = t(:,R)*w*testc;
210                                end
211
212                        end
213                        prwaitbar(0);
214
215                end
216                prwaitbar(0);
217
218    % Calculate average error and standard deviation for this classifier
219    % (or set the latter to zero if there's been just 1 repetition).
220
221                e.error(wi,:) = mean(e1,1);
222                e.apperror(wi,:) = mean(e0,1);
223                if (nreps == 1)
224                        e.std(wi,:) = zeros(1,size(e.std,2));
225                        e.appstd(wi,:) = zeros(1,size(e.appstd,2));
226                else
227                        e.std(wi,:) = std(e1)/sqrt(nreps);
228                        e.appstd(wi,:) = std(e0)/sqrt(nreps);
229                end
230        end
231        prwaitbar(0);
232
233        % The first element is the empty string [], remove it.
234        e.names(1,:) = [];
235
236return
237
Note: See TracBrowser for help on using the repository browser.