[5] | 1 | %EMC EM Classifier using semi-supervised data
|
---|
| 2 | %
|
---|
| 3 | % W = EMC(A,B,CLASSF,LABTYPE,FID)
|
---|
| 4 | % W = A*EMC([],B,CLASSF,LABTYPE,FID)
|
---|
| 5 | %
|
---|
| 6 | % INPUT
|
---|
| 7 | % A Labeled dataset used for training
|
---|
| 8 | % B Additional unlabeled dataset
|
---|
| 9 | % CLASSF Untrained classifier (default QDC)
|
---|
| 10 | % LABTYPE Label type to be used (crisp (default) or soft)
|
---|
| 11 | % FID File ID to write progress to (default [], see PRPROGRESS)
|
---|
| 12 | %
|
---|
| 13 | % OUTPUT
|
---|
| 14 | % W Trained classifier
|
---|
| 15 | %
|
---|
| 16 | % DESCRIPTION
|
---|
| 17 | % Using the EM algorithm the classifier CLASSF is used iteratively
|
---|
| 18 | % on the joint dataset [A;B]. In each step the labels of A are reset
|
---|
| 19 | % to their initial values. Initial labels in B are neglected.
|
---|
| 20 | % Labels of LABTYPE 'soft' are not supported by all classifiers.
|
---|
| 21 | %
|
---|
| 22 | % SEE ALSO
|
---|
| 23 | % DATASETS, MAPPINGS, EMCLUST, PRPROGRESS
|
---|
| 24 |
|
---|
| 25 | % Copyright: R.P.W. Duin, r.p.w.duin@prtools.org
|
---|
| 26 | % Faculty EWI, Delft University of Technology
|
---|
| 27 | % P.O. Box 5031, 2600 GA Delft, The Netherlands
|
---|
| 28 |
|
---|
[100] | 29 | function w = emc(a,b,classf,labtype)
|
---|
[5] | 30 | if nargin < 4 | isempty(labtype), labtype = 'crisp'; end
|
---|
| 31 | if nargin < 3 | isempty(classf), classf = qdc; end
|
---|
| 32 | if nargin < 2, b = []; end
|
---|
| 33 | if nargin < 1 | isempty(a)
|
---|
| 34 | w = mapping(mfilename,'untrained',{b,classf,labtype,fid});
|
---|
| 35 | w = setname(w,'EM CLassifier');
|
---|
| 36 | return
|
---|
| 37 | end
|
---|
| 38 |
|
---|
| 39 | islabtype(a,'crisp','soft');
|
---|
| 40 | isvaldset(a,1,2); % at least 2 object per class, 2 classes
|
---|
| 41 | if isempty(b)
|
---|
| 42 | w = a*classf;
|
---|
| 43 | return
|
---|
| 44 | end
|
---|
| 45 | if size(a,2) ~= size(b,2)
|
---|
| 46 | error('Datasets should have same number of features')
|
---|
| 47 | end
|
---|
| 48 |
|
---|
| 49 | c = getsize(a,3);
|
---|
| 50 | epsilon = 1e-6;
|
---|
| 51 | change = 1;
|
---|
| 52 | nlab = getnlab(a);
|
---|
| 53 | lablist = getlablist(a);
|
---|
[100] | 54 | p = getprior(a);
|
---|
[5] | 55 | a = setlabels(a,nlab);
|
---|
[100] | 56 | a = setprior(a,p);
|
---|
[5] | 57 | a = setlabtype(a,labtype);
|
---|
| 58 | switch labtype
|
---|
| 59 | case 'crisp'
|
---|
| 60 | lab = zeros(size(b,1),1);
|
---|
| 61 | case 'soft'
|
---|
| 62 | lab = zeros(size(b,1),c);
|
---|
| 63 | end
|
---|
[100] | 64 | b = prdataset(+b);
|
---|
[5] | 65 | w = a*classf;
|
---|
| 66 |
|
---|
| 67 | while change > epsilon
|
---|
| 68 | d = b*w;
|
---|
| 69 | switch labtype
|
---|
| 70 | case 'crisp'
|
---|
| 71 | labb = d*labeld;
|
---|
| 72 | change = mean(lab ~= labb);
|
---|
| 73 | case 'soft'
|
---|
| 74 | labb = d*classc;
|
---|
| 75 | change = mean(mean((+(labb-lab)).^2));
|
---|
| 76 | otherwise
|
---|
| 77 | error('Wrong LABTYPE given')
|
---|
| 78 | end
|
---|
| 79 | lab = labb;
|
---|
| 80 | b = setlabtype(b,labtype,lab);
|
---|
[109] | 81 | c = [setlabtype(a,labtype); b];
|
---|
[5] | 82 | w = c*classf;
|
---|
| 83 | end
|
---|
| 84 |
|
---|
| 85 | J = getlabels(w);
|
---|
| 86 | w = setlabels(w,lablist(J,:));
|
---|
| 87 | |
---|