1 | %GATEM Gate classifier |
---|
2 | % |
---|
3 | % W = GATEM(A,WGATE,W) |
---|
4 | % |
---|
5 | % Make a gate classifier WGATE that splits the feature space in subsets, |
---|
6 | % train a classifier W{i} on each of the subsets and combine the results |
---|
7 | % again. Thus W is a cell-array (cell-vector) of untrained classifiers. |
---|
8 | % Both the gate classifier as the secondary classifiers are trained on |
---|
9 | % dataset A. This gatem is like a mixture of experts, where each |
---|
10 | % classifier is focussing on a special area in the feature space. |
---|
11 | % |
---|
12 | % Example: a = gendatb; % generate a 2-class banana dataset |
---|
13 | % w = gatem(a, ldc, {parzenc, qdc}); |
---|
14 | % % split the feature space in two by an ldc |
---|
15 | % % train a parzenc in one half, and a qdc in |
---|
16 | % % the other half |
---|
17 | % scatterd(a); plotc(w) |
---|
18 | % |
---|
19 | % In this version the ordering of the data that is passed to the |
---|
20 | % classifiers W{i} is based on the ordering as defined by |
---|
21 | % getlabels(A*WGATE). Probably this should be made more flexible in a |
---|
22 | % future version. |
---|
23 | |
---|
24 | % Copyright: D.M.J. Tax, D.M.J.Tax@prtools.org |
---|
25 | % Faculty EWI, Delft University of Technology |
---|
26 | % P.O. Box 5031, 2600 GA Delft, The Netherlands |
---|
27 | |
---|
28 | function w = gatem(a,wgate,W) |
---|
29 | |
---|
30 | if nargin < 1 | isempty(a) |
---|
31 | % When no inputs are given, we are expected to return an empty |
---|
32 | % mapping: |
---|
33 | W = mapping(mfilename,{wgate,W}); |
---|
34 | W = setname(W,'Gate classifier'); |
---|
35 | return |
---|
36 | end |
---|
37 | |
---|
38 | if ~strcmp(getmapping_file(wgate),mfilename) % training |
---|
39 | % training consist of training the gate classifier, splitting the |
---|
40 | % dataset A according to the output label, and training the |
---|
41 | % individual classifiers on the smaller datasets. |
---|
42 | |
---|
43 | % First train the gate: |
---|
44 | if ~istrained(wgate) |
---|
45 | wgate = a*wgate; |
---|
46 | end |
---|
47 | gatelabels = getlabels(wgate); |
---|
48 | [k,c] = size(wgate); |
---|
49 | |
---|
50 | % Do some checking |
---|
51 | if length(W)~=c |
---|
52 | error('The number of classifiers in W should be equal to the number of classes'); |
---|
53 | end |
---|
54 | |
---|
55 | % Now map the data through the classifier: |
---|
56 | out = a*wgate; |
---|
57 | % and extract the indices of the sub-datasets: |
---|
58 | [dummy,I] = max(+out,[],2); |
---|
59 | |
---|
60 | % Train the secondary mappings: |
---|
61 | for i=1:c |
---|
62 | %Empty classifiers and trained classifiers don't have to be trained |
---|
63 | if ~isempty(W{i}) & ~istrained(W{i}) |
---|
64 | % extract the datasets, |
---|
65 | b = a(find(I==i),:); |
---|
66 | % be careful, this can go wrong: |
---|
67 | if isempty(b) |
---|
68 | warning('One of the classifiers did not get any training data'); |
---|
69 | else |
---|
70 | % and train the mapping |
---|
71 | W{i} = b*W{i}; |
---|
72 | end |
---|
73 | end |
---|
74 | |
---|
75 | % Does this classifier output the same as the gate?? |
---|
76 | if ~isempty(W{i}) |
---|
77 | labels{i} = matchlablist(getlabels(W{i}),gatelabels); |
---|
78 | else |
---|
79 | labels{i} = []; |
---|
80 | end |
---|
81 | end |
---|
82 | |
---|
83 | % Now store everything: |
---|
84 | V.wgate = wgate; |
---|
85 | V.W = W; |
---|
86 | V.labels = labels; |
---|
87 | w = mapping(mfilename,'trained',V,gatelabels,k,c); |
---|
88 | w = setname(w,'Gate classifier'); |
---|
89 | |
---|
90 | else % evaluation |
---|
91 | % get the data out: |
---|
92 | V = getdata(wgate); |
---|
93 | [n,k] = size(a); |
---|
94 | |
---|
95 | % Apply the gate classifier: |
---|
96 | out = +(a*V.wgate); |
---|
97 | [dummy,lab] = max(+out,[],2); |
---|
98 | |
---|
99 | % Now apply the secondary classifiers |
---|
100 | for i=1:length(V.W) |
---|
101 | if ~isempty(V.W{i}) |
---|
102 | % extract the datasets, |
---|
103 | I = find(lab==i); |
---|
104 | if ~isempty(I) % so, if there is data to classify: |
---|
105 | % compute the output: |
---|
106 | tmpout = +(a(I,:)*V.W{i}); |
---|
107 | % store the output: |
---|
108 | %out(I,:) = tmpout(:,V.labels{i}); |
---|
109 | out(I,V.labels{i}) = tmpout; |
---|
110 | end |
---|
111 | end |
---|
112 | end |
---|
113 | |
---|
114 | % Fill in the data, keeping all other fields in the dataset intact |
---|
115 | w = setdata(a,out); |
---|
116 | w = set(w,'featlab',getlabels(wgate),'featsize',getsize_out(wgate)); |
---|
117 | |
---|
118 | end |
---|
119 | |
---|
120 | return |
---|