source: prextra/gatem.m @ 52

Last change on this file since 52 was 5, checked in by bduin, 14 years ago
File size: 3.4 KB
Line 
1%GATEM  Gate classifier
2%
3%       W = GATEM(A,WGATE,W)
4%
5% Make a gate classifier WGATE that splits the feature space in subsets,
6% train a classifier W{i} on each of the subsets and combine the results
7% again. Thus W is a cell-array (cell-vector) of untrained classifiers.
8% Both the gate classifier as the secondary classifiers are trained on
9% dataset A.  This gatem is like a mixture of experts, where each
10% classifier is focussing on a special area in the feature space.
11%
12% Example:  a = gendatb;    % generate a 2-class banana dataset
13%           w = gatem(a, ldc, {parzenc, qdc});
14%                           % split the feature space in two by an ldc
15%                           % train a parzenc in one half, and a qdc in
16%                           % the other half
17%           scatterd(a); plotc(w)
18%
19% In this version the ordering of the data that is passed to the
20% classifiers W{i} is based on the ordering as defined by
21% getlabels(A*WGATE).  Probably this should be made more flexible in a
22% future version.
23
24% Copyright: D.M.J. Tax, D.M.J.Tax@prtools.org
25% Faculty EWI, Delft University of Technology
26% P.O. Box 5031, 2600 GA Delft, The Netherlands
27
28function w = gatem(a,wgate,W)
29
30if nargin < 1 | isempty(a)
31        % When no inputs are given, we are expected to return an empty
32        % mapping:
33        W = mapping(mfilename,{wgate,W});
34        W = setname(W,'Gate classifier');
35        return
36end
37
38if ~strcmp(getmapping_file(wgate),mfilename)  % training
39        % training consist of training the gate classifier, splitting the
40        % dataset A according to the output label, and training the
41        % individual classifiers on the smaller datasets.
42
43        % First train the gate:
44        if ~istrained(wgate)
45                wgate = a*wgate;
46        end
47        gatelabels = getlabels(wgate);
48        [k,c] = size(wgate);
49
50        % Do some checking
51        if length(W)~=c
52                error('The number of classifiers in W should be equal to the number of classes');
53        end
54
55        % Now map the data through the classifier:
56        out = a*wgate;
57        % and extract the indices of the sub-datasets:
58        [dummy,I] = max(+out,[],2);
59
60        % Train the secondary mappings:
61        for i=1:c
62                %Empty classifiers and trained classifiers don't have to be trained
63                if ~isempty(W{i}) & ~istrained(W{i})
64                        % extract the datasets,
65                        b = a(find(I==i),:);
66                        % be careful, this can go wrong:
67                        if isempty(b)
68                                warning('One of the classifiers did not get any training data');
69                        else
70                                % and train the mapping
71                                W{i} = b*W{i};
72                        end
73                end
74               
75                % Does this classifier output the same as the gate??
76                if ~isempty(W{i})
77                        labels{i} = matchlablist(getlabels(W{i}),gatelabels);
78                else
79                        labels{i} = [];
80                end
81        end
82
83        % Now store everything:
84        V.wgate = wgate;
85        V.W = W;
86        V.labels = labels;
87        w = mapping(mfilename,'trained',V,gatelabels,k,c);
88        w = setname(w,'Gate classifier');
89
90else                  % evaluation
91        % get the data out:
92        V = getdata(wgate);
93        [n,k] = size(a);
94
95        % Apply the gate classifier:
96        out = +(a*V.wgate);
97        [dummy,lab] = max(+out,[],2);
98
99        % Now apply the secondary classifiers
100        for i=1:length(V.W)
101                if ~isempty(V.W{i})
102                        % extract the datasets,
103                        I = find(lab==i);
104                        if ~isempty(I)  % so, if there is data to classify:
105                                % compute the output:
106                                tmpout = +(a(I,:)*V.W{i});
107                                % store the output:
108                                %out(I,:) = tmpout(:,V.labels{i});
109                                out(I,V.labels{i}) = tmpout;
110                        end
111                end
112        end
113
114        % Fill in the data, keeping all other fields in the dataset intact
115        w = setdata(a,out);
116        w = set(w,'featlab',getlabels(wgate),'featsize',getsize_out(wgate));
117
118end
119
120return
Note: See TracBrowser for help on using the repository browser.