source: distools/nnerr.m @ 114

Last change on this file since 114 was 10, checked in by bduin, 14 years ago
File size: 2.5 KB
Line 
1%NNERR Exact expected NN error from a dissimilarity matrix (1)
2%
3%   E = NNERR(D,M)
4%
5% INPUT
6%   D   NxN dissimilarity dataset
7%   M   Vector with desired umber of objects to be selected
8%
9% OUTPUT
10%   E   Expected NN errror
11%
12%   E = NNERR(D)
13%
14% In this case a set of training set sizes is used to produce
15% a full learning curve. E can be plotted by PLOTE.
16%
17% DESCRIPTION
18% An exact computation is made of the expected NN error for a random
19% selection of M objects for training. D should be a dataset containing
20% a labeled square dissimilarity matrix.
21%
22
23% Copyright: R.P.W. Duin, r.duin@ieee.org
24% and Elzbieta Pekalska, ela.pekalska@googlemail.com
25% Faculty EWI, Delft University of Technology and
26% School of Computer Science, University of Manchester
27
28function e = nnerr(d,n)
29
30  if nargin < 2, n = []; end
31        m = size(d,2);
32        if nargin == 2 & any(n >= m)
33                error('Training set sizes should be smaller than sample size')
34        end
35       
36        if isempty(n)
37                % find full curve, but gain some speed
38                L = [1:20 22:2:40 45:5:60 70:10:100 120:20:300 350:50:1000 1100:100:10000];
39                L = [L(find(L<m-1)) m-1];
40                f = zeros(1,m-1);
41                prwaitbar(max(L),'Compute Learning Curve')
42                for i=1:length(L)
43                        prwaitbar(max(L),L(i));
44                        f(L(i)) = feval(mfilename,d,L(i));
45                        if (i > 1) & (L(i)-L(i-1) > 1)
46                                for n=L(i-1):L(i)
47                                        f(n) = f(L(i-1)) + (f(L(i))-f(L(i-1)))*(n-L(i-1))/(L(i)-L(i-1));
48                                end
49                        end
50                end
51                prwaitbar(0)
52   
53    e.error = f;
54    e.xvalues = [1:length(e.error)];
55    e.title = 'Learning curve 1-NN rule';
56    e.xlabel = 'Size training set';
57    e.ylabel = 'Expected classification error';
58    e.plot = 'semilogx';
59       
60  elseif length(n) > 1
61       
62                for i=1:length(n)
63                        e(i) = feval(mfilename,d,n(i));
64    end
65
66  else
67       
68    q = zeros(1,m);
69    for k = 1:m-n
70      %p(k) = (prod(m-k+1-n+1:m-k+1) - prod(m-k-n+1:m-k)) / prod(m-n+1:m);
71      q(k) = (exp(gamln(m-k+2)-gamln(m-k+1-n+1)-gamln(m+1)+gamln(m-n+1)) ...
72                     - exp(gamln(m-k+1)-gamln(m-k-n+1)-gamln(m+1)+gamln(m-n+1)));
73    end
74    k = m-n+1;
75    %p(k) = (prod(m-k+1-n+1:m-k+1) - prod(m-k-n+1:m-k)) / prod(m-n+1:m);
76    q(k) = (exp(gamln(m-k+2)-gamln(m-k+1-n+1)-gamln(m+1)+gamln(m-n+1)));
77       
78    isdataset(d);
79    nlab     = getnlab(d);
80    d = d + diag(repmat(inf,1,m));
81    [DD,L] = sort(+d,2);                        % sort distances
82    L = nlab(L);
83    R = mean(L ~= repmat(nlab,1,m));
84    e = q*R';
85
86  end
87   
88return
89
90function x = gamln(y)
91        if y == 0, x = 1;
92        elseif y < 0, x = 1;
93        else, x = gammaln(y); end
94return
Note: See TracBrowser for help on using the repository browser.