source: prextra/wvotec.m @ 10

Last change on this file since 10 was 5, checked in by bduin, 14 years ago
File size: 3.2 KB
RevLine 
[5]1%WVOTEC Weighted combiner (Adaboost weights)
2%
3%  W = WVOTEC(A,V)   compute weigths and store
4%  W = WVOTEC(V,U)   Construct weighted combiner using weights U
5%
6%  INPUT
7%    A      Labeled dataset
8%    V      Parallel or stacked set of trained classifiers
9%    U      Set of classifier weights
10%
11%  OUTPUT
12%    W      Combined classifier
13%
14% DESCRIPTION
15% The set of trained classifiers V is combined using weighted
16% majority voting. If given the weights U are used. If not
17% given, the weights are computed from the classification
18% results of the labeled dataset A using 0.5*log((1-E)/E)
19% if E is the classifier error.
20%
21% SEE ALSO
22% MAPPINGS, DATASETS,
23
24% Copyright: R.P.W. Duin, r.p.w.duin@prtools.org
25% Faculty EWI, Delft University of Technology
26% P.O. Box 5031, 2600 GA Delft, The Netherlands
27
28function w = wvotec(a,v)
29
30prtrace(mfilename);
31
32if nargin < 1 | isempty(a)
33        w = mapping(mfilename,'untrained');
34elseif nargin < 2
35        error('Illegal call')
36elseif isdataset(a)                    % train or classify
37        if ~strcmp(v.mapping_file,mfilename) % training
38                if isparallel(v)                   % parallel combiner
39                        n = 0;
40                        e = zeros(1,length(v.data));
41                        for j=1:length(v.data)
42                                vv = v.data{j};
43                                d = a(:,n+1:n+size(vv,1))*vv*classc;
44                                e(j) = testmc(d);
45                                n = n+size(vv,1);
46                        end
47                elseif isstacked(v)                % stacked combiner
48                        e = zeros(1,length(v.data));
49                        for j=1:length(v.data)
50                                vv = v.data{j};
51                                e(j) = testmc(a,vv);
52                        end
53                else
54                        error('Classifier combination expected')
55                end
56                                  % Find weights                                                               
57                L = find(e < 1-max(getprior(a))); % take classifier better than prior
58                alf = zeros(1,length(e));
59                alf(L) = log((1-e(L))./e(L))/2;
60                alf = alf/sum(alf);
61                                  % construct the classifier
62                [m,k,c] = getsize(a);
63                w = mapping(mfilename,'trained',{v,alf},getlabels(vv),k,c);
64                w = setname(w,'Weighted Voting');
65        else                                 % testing
66                alf = v.data{2};                   % get the weights
67                u = v.data{1};                     % get the set of classifiers
68                m = size(a,1);
69                dtot = zeros(m,size(v,2));
70                if isparallel(u)                   % parallel combiner
71                        n = 0;
72                        for j=1:length(u.data)           % weight them
73                                vv = u.data{j};
74                                aa = a(:,n+1:n+size(vv,1));
75                                d = a(:,n+1:n+size(vv,1))*vv;
76                                [dd,jj] = max(+d,[],2);
77                                dd = zeros(size(dtot));
78                                dd([1:m]'+(jj-1)*m) = alf(j);
79                                dtot = dtot + dd;
80                                n = n+size(vv,1);
81                        end
82                elseif isstacked(u)                % stacked combiner
83                        for j=1:length(u.data)           % weight them
84                                vv = u.data{j};
85                                d = a*vv;
86                                [dd,jj] = max(+d,[],2);
87                                dd = zeros(size(dtot));
88                                dd([1:m]'+(jj-1)*m) = alf(j);
89                                dtot = dtot + dd;
90                        end
91                else
92                        error('Classifier combination expected')
93                end
94                w = setdat(d,dtot);
95  end
96 
97else                  % store classifier from given weights
98   
99  ismapping(a);
100  u = v;              % the weights
101  v = a;              % the combined classifier
102  n = length(v.data);
103  if length(u) ~= n
104    error('Wrong number of weights given')
105  end
106  [k,c] = getsize(v.data{1});
107        w = mapping(mfilename,'trained',{v,u},getlabels(v{1}),k,c);
108        w = setname(w,'Weighted Voting');
109end
110
111               
112                       
Note: See TracBrowser for help on using the repository browser.