source: prextra/emc.m @ 109

Last change on this file since 109 was 109, checked in by bduin, 9 years ago
File size: 2.2 KB
Line 
1%EMC EM Classifier using semi-supervised data
2%
3%   W = EMC(A,B,CLASSF,LABTYPE,FID)
4%   W = A*EMC([],B,CLASSF,LABTYPE,FID)
5%
6% INPUT
7%   A       Labeled dataset used for training
8%   B       Additional unlabeled dataset
9%   CLASSF  Untrained classifier (default QDC)
10%   LABTYPE Label type to be used (crisp (default) or soft)
11%   FID     File ID to write progress to (default [], see PRPROGRESS)
12%
13% OUTPUT
14%   W      Trained classifier
15%
16% DESCRIPTION
17% Using the EM algorithm the classifier CLASSF is used iteratively
18% on the joint dataset [A;B]. In each step the labels of A are reset
19% to their initial values. Initial labels in B are neglected.
20% Labels of LABTYPE 'soft' are not supported by all classifiers.
21%
22% SEE ALSO
23% DATASETS, MAPPINGS, EMCLUST, PRPROGRESS
24
25% Copyright: R.P.W. Duin, r.p.w.duin@prtools.org
26% Faculty EWI, Delft University of Technology
27% P.O. Box 5031, 2600 GA Delft, The Netherlands
28
29function w = emc(a,b,classf,labtype)
30        if nargin < 4 | isempty(labtype), labtype = 'crisp'; end
31        if nargin < 3 | isempty(classf), classf = qdc; end     
32        if nargin < 2, b = []; end
33        if nargin < 1 | isempty(a)
34                w = mapping(mfilename,'untrained',{b,classf,labtype,fid});
35                w = setname(w,'EM CLassifier');
36                return
37        end
38
39        islabtype(a,'crisp','soft');
40        isvaldset(a,1,2); % at least 2 object per class, 2 classes
41        if isempty(b)
42                w = a*classf;
43                return
44        end
45        if size(a,2) ~= size(b,2)
46                error('Datasets should have same number of features')
47        end
48       
49        c = getsize(a,3);
50        epsilon = 1e-6;
51        change = 1;
52        nlab = getnlab(a);
53        lablist = getlablist(a);
54  p = getprior(a);
55        a = setlabels(a,nlab);
56  a = setprior(a,p);
57        a = setlabtype(a,labtype);
58        switch labtype
59                case 'crisp'
60                        lab = zeros(size(b,1),1);
61                case 'soft'
62                        lab = zeros(size(b,1),c);
63        end
64        b = prdataset(+b);
65        w = a*classf;
66       
67        while change > epsilon
68                d = b*w;
69                switch labtype
70                        case 'crisp'
71                                labb = d*labeld;
72                                change = mean(lab ~= labb);
73                        case 'soft'
74                                labb = d*classc;
75                                change = mean(mean((+(labb-lab)).^2));
76                        otherwise
77                                error('Wrong LABTYPE given')
78                end
79                lab = labb;
80                b = setlabtype(b,labtype,lab);
81                c = [setlabtype(a,labtype); b];
82                w = c*classf;
83        end
84       
85        J = getlabels(w);
86        w = setlabels(w,lablist(J,:));
87       
Note: See TracBrowser for help on using the repository browser.