IRootLab
An Open-Source MATLAB toolbox for vibrational biospectroscopy
clssr_d.m
Go to the documentation of this file.
1 %> @brief Linear and Quadratic discriminant
2 %>
3 %> Fits a Gaussian to each class. In the linear case, sa polled co-variance matrix is calculated. In the quadratic case, each class has its
4 %> own co-variance matrix
5 %>
6 %> The problem with MATLAB's classify() is that it is not possible to call the training and use separately.
7 %>
8 %> @sa uip_clssr_d.m
9 classdef clssr_d < clssr
10  properties
11  %> ='linear'. Possibilities: 'linear' or 'quadratic'.
12  type = 'linear';
13  %> =1. Whether or not to use priors. Setting to ZERO is a way to account for unbalanced classes.
14  flag_use_priors = 1;
15  end;
16 
17  properties(SetAccess=protected)
18  % Inverse of polled covariance matrix (linear case)
19  invcov;
20  % Determinant of polled covariance matrix (linear case)
21  detcov;
22  % Inverse of covariance matrices per class (quadratic case) (cell of matrices).
23  invcovs;
24  % Determinant of polled covariance matrices (quadratic case) (vector of scalars).
25  detcovs;
26  % [nc]x[nf] matrix. Class means
27  means = {};
28  % Probability of belonging to each class - calculated from training data
29  priors;
30  R = {};
31  logDetSigma;
32  end;
33 
34 
35  properties(Access=private)
36  X;
37  classes;
38  end;
39 
40  methods
41  function o = clssr_d(o)
42  o.classtitle = 'Gaussian fit';
43  end;
44 
45 % Better not implement these things %> If title is not empty, will not mess with description too much
46 % function s = get_description(o)
47 % if ~isempty(o.title)
48 % s = get_description@clssr(o);
49 % else
50 % s = [get_description@clssr(o), ' type = ', o.type];
51 % end;
52 % end;
53  end;
54 
55  methods(Access=protected)
56 
57 
58  %> Bits extracted fro MATLAB's classify()
59  %>
60  function o = do_train(o, data)
61  o.classlabels = data.classlabels;
62 
63  t = tic();
64 
65  o.means = zeros(data.nc, data.nf);
66  for k = data.nc:-1:1 % Backwards for allocation
67  Xtemp = data.X(data.classes == k-1, :);
68  nonow = size(Xtemp, 1);
69 
70  o.means(k, :) = mean(Xtemp);
71  if o.flag_use_priors
72  o.priors(k) = nonow;
73  else
74  o.priors(k) = 1;
75  end;
76  end;
77  o.priors = o.priors/sum(o.priors);
78 
79  switch o.type
80  case 'linear'
81  % Pooled estimate of covariance. Do not do pivoting, so that A can be
82  % computed without unpermuting. Instead use SVD to find rank of R.
83  [Q,R] = qr(data.X - o.means(data.classes+1, :), 0); %#ok<*PROP>
84  o.R = R / sqrt(data.no-data.nc); % SigmaHat = R'*R
85  s = svd(R);
86  if any(s <= max(data.no, data.nf) * eps(max(s)))
87  irerror(sprintf('The pooled covariance matrix of TRAINING must be positive definite. There are probably too few spectra (%d) or too many variables (%d)!', data.no, data.nf));
88  end
89  o.logDetSigma = 2*sum(log(s)); % avoid over/underflow
90 
91  case 'quadratic'
92 
93  o.R = cell(1, data.nc);
94 
95  o.logDetSigma = zeros(data.nc, 1);
96  for k = 1:data.nc
97  Xtemp = data.X(data.classes == k-1, :);
98  nonow = size(Xtemp, 1);
99  % o.means{k} = mean(Xtemp);
100 
101  % Stratified estimate of covariance. Do not do pivoting, so that A
102  % can be computed without unpermuting. Instead use SVD to find rank
103  % of R.
104  [Q,Rk] = qr(bsxfun(@minus, Xtemp, o.means(k, :)), 0);
105  o.R{k} = Rk / sqrt(nonow - 1); % SigmaHat = R'*R
106  s = svd(o.R{k});
107  if any(s <= max(nonow,data.nf) * eps(max(s)))
108  irerror(sprintf(['The covariance of each class in TRAINING must ',...
109  'be positive definite. There are probably too few spectra in ', ...
110  'class "%s" (%d) or too many variables (%d)!'], ...
111  data.classlabels{k}, size(Xtemp, 1), data.nf));
112 % irerror('The covariance matrix of each group in TRAINING must be positive definite. There are probably too few spectra or too many variables!');
113  end
114  o.logDetSigma(k) = 2*sum(log(s)); % avoid over/underflow
115 
116  end;
117 
118  otherwise
119  irerror(sprintf('Unknown type: %s', o.type));
120  end;
121 
122  o.time_train = toc(t);
123  end;
124 
125 
126  %> With bits from MATLAB classify()
127  function est = do_use(o, data)
128  est = estimato();
129  est.classlabels = o.classlabels;
130  est = est.copy_from_data(data);
131 
132  t = tic();
133 
134  nc = numel(o.classlabels);
135 
136  posteriors = zeros(data.no, nc);
137 
138  switch o.type
139  case 'linear'
140  % MVN relative log posterior density, by group, for each sample
141  for k = 1:nc
142  A = bsxfun(@minus, data.X, o.means(k,:)) / o.R;
143  posteriors(:,k) = o.priors(k)*exp(-.5*(sum(A.*A, 2)+o.logDetSigma));
144  end
145 
146  case 'quadratic'
147  for k = 1:nc
148  A = bsxfun(@minus, data.X, o.means(k, :))/o.R{k};
149 
150  % MVN relative log posterior density, by group, for each sample
151 % D(:,k) = log(prior(k)) - .5*(sum(A .* A, 2) + logDetSigma(k));
152  posteriors(:,k) = o.priors(k)*exp(-.5*(sum(A.*A, 2)+o.logDetSigma(k)));
153 % posteriors(:,k) = log(o.priors(k))-.5*(sum(A.*A, 2)+o.logDetSigma(k));
154  end;
155  end
156 
157  posteriors = normalize_rows(posteriors);
158 
159  est.X = posteriors;
160  o.time_use = toc(t);
161  end;
162 
163  end;
164 end