1 %> @brief Linear and Quadratic discriminant
3 %> Fits a Gaussian to each
class. In the linear
case, sa polled co-variance matrix is calculated. In the quadratic
case, each
class has its
4 %> own co-variance matrix
6 %> The problem with MATLAB
's classify() is that it is not possible to call the training and use separately.
9 classdef clssr_d < clssr
11 %> ='linear
'. Possibilities: 'linear
' or 'quadratic
'.
13 %> =1. Whether or not to use priors. Setting to ZERO is a way to account for unbalanced classes.
17 properties(SetAccess=protected)
18 % Inverse of polled covariance matrix (linear case)
20 % Determinant of polled covariance matrix (linear case)
22 % Inverse of covariance matrices per class (quadratic case) (cell of matrices).
24 % Determinant of polled covariance matrices (quadratic case) (vector of scalars).
26 % [nc]x[nf] matrix. Class means
28 % Probability of belonging to each class - calculated from training data
35 properties(Access=private)
41 function o = clssr_d(o)
42 o.classtitle = 'Gaussian fit
';
45 % Better not implement these things %> If title is not empty, will not mess with description too much
46 % function s = get_description(o)
47 % if ~isempty(o.title)
48 % s = get_description@clssr(o);
50 % s = [get_description@clssr(o), ' type =
', o.type];
55 methods(Access=protected)
58 %> Bits extracted fro MATLAB's classify()
60 function o = do_train(o, data)
61 o.classlabels = data.classlabels;
65 o.means = zeros(data.nc, data.nf);
66 for k = data.nc:-1:1 % Backwards for allocation
67 Xtemp = data.X(data.classes == k-1, :);
68 nonow = size(Xtemp, 1);
70 o.means(k, :) = mean(Xtemp);
77 o.priors = o.priors/sum(o.priors);
81 % Pooled estimate of covariance. Do not do pivoting, so that A can be
82 % computed without unpermuting. Instead use SVD to find rank of R.
83 [Q,R] = qr(data.X - o.means(data.classes+1, :), 0); %
#ok<*PROP>
84 o.R = R / sqrt(data.no-data.nc); % SigmaHat = R
'*R
86 if any(s <= max(data.no, data.nf) * eps(max(s)))
87 irerror(sprintf('The pooled covariance matrix of TRAINING must be positive definite. There are probably too few spectra (%d) or too many variables (%d)!', data.no, data.nf));
89 o.logDetSigma = 2*sum(log(s)); % avoid over/underflow
93 o.R = cell(1, data.nc);
95 o.logDetSigma = zeros(data.nc, 1);
97 Xtemp = data.X(data.classes == k-1, :);
98 nonow = size(Xtemp, 1);
99 % o.means{k} = mean(Xtemp);
101 % Stratified estimate of covariance. Do not
do pivoting, so that A
102 % can be computed without unpermuting. Instead use SVD to find rank
104 [Q,Rk] = qr(bsxfun(@minus, Xtemp, o.means(k, :)), 0);
105 o.R{k} = Rk / sqrt(nonow - 1); % SigmaHat = R
'*R
107 if any(s <= max(nonow,data.nf) * eps(max(s)))
108 irerror(sprintf(['The covariance of each
class in TRAINING must ',...
109 'be positive definite. There are probably too few spectra in ', ...
110 'class "%s
" (%d) or too many variables (%d)!'], ...
111 data.classlabels{k}, size(Xtemp, 1), data.nf));
112 % irerror('The covariance matrix of each group in TRAINING must be positive definite. There are probably too few spectra or too many variables!');
114 o.logDetSigma(k) = 2*sum(log(s)); % avoid over/underflow
119 irerror(sprintf('Unknown type: %s', o.type));
122 o.time_train = toc(t);
126 %> With bits from MATLAB classify()
127 function est = do_use(o, data)
129 est.classlabels = o.classlabels;
130 est = est.copy_from_data(data);
134 nc = numel(o.classlabels);
136 posteriors = zeros(data.no, nc);
140 % MVN relative log posterior density, by group, for each sample
142 A = bsxfun(@minus, data.X, o.means(k,:)) / o.R;
143 posteriors(:,k) = o.priors(k)*exp(-.5*(sum(A.*A, 2)+o.logDetSigma));
148 A = bsxfun(@minus, data.X, o.means(k, :))/o.R{k};
150 % MVN relative log posterior density, by group, for each sample
151 % D(:,k) = log(prior(k)) - .5*(sum(A .* A, 2) + logDetSigma(k));
152 posteriors(:,k) = o.priors(k)*exp(-.5*(sum(A.*A, 2)+o.logDetSigma(k)));
153 % posteriors(:,k) = log(o.priors(k))-.5*(sum(A.*A, 2)+o.logDetSigma(k));
157 posteriors = normalize_rows(posteriors);