source: distools/crossvald.m @ 65

Last change on this file since 65 was 10, checked in by bduin, 14 years ago
File size: 7.1 KB
RevLine 
[10]1%CROSSVALD Cross-validation error for dissimilarity representations
2%
3%   [ERR,STD_ERR] = CROSSVALD(D,CLASSF,N,K,ITER,FID)
4%   [ERR,CERR,NLAB_OUT] = CROSSVALD(D,CLASSF,N,K,1,FID)
5%
6% INPUT
7%   A          Input dataset
8%   CLASSF     Untrained classifier to be tested.
9%   N          Number of dataset divisions (default: N==number of
10%              samples, leave-one-out)
11%   K          Desired size of the representation set (default: use all)
12%   ITER       Number of iterations
13%   FID        File descriptor for progress report file (default: 0)
14%
15% OUTPUT
16%   ERR        Average test error weighted by class priors.
17%   CERR       Unweighted test errors per class
18%   STD_ERR    Standard deviation in the error
19%   NLAB_OUT   Assigned numeric labels
20%
21% DESCRIPTION
22% Cross-validation estimation of the error and the instability of the
23% untrained classifier CLASSF using the dissimilarity dataset D. The set is
24% randomly permutated and divided in N (almost) equally sized parts. Note that
25% for a dissimilarity matrix, the division has to be applied both to rows and
26% to columns. The classifier is trained on N-1 parts and the remaining part is
27% used for testing. This is rotated over all parts.
28%
29% ERR is their weighted avarage over the class priors. CERR are the class error
30% frequencies.  D and/or CLASSF may be cell arrays of datasets and classifiers.
31% In that case ERR is an array with errors with on position ERR(i,j) the error of
32% the j-th classifier for the i-th dataset. In this mode, CERR and NLAB_OUT are
33% returned in cell arrays.
34%
35% If ITER > 1 the routine is run ITER times and results are averaged. The
36% standard deviation of the error is returned in STD_ERR.
37%
38% NOTE
39% D is a square dissimilarity matrix for which the representation set has to be
40% reduced to a K-element subset from the training set. This is done by random
41% selection. If K is not chosen the entire training set is used.
42%
43% Progress is reported in file FID, default FID=0: no report.  Use FID=1
44% for report in the command window.
45%
46% EXAMPLE
47% A = GENDATB(100);
48% D = SQRT(DISTM(A));
49% [E,S] = CROSSVALD (D, ldc([],1e-2,1e-6)*LOGDENS, 10, [], 5);
50%
51% SEE ALSO
52% DATASETS, MAPPINGS, TESTC
53
54% Copyright: R.P.W. Duin, r.duin@ieee.org
55% and Elzbieta Pekalska, ela.pekalska@googlemail.com
56% Faculty EWI, Delft University of Technology and
57% School of Computer Science, University of Manchester
58
59
60function [err,cerr,nlabout] = crossvald(data,classf,n,kred,iter,fid)
61
62        prtrace(mfilename);
63
64        if nargin < 6, fid = []; end
65        if nargin < 5 | isempty(iter), iter = 1; end
66        if nargin < 4 | isempty(kred), kred = []; end
67        if nargin < 3, n = []; end
68   
69        if iter ~= 1
70                eer = cell(1,iter);
71    s = sprintf('crossvald, %i iterations: ',iter);
72    prwaitbar(iter,s);
73                for j = 1:iter
74      prwaitbar(iter,j,[s num2str(j)]);
75                        eer{j} = feval(mfilename,data,classf,n,kred,1,fid);
76                end
77    prwaitbar(0);
78                fe = zeros(size(eer{1},1),size(eer{1},2),iter);
79                for j=1:iter
80                        fe(:,:,j) = eer{j};
81                end
82                err = mean(fe,3);
83                std_err = std(fe,[],3);
84                cerr = std_err;
85                return
86        end
87       
88                       
89        % datasets or classifiers are cell arrays
90        if iscell(classf) | iscell(data)
91
92                seed = rand('state');
93                if ~iscell(classf), classf = {classf}; end
94                if ~iscell(data), data = {data}; end
95                if isdataset(classf{1}) & ismapping(data{1}) % correct for old order
96                        dd = data; data = classf; classf = dd;
97                end
98                numc = length(classf);
99                numd = length(data);
100                cerr = cell(numd,numc);
101                nlab_out = cell(numd,numc);
102
103                prprogress(fid,['\n%i-fold crossvalidation: ' ...
104                             '%i classifiers, %i datasets\n'],n,numc,numd);
105
106                e = zeros(numd,numc);
107    if numc > 1
108      snumc = sprintf('crossvald: %i classifiers: ',numc);
109      prwaitbar(iter,snumc);
110    end
111       
112                for jc = 1:numc
113
114      if numc > 1, prwaitbar(numc,jc,[snumc num2str(jc)]); end
115                        prprogress(fid,'  %s\n',getname(classf{jc}));
116      if numd > 1
117        snumd = sprintf('crossvald: %i datasets: ',numd);
118        prwaitbar(iter,snumd);
119      end
120           
121                        for jd = 1:numd
122
123        if numd > 1, prwaitbar(numd,jd,[snumd num2str(jd)]); end
124                                prprogress(fid,'    %s',getname(data{jd}));
125
126                                rand('state',seed);
127                                [ee,cc,nn] = feval(mfilename,data{jd},classf{jc},n,kred,1,fid);
128                                e(jd,jc) = ee;
129                                cerr(jd,jc) = {cc};
130                                nlabout(jd,jc) = {nn};
131
132                        end
133      if numd > 1, prwaitbar(0); end
134                        %fprintf(fid,'\n');
135
136                end
137                if nargout == 0
138                        fprintf('\n  %i-fold cross validation result for',n);
139                        disperror(data,classf,e);
140                else
141                        err = e;
142                end
143    if numc > 1, prwaitbar(0); end
144   
145        else
146
147                if isdataset(classf) & ismapping(data) % correct for old order
148                        dd = data; data = classf; classf = dd;
149                end
150   % discheck(data,[],0);
151                isdataset(data);
152                ismapping(classf);     
153                [m,k,c] = getsize(data);
154                lab = getlab(data);
155                if isempty(n), n = m; end
156
157                if n > m
158                        warning('Number of batches too large: reset to leave-one-out')
159                        n = m;
160                elseif n <= 1
161                        error('Wrong size for number of batches')
162                end
163                if ~isuntrained(classf)
164                        error('Classifier should be untrained')
165                end
166                J = randperm(m);
167                N = classsizes(data);
168
169                % attempt to find an more equal distribution over the classes
170                if all(N > n)
171
172                        K = zeros(1,m);
173
174                        for i = 1:length(N)
175
176                                L = findnlab(data(J,:),i);
177
178                                M = mod(0:N(i)-1,n)+1;
179
180                                K(L) = M;
181
182                        end
183
184                else
185           
186                        K = mod(1:m,n)+1;
187
188                end
189       
190                nlabout = zeros(m,1);
191                rstate2 = rand('state');
192    prprogress(fid,'%5.0f      ',n);
193    sfolds = sprintf('crossvald: %i folds: ',n);
194    prwaitbar(n,sfolds);
195                for i = 1:n
196      prwaitbar(n,i,[sfolds num2str(i)]);
197                        OUT = find(K==i);
198                        JOUT=J(OUT);
199                        JIN = J; JIN(OUT) = [];
200                        if ~isempty(kred)
201                                if length(JIN) < kred
202                                        error('Training set too small for desired size representation set.')
203                                end
204                                rstate1 = rand('state');
205                                rand('state',rstate2);
206                                RED     = randperm(length(JIN));
207                                rstate2 = rand('state');
208                                rand('state',rstate1);
209                                RED  = RED(1:kred);  % Here we reduce the repr. set but dont take care of an equal distribution over classes
210                                JINF = JIN(RED);
211                                JINR = JINF;
212                        else
213                                JINF = JIN;
214                        end
215                        if (iscell(classf.data) & length(classf.data) > 0 & isparallel(classf.data{1}))
216                                m = size(data,1);
217                                k = size(data,2)/m;
218                                if (k ~= floor(k))
219                                        error('Dataset should be a concatenation of square matrices.')
220                                end
221                                JINF = repmat(JINF(:),1,k) + repmat([0:k-1]*m,kred,1);
222                                JINF = JINF(:);
223                        end
224                        dlearn = data(JIN,JINF);
225                        dtest  = data(JOUT,JINF);
226                        w      = dlearn*classf;   % Training
227                                                  % Testing
228                        [mx,nlabout(JOUT)] = max(+(dtest*w),[],2);
229             % nlabout contains class assignments according to w
230                                                 % needs conversion to nlabs
231                        L1 = matchlablist(getlablist(data),getlabels(w));
232                        [LL,L2] = sort(L1);
233                        nlabout(JOUT) = L2(nlabout(JOUT));                                                                                                                                                     
234                        s = sprintf('\b\b\b\b\b%5.0f',i);
235                        prprogress(fid,s);
236                end
237    prwaitbar(0)
238    prprogress(fid,'\b\b\b\b\b\b\b\b\b\b\b',n);
239    %L = matchlablist(getlablist(data),getlabels(w));
240                for j=1:c
241                        J = findnlab(data,j);
242                        f(j) = sum(nlabout(J)~=j)/length(J);
243                end
244                e = f*getprior(data)';
245                if nargout > 0
246                        err  = e;
247                        cerr = f;
248                else
249                        disp([num2str(n) '-fold cross validation error on ' num2str(size(data,1)) ' objects: ' num2str(e)])
250                end
251        end
252        return
Note: See TracBrowser for help on using the repository browser.