1 | %CROSSVALD Cross-validation error for dissimilarity representations |
---|
2 | % |
---|
3 | % [ERR,STD_ERR] = CROSSVALD(D,CLASSF,N,K,ITER,FID) |
---|
4 | % [ERR,CERR,NLAB_OUT] = CROSSVALD(D,CLASSF,N,K,1,FID) |
---|
5 | % |
---|
6 | % INPUT |
---|
7 | % A Input dataset |
---|
8 | % CLASSF Untrained classifier to be tested. |
---|
9 | % N Number of dataset divisions (default: N==number of |
---|
10 | % samples, leave-one-out) |
---|
11 | % K Desired size of the representation set (default: use all) |
---|
12 | % ITER Number of iterations |
---|
13 | % FID File descriptor for progress report file (default: 0) |
---|
14 | % |
---|
15 | % OUTPUT |
---|
16 | % ERR Average test error weighted by class priors. |
---|
17 | % CERR Unweighted test errors per class |
---|
18 | % STD_ERR Standard deviation in the error |
---|
19 | % NLAB_OUT Assigned numeric labels |
---|
20 | % |
---|
21 | % DESCRIPTION |
---|
22 | % Cross-validation estimation of the error and the instability of the |
---|
23 | % untrained classifier CLASSF using the dissimilarity dataset D. The set is |
---|
24 | % randomly permutated and divided in N (almost) equally sized parts. Note that |
---|
25 | % for a dissimilarity matrix, the division has to be applied both to rows and |
---|
26 | % to columns. The classifier is trained on N-1 parts and the remaining part is |
---|
27 | % used for testing. This is rotated over all parts. |
---|
28 | % |
---|
29 | % ERR is their weighted avarage over the class priors. CERR are the class error |
---|
30 | % frequencies. D and/or CLASSF may be cell arrays of datasets and classifiers. |
---|
31 | % In that case ERR is an array with errors with on position ERR(i,j) the error of |
---|
32 | % the j-th classifier for the i-th dataset. In this mode, CERR and NLAB_OUT are |
---|
33 | % returned in cell arrays. |
---|
34 | % |
---|
35 | % If ITER > 1 the routine is run ITER times and results are averaged. The |
---|
36 | % standard deviation of the error is returned in STD_ERR. |
---|
37 | % |
---|
38 | % NOTE |
---|
39 | % D is a square dissimilarity matrix for which the representation set has to be |
---|
40 | % reduced to a K-element subset from the training set. This is done by random |
---|
41 | % selection. If K is not chosen the entire training set is used. |
---|
42 | % |
---|
43 | % Progress is reported in file FID, default FID=0: no report. Use FID=1 |
---|
44 | % for report in the command window. |
---|
45 | % |
---|
46 | % EXAMPLE |
---|
47 | % A = GENDATB(100); |
---|
48 | % D = SQRT(DISTM(A)); |
---|
49 | % [E,S] = CROSSVALD (D, ldc([],1e-2,1e-6)*LOGDENS, 10, [], 5); |
---|
50 | % |
---|
51 | % SEE ALSO |
---|
52 | % DATASETS, MAPPINGS, TESTC |
---|
53 | |
---|
54 | % Copyright: R.P.W. Duin, r.duin@ieee.org |
---|
55 | % and Elzbieta Pekalska, ela.pekalska@googlemail.com |
---|
56 | % Faculty EWI, Delft University of Technology and |
---|
57 | % School of Computer Science, University of Manchester |
---|
58 | |
---|
59 | |
---|
60 | function [err,cerr,nlabout] = crossvald(data,classf,n,kred,iter,fid) |
---|
61 | |
---|
62 | prtrace(mfilename); |
---|
63 | |
---|
64 | if nargin < 6, fid = []; end |
---|
65 | if nargin < 5 | isempty(iter), iter = 1; end |
---|
66 | if nargin < 4 | isempty(kred), kred = []; end |
---|
67 | if nargin < 3, n = []; end |
---|
68 | |
---|
69 | if iter ~= 1 |
---|
70 | eer = cell(1,iter); |
---|
71 | s = sprintf('crossvald, %i iterations: ',iter); |
---|
72 | prwaitbar(iter,s); |
---|
73 | for j = 1:iter |
---|
74 | prwaitbar(iter,j,[s num2str(j)]); |
---|
75 | eer{j} = feval(mfilename,data,classf,n,kred,1,fid); |
---|
76 | end |
---|
77 | prwaitbar(0); |
---|
78 | fe = zeros(size(eer{1},1),size(eer{1},2),iter); |
---|
79 | for j=1:iter |
---|
80 | fe(:,:,j) = eer{j}; |
---|
81 | end |
---|
82 | err = mean(fe,3); |
---|
83 | std_err = std(fe,[],3); |
---|
84 | cerr = std_err; |
---|
85 | return |
---|
86 | end |
---|
87 | |
---|
88 | |
---|
89 | % datasets or classifiers are cell arrays |
---|
90 | if iscell(classf) | iscell(data) |
---|
91 | |
---|
92 | seed = rand('state'); |
---|
93 | if ~iscell(classf), classf = {classf}; end |
---|
94 | if ~iscell(data), data = {data}; end |
---|
95 | if isdataset(classf{1}) & ismapping(data{1}) % correct for old order |
---|
96 | dd = data; data = classf; classf = dd; |
---|
97 | end |
---|
98 | numc = length(classf); |
---|
99 | numd = length(data); |
---|
100 | cerr = cell(numd,numc); |
---|
101 | nlab_out = cell(numd,numc); |
---|
102 | |
---|
103 | prprogress(fid,['\n%i-fold crossvalidation: ' ... |
---|
104 | '%i classifiers, %i datasets\n'],n,numc,numd); |
---|
105 | |
---|
106 | e = zeros(numd,numc); |
---|
107 | if numc > 1 |
---|
108 | snumc = sprintf('crossvald: %i classifiers: ',numc); |
---|
109 | prwaitbar(iter,snumc); |
---|
110 | end |
---|
111 | |
---|
112 | for jc = 1:numc |
---|
113 | |
---|
114 | if numc > 1, prwaitbar(numc,jc,[snumc num2str(jc)]); end |
---|
115 | prprogress(fid,' %s\n',getname(classf{jc})); |
---|
116 | if numd > 1 |
---|
117 | snumd = sprintf('crossvald: %i datasets: ',numd); |
---|
118 | prwaitbar(iter,snumd); |
---|
119 | end |
---|
120 | |
---|
121 | for jd = 1:numd |
---|
122 | |
---|
123 | if numd > 1, prwaitbar(numd,jd,[snumd num2str(jd)]); end |
---|
124 | prprogress(fid,' %s',getname(data{jd})); |
---|
125 | |
---|
126 | rand('state',seed); |
---|
127 | [ee,cc,nn] = feval(mfilename,data{jd},classf{jc},n,kred,1,fid); |
---|
128 | e(jd,jc) = ee; |
---|
129 | cerr(jd,jc) = {cc}; |
---|
130 | nlabout(jd,jc) = {nn}; |
---|
131 | |
---|
132 | end |
---|
133 | if numd > 1, prwaitbar(0); end |
---|
134 | %fprintf(fid,'\n'); |
---|
135 | |
---|
136 | end |
---|
137 | if nargout == 0 |
---|
138 | fprintf('\n %i-fold cross validation result for',n); |
---|
139 | disperror(data,classf,e); |
---|
140 | else |
---|
141 | err = e; |
---|
142 | end |
---|
143 | if numc > 1, prwaitbar(0); end |
---|
144 | |
---|
145 | else |
---|
146 | |
---|
147 | if isdataset(classf) & ismapping(data) % correct for old order |
---|
148 | dd = data; data = classf; classf = dd; |
---|
149 | end |
---|
150 | % discheck(data,[],0); |
---|
151 | isdataset(data); |
---|
152 | ismapping(classf); |
---|
153 | [m,k,c] = getsize(data); |
---|
154 | lab = getlab(data); |
---|
155 | if isempty(n), n = m; end |
---|
156 | |
---|
157 | if n > m |
---|
158 | warning('Number of batches too large: reset to leave-one-out') |
---|
159 | n = m; |
---|
160 | elseif n <= 1 |
---|
161 | error('Wrong size for number of batches') |
---|
162 | end |
---|
163 | if ~isuntrained(classf) |
---|
164 | error('Classifier should be untrained') |
---|
165 | end |
---|
166 | J = randperm(m); |
---|
167 | N = classsizes(data); |
---|
168 | |
---|
169 | % attempt to find an more equal distribution over the classes |
---|
170 | if all(N > n) |
---|
171 | |
---|
172 | K = zeros(1,m); |
---|
173 | |
---|
174 | for i = 1:length(N) |
---|
175 | |
---|
176 | L = findnlab(data(J,:),i); |
---|
177 | |
---|
178 | M = mod(0:N(i)-1,n)+1; |
---|
179 | |
---|
180 | K(L) = M; |
---|
181 | |
---|
182 | end |
---|
183 | |
---|
184 | else |
---|
185 | |
---|
186 | K = mod(1:m,n)+1; |
---|
187 | |
---|
188 | end |
---|
189 | |
---|
190 | nlabout = zeros(m,1); |
---|
191 | rstate2 = rand('state'); |
---|
192 | prprogress(fid,'%5.0f ',n); |
---|
193 | sfolds = sprintf('crossvald: %i folds: ',n); |
---|
194 | prwaitbar(n,sfolds); |
---|
195 | for i = 1:n |
---|
196 | prwaitbar(n,i,[sfolds num2str(i)]); |
---|
197 | OUT = find(K==i); |
---|
198 | JOUT=J(OUT); |
---|
199 | JIN = J; JIN(OUT) = []; |
---|
200 | if ~isempty(kred) |
---|
201 | if length(JIN) < kred |
---|
202 | error('Training set too small for desired size representation set.') |
---|
203 | end |
---|
204 | rstate1 = rand('state'); |
---|
205 | rand('state',rstate2); |
---|
206 | RED = randperm(length(JIN)); |
---|
207 | rstate2 = rand('state'); |
---|
208 | rand('state',rstate1); |
---|
209 | RED = RED(1:kred); % Here we reduce the repr. set but dont take care of an equal distribution over classes |
---|
210 | JINF = JIN(RED); |
---|
211 | JINR = JINF; |
---|
212 | else |
---|
213 | JINF = JIN; |
---|
214 | end |
---|
215 | if (iscell(classf.data) & length(classf.data) > 0 & isparallel(classf.data{1})) |
---|
216 | m = size(data,1); |
---|
217 | k = size(data,2)/m; |
---|
218 | if (k ~= floor(k)) |
---|
219 | error('Dataset should be a concatenation of square matrices.') |
---|
220 | end |
---|
221 | JINF = repmat(JINF(:),1,k) + repmat([0:k-1]*m,kred,1); |
---|
222 | JINF = JINF(:); |
---|
223 | end |
---|
224 | dlearn = data(JIN,JINF); |
---|
225 | dtest = data(JOUT,JINF); |
---|
226 | w = dlearn*classf; % Training |
---|
227 | % Testing |
---|
228 | [mx,nlabout(JOUT)] = max(+(dtest*w),[],2); |
---|
229 | % nlabout contains class assignments according to w |
---|
230 | % needs conversion to nlabs |
---|
231 | L1 = matchlablist(getlablist(data),getlabels(w)); |
---|
232 | [LL,L2] = sort(L1); |
---|
233 | nlabout(JOUT) = L2(nlabout(JOUT)); |
---|
234 | s = sprintf('\b\b\b\b\b%5.0f',i); |
---|
235 | prprogress(fid,s); |
---|
236 | end |
---|
237 | prwaitbar(0) |
---|
238 | prprogress(fid,'\b\b\b\b\b\b\b\b\b\b\b',n); |
---|
239 | %L = matchlablist(getlablist(data),getlabels(w)); |
---|
240 | for j=1:c |
---|
241 | J = findnlab(data,j); |
---|
242 | f(j) = sum(nlabout(J)~=j)/length(J); |
---|
243 | end |
---|
244 | e = f*getprior(data)'; |
---|
245 | if nargout > 0 |
---|
246 | err = e; |
---|
247 | cerr = f; |
---|
248 | else |
---|
249 | disp([num2str(n) '-fold cross validation error on ' num2str(size(data,1)) ' objects: ' num2str(e)]) |
---|
250 | end |
---|
251 | end |
---|
252 | return |
---|