[10] | 1 | %CROSSVALD Cross-validation error for dissimilarity representations |
---|
| 2 | % |
---|
| 3 | % [ERR,STD_ERR] = CROSSVALD(D,CLASSF,N,K,ITER,FID) |
---|
| 4 | % [ERR,CERR,NLAB_OUT] = CROSSVALD(D,CLASSF,N,K,1,FID) |
---|
| 5 | % |
---|
| 6 | % INPUT |
---|
| 7 | % A Input dataset |
---|
| 8 | % CLASSF Untrained classifier to be tested. |
---|
| 9 | % N Number of dataset divisions (default: N==number of |
---|
| 10 | % samples, leave-one-out) |
---|
| 11 | % K Desired size of the representation set (default: use all) |
---|
| 12 | % ITER Number of iterations |
---|
| 13 | % FID File descriptor for progress report file (default: 0) |
---|
| 14 | % |
---|
| 15 | % OUTPUT |
---|
| 16 | % ERR Average test error weighted by class priors. |
---|
| 17 | % CERR Unweighted test errors per class |
---|
| 18 | % STD_ERR Standard deviation in the error |
---|
| 19 | % NLAB_OUT Assigned numeric labels |
---|
| 20 | % |
---|
| 21 | % DESCRIPTION |
---|
| 22 | % Cross-validation estimation of the error and the instability of the |
---|
| 23 | % untrained classifier CLASSF using the dissimilarity dataset D. The set is |
---|
| 24 | % randomly permutated and divided in N (almost) equally sized parts. Note that |
---|
| 25 | % for a dissimilarity matrix, the division has to be applied both to rows and |
---|
| 26 | % to columns. The classifier is trained on N-1 parts and the remaining part is |
---|
| 27 | % used for testing. This is rotated over all parts. |
---|
| 28 | % |
---|
| 29 | % ERR is their weighted avarage over the class priors. CERR are the class error |
---|
| 30 | % frequencies. D and/or CLASSF may be cell arrays of datasets and classifiers. |
---|
| 31 | % In that case ERR is an array with errors with on position ERR(i,j) the error of |
---|
| 32 | % the j-th classifier for the i-th dataset. In this mode, CERR and NLAB_OUT are |
---|
| 33 | % returned in cell arrays. |
---|
| 34 | % |
---|
| 35 | % If ITER > 1 the routine is run ITER times and results are averaged. The |
---|
| 36 | % standard deviation of the error is returned in STD_ERR. |
---|
| 37 | % |
---|
| 38 | % NOTE |
---|
| 39 | % D is a square dissimilarity matrix for which the representation set has to be |
---|
| 40 | % reduced to a K-element subset from the training set. This is done by random |
---|
| 41 | % selection. If K is not chosen the entire training set is used. |
---|
| 42 | % |
---|
| 43 | % Progress is reported in file FID, default FID=0: no report. Use FID=1 |
---|
| 44 | % for report in the command window. |
---|
| 45 | % |
---|
| 46 | % EXAMPLE |
---|
| 47 | % A = GENDATB(100); |
---|
| 48 | % D = SQRT(DISTM(A)); |
---|
| 49 | % [E,S] = CROSSVALD (D, ldc([],1e-2,1e-6)*LOGDENS, 10, [], 5); |
---|
| 50 | % |
---|
| 51 | % SEE ALSO |
---|
| 52 | % DATASETS, MAPPINGS, TESTC |
---|
| 53 | |
---|
| 54 | % Copyright: R.P.W. Duin, r.duin@ieee.org |
---|
| 55 | % and Elzbieta Pekalska, ela.pekalska@googlemail.com |
---|
| 56 | % Faculty EWI, Delft University of Technology and |
---|
| 57 | % School of Computer Science, University of Manchester |
---|
| 58 | |
---|
| 59 | |
---|
| 60 | function [err,cerr,nlabout] = crossvald(data,classf,n,kred,iter,fid) |
---|
| 61 | |
---|
| 62 | prtrace(mfilename); |
---|
| 63 | |
---|
| 64 | if nargin < 6, fid = []; end |
---|
| 65 | if nargin < 5 | isempty(iter), iter = 1; end |
---|
| 66 | if nargin < 4 | isempty(kred), kred = []; end |
---|
| 67 | if nargin < 3, n = []; end |
---|
| 68 | |
---|
| 69 | if iter ~= 1 |
---|
| 70 | eer = cell(1,iter); |
---|
| 71 | s = sprintf('crossvald, %i iterations: ',iter); |
---|
| 72 | prwaitbar(iter,s); |
---|
| 73 | for j = 1:iter |
---|
| 74 | prwaitbar(iter,j,[s num2str(j)]); |
---|
| 75 | eer{j} = feval(mfilename,data,classf,n,kred,1,fid); |
---|
| 76 | end |
---|
| 77 | prwaitbar(0); |
---|
| 78 | fe = zeros(size(eer{1},1),size(eer{1},2),iter); |
---|
| 79 | for j=1:iter |
---|
| 80 | fe(:,:,j) = eer{j}; |
---|
| 81 | end |
---|
| 82 | err = mean(fe,3); |
---|
| 83 | std_err = std(fe,[],3); |
---|
| 84 | cerr = std_err; |
---|
| 85 | return |
---|
| 86 | end |
---|
| 87 | |
---|
| 88 | |
---|
| 89 | % datasets or classifiers are cell arrays |
---|
| 90 | if iscell(classf) | iscell(data) |
---|
| 91 | |
---|
| 92 | seed = rand('state'); |
---|
| 93 | if ~iscell(classf), classf = {classf}; end |
---|
| 94 | if ~iscell(data), data = {data}; end |
---|
| 95 | if isdataset(classf{1}) & ismapping(data{1}) % correct for old order |
---|
| 96 | dd = data; data = classf; classf = dd; |
---|
| 97 | end |
---|
| 98 | numc = length(classf); |
---|
| 99 | numd = length(data); |
---|
| 100 | cerr = cell(numd,numc); |
---|
| 101 | nlab_out = cell(numd,numc); |
---|
| 102 | |
---|
| 103 | prprogress(fid,['\n%i-fold crossvalidation: ' ... |
---|
| 104 | '%i classifiers, %i datasets\n'],n,numc,numd); |
---|
| 105 | |
---|
| 106 | e = zeros(numd,numc); |
---|
| 107 | if numc > 1 |
---|
| 108 | snumc = sprintf('crossvald: %i classifiers: ',numc); |
---|
| 109 | prwaitbar(iter,snumc); |
---|
| 110 | end |
---|
| 111 | |
---|
| 112 | for jc = 1:numc |
---|
| 113 | |
---|
| 114 | if numc > 1, prwaitbar(numc,jc,[snumc num2str(jc)]); end |
---|
| 115 | prprogress(fid,' %s\n',getname(classf{jc})); |
---|
| 116 | if numd > 1 |
---|
| 117 | snumd = sprintf('crossvald: %i datasets: ',numd); |
---|
| 118 | prwaitbar(iter,snumd); |
---|
| 119 | end |
---|
| 120 | |
---|
| 121 | for jd = 1:numd |
---|
| 122 | |
---|
| 123 | if numd > 1, prwaitbar(numd,jd,[snumd num2str(jd)]); end |
---|
| 124 | prprogress(fid,' %s',getname(data{jd})); |
---|
| 125 | |
---|
| 126 | rand('state',seed); |
---|
| 127 | [ee,cc,nn] = feval(mfilename,data{jd},classf{jc},n,kred,1,fid); |
---|
| 128 | e(jd,jc) = ee; |
---|
| 129 | cerr(jd,jc) = {cc}; |
---|
| 130 | nlabout(jd,jc) = {nn}; |
---|
| 131 | |
---|
| 132 | end |
---|
| 133 | if numd > 1, prwaitbar(0); end |
---|
| 134 | %fprintf(fid,'\n'); |
---|
| 135 | |
---|
| 136 | end |
---|
| 137 | if nargout == 0 |
---|
| 138 | fprintf('\n %i-fold cross validation result for',n); |
---|
| 139 | disperror(data,classf,e); |
---|
| 140 | else |
---|
| 141 | err = e; |
---|
| 142 | end |
---|
| 143 | if numc > 1, prwaitbar(0); end |
---|
| 144 | |
---|
| 145 | else |
---|
| 146 | |
---|
| 147 | if isdataset(classf) & ismapping(data) % correct for old order |
---|
| 148 | dd = data; data = classf; classf = dd; |
---|
| 149 | end |
---|
| 150 | % discheck(data,[],0); |
---|
| 151 | isdataset(data); |
---|
| 152 | ismapping(classf); |
---|
| 153 | [m,k,c] = getsize(data); |
---|
| 154 | lab = getlab(data); |
---|
| 155 | if isempty(n), n = m; end |
---|
| 156 | |
---|
| 157 | if n > m |
---|
| 158 | warning('Number of batches too large: reset to leave-one-out') |
---|
| 159 | n = m; |
---|
| 160 | elseif n <= 1 |
---|
| 161 | error('Wrong size for number of batches') |
---|
| 162 | end |
---|
| 163 | if ~isuntrained(classf) |
---|
| 164 | error('Classifier should be untrained') |
---|
| 165 | end |
---|
| 166 | J = randperm(m); |
---|
| 167 | N = classsizes(data); |
---|
| 168 | |
---|
| 169 | % attempt to find an more equal distribution over the classes |
---|
| 170 | if all(N > n) |
---|
| 171 | |
---|
| 172 | K = zeros(1,m); |
---|
| 173 | |
---|
| 174 | for i = 1:length(N) |
---|
| 175 | |
---|
| 176 | L = findnlab(data(J,:),i); |
---|
| 177 | |
---|
| 178 | M = mod(0:N(i)-1,n)+1; |
---|
| 179 | |
---|
| 180 | K(L) = M; |
---|
| 181 | |
---|
| 182 | end |
---|
| 183 | |
---|
| 184 | else |
---|
| 185 | |
---|
| 186 | K = mod(1:m,n)+1; |
---|
| 187 | |
---|
| 188 | end |
---|
| 189 | |
---|
| 190 | nlabout = zeros(m,1); |
---|
| 191 | rstate2 = rand('state'); |
---|
| 192 | prprogress(fid,'%5.0f ',n); |
---|
| 193 | sfolds = sprintf('crossvald: %i folds: ',n); |
---|
| 194 | prwaitbar(n,sfolds); |
---|
| 195 | for i = 1:n |
---|
| 196 | prwaitbar(n,i,[sfolds num2str(i)]); |
---|
| 197 | OUT = find(K==i); |
---|
| 198 | JOUT=J(OUT); |
---|
| 199 | JIN = J; JIN(OUT) = []; |
---|
| 200 | if ~isempty(kred) |
---|
| 201 | if length(JIN) < kred |
---|
| 202 | error('Training set too small for desired size representation set.') |
---|
| 203 | end |
---|
| 204 | rstate1 = rand('state'); |
---|
| 205 | rand('state',rstate2); |
---|
| 206 | RED = randperm(length(JIN)); |
---|
| 207 | rstate2 = rand('state'); |
---|
| 208 | rand('state',rstate1); |
---|
| 209 | RED = RED(1:kred); % Here we reduce the repr. set but dont take care of an equal distribution over classes |
---|
| 210 | JINF = JIN(RED); |
---|
| 211 | JINR = JINF; |
---|
| 212 | else |
---|
| 213 | JINF = JIN; |
---|
| 214 | end |
---|
| 215 | if (iscell(classf.data) & length(classf.data) > 0 & isparallel(classf.data{1})) |
---|
| 216 | m = size(data,1); |
---|
| 217 | k = size(data,2)/m; |
---|
| 218 | if (k ~= floor(k)) |
---|
| 219 | error('Dataset should be a concatenation of square matrices.') |
---|
| 220 | end |
---|
| 221 | JINF = repmat(JINF(:),1,k) + repmat([0:k-1]*m,kred,1); |
---|
| 222 | JINF = JINF(:); |
---|
| 223 | end |
---|
| 224 | dlearn = data(JIN,JINF); |
---|
| 225 | dtest = data(JOUT,JINF); |
---|
| 226 | w = dlearn*classf; % Training |
---|
| 227 | % Testing |
---|
| 228 | [mx,nlabout(JOUT)] = max(+(dtest*w),[],2); |
---|
| 229 | % nlabout contains class assignments according to w |
---|
| 230 | % needs conversion to nlabs |
---|
| 231 | L1 = matchlablist(getlablist(data),getlabels(w)); |
---|
| 232 | [LL,L2] = sort(L1); |
---|
| 233 | nlabout(JOUT) = L2(nlabout(JOUT)); |
---|
| 234 | s = sprintf('\b\b\b\b\b%5.0f',i); |
---|
| 235 | prprogress(fid,s); |
---|
| 236 | end |
---|
| 237 | prwaitbar(0) |
---|
| 238 | prprogress(fid,'\b\b\b\b\b\b\b\b\b\b\b',n); |
---|
| 239 | %L = matchlablist(getlablist(data),getlabels(w)); |
---|
| 240 | for j=1:c |
---|
| 241 | J = findnlab(data,j); |
---|
| 242 | f(j) = sum(nlabout(J)~=j)/length(J); |
---|
| 243 | end |
---|
| 244 | e = f*getprior(data)'; |
---|
| 245 | if nargout > 0 |
---|
| 246 | err = e; |
---|
| 247 | cerr = f; |
---|
| 248 | else |
---|
| 249 | disp([num2str(n) '-fold cross validation error on ' num2str(size(data,1)) ' objects: ' num2str(e)]) |
---|
| 250 | end |
---|
| 251 | end |
---|
| 252 | return |
---|