IRootLab
An Open-Source MATLAB toolbox for vibrational biospectroscopy
gridsearch.m
Go to the documentation of this file.
1 %> @brief Grid Search
2 %>
3 %> Grid search is an simple iterative way of optimization that avoids using the gradient of the objective function (F(.)). Instead, it
4 %> evaluates F(.) at all points within a point grid and finds to find the maximum value. A new, finer grid is then formed
5 %> around the point that corresponds to this maximum and the process follows.
6 %>
7 %> The idea came from the guide that comes with LibSVM [1].
8 %>
9 %> Rather than passing a vector from the domain to the objective function, this grid search assigns values to fields
10 %> within a structure ('.obj') and calls the objective function ('.f_get_rate()') passing '.obj' as a parameter.
11 %>
12 %> Grid search is of course not restricted to SVM neither to classifiers.
13 %>
14 %> Not published in the GUI at the moment.
15 %>
16 %> <h3>References:</h3>
17 %> [1] http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf.
18 %>
20 classdef gridsearch < as
21  properties
22  %> SGS. If not supplied, the @ref data property will be expected to
23  %> have at least two elements. Another option is to use obsidxs
24  %> instead
25  sgs;
26  %> Mold classifier
27  clssr;
28  %> =(automatic). Cell array. Molds for the recording.
29  %>
30  %> Automatic logs are rates, time_train, and time_use.
31  %>
32  %> If passed, make sure that first log is rate, and second is time
33  %>
34  %> log titles will become fields inside the sovaluess.values
35  %>
36  log_mold;
37  %> (Optional) Block to post-process the test data. For example, a @ref grag_classes_first.
38  postpr_test;
39  %> Block to post-process the estimation issued by the classifier. Examples:
40  %> @arg a @ref decider
41  %> @arg a @block_cascade_base consisting of a @ref decider followed by a @ref grag_classes_vote
42  %> There isn't a default, this must be provided
43  postpr_est;
44 
45  %> Number of times to zoom close to best point
46  no_refinements = 0;
47  %> =3. Maximum number of tries per iteration. A try counts when the chosen item was on any edge. In this case, the search space will
48  %> be shifted to have the chosen in the middle, without refinements
49  maxmoves = 3;
50  %> Array of gridsearchparam objects
51  params = gridsearchparam.empty;
52  %> Parameters specifications in a cell
53  %>
54  %> If provided, will override @ref gridsearch::params
55  paramspecs;
56  %> =0. Whether to run in parallel mode!
57  %>
58  %> @sa reptt_blockcube::flag_parallel
59  flag_parallel = 0;
60  %> Chooser object
61  chooser;
62  end;
63 
64  methods
65  function o = gridsearch()
66  o.classtitle = 'Grid Search';
67  o.flag_ui = 1;
68  o.flag_params = 1;
69  end;
70 
71  %> Adds parameter using @ref gridsearchparam constructor with varargin
72  function o = add_param(o, varargin)
73  if numel(o.params) >= 3
74  irerror('Grid search can handle a maximum of 3 variables!');
75  end;
76  o.params(end+1) = gridsearchparam(varargin{:});
77  end;
78 
79  function o = assert(o)
80  no_dims = length(o.params);
81  if no_dims < 1
82  irerror('No paramaters for gridsearch!');
83  end;
84 
85  if o.no_refinements > 0 && ~all([o.params.flag_numeric])
86  irerror('In order to refine search, all parameters must be numeric!');
87  end;
88  end;
89 
90  function o = make_defaults(o, data)
91  if isempty(o.log_mold)
92  ott = ttlogprovider();
93  o.log_mold = ott.get_ttlogs(data);
94  end;
95 
96  if isempty(o.chooser)
97  ch = chooser(); %#ok<*CPROP,*PROP>
98  ch.rate_maxloss = 0.001;
99  ch.time_mingain = 0.4500;
100 
101  idx = find(cellfun(@(x) strcmp(x.title, 'rates'), o.log_mold)); %#ok<*EFIND>
102  if isempty(idx)
103  ch.ratesname = o.log_mold{1}.title;
104  % else assumes chooser_base default, which is 'rates'
105  end;
106  idx = find(cellfun(@(x) strcmp(x.title, 'times3'), o.log_mold));
107  if isempty(idx)
108  ch.timesname = o.log_mold{2}.title;
109  % else assumes chooser_base default, which is 'times3'
110  end;
111  o.chooser = ch;
112  end;
113 
114  if ~isempty(o.paramspecs)
115  o.params = gridsearchparam.empty();
116  for i = 1:size(o.paramspecs, 1)
117  o.params(i) = gridsearchparam(o.paramspecs{i, :});
118  end;
119  end;
120 
121 % Won't make default post-processors anymore
122 % if isempty(o.postpr_est)
123 % o.postpr_est = def_postpr_est();
124 % o.postpr_test = def_postpr_test(); % Overrides pospr_test because need a harmonic pair
125 % end;
126  end;
127  end;
128 
129  methods(Access=protected)
130  function log = do_use(o, data)
131  o = o.make_defaults(data);
132  o.assert();
133 
134  u = reptt_blockcube();
135  u.log_mold = o.log_mold;
136  u.sgs = o.sgs;
137  u.flag_parallel = o.flag_parallel;
138  u.postpr_test = o.postpr_test;
139  u.postpr_est = o.postpr_est;
140  moldcube = u;
141 
142  params = o.params;
143 
144  % get lengths
145  nj = numel(params);
146  nv = 1;
147  for j = 1:nj
148  nvv(j) = numel(params(j).values);
149  nv = nv*nvv(j);
150  ticklabelss{j} = params(j).get_ticklabels();
151  end;
152 
153  log = log_gridsearch();
154 
155  if o.flag_parallel
156  parallel_open();
157  end;
158 
159  % main loop
160  irefin = 0;
161  iiter = 1;
162  imove = 0;
163  nExpected = o.no_refinements+1; % Expected iterations
164  ipro = progress2_open('GRIDSEARCH', [], 0, nExpected);
165  while 1
166  s_it = sprintf('Iteration: %d (refinement: %d; move: %d)', iiter, irefin, imove);
167 
168  irverbose ('**************', 2);
169  irverbose(['************** Grid search ', s_it], 2);
170  irverbose ('**************', 2);
171 
172  % Creates sovalues
173  sov = sovalues();
174  sov.title = s_it;
175  sov.chooser = o.chooser;
176  for j = 1:nj
177  p = params(j);
178  ax = raxisdata();
179  ax.label = p.get_label();
180  ax.values = p.get_values_numeric();
181  ax.ticks = p.get_ticklabels();
182  ax.legends = p.get_legends();
183  sov.ax(j) = ax;
184  end;
185 
186 
187  % make block_cube
188  idxs = cell(1, nj);
189  for q = 1:nv
190  % Parameter setting
191  r = q;
192  blk = o.clssr;
193  s_spec = '';
194  for j = 1:nj
195  p = params(j);
196  idx = mod(r-1, nvv(j))+1;
197  r = floor((r-1)/nvv(j))+1;
198 
199  eval(sprintf('blk.%s = p.get_value(idx);', p.name)); % Sets block value
200  s_spec = cat(2, s_spec, iif(j > 1, ', ', ''), p.name, '=', p.get_value_string(idx)); %ticklabelss{j}{idx});
201  idxs{j} = idx;
202  end;
203  blk.title = s_spec;
204 % fprintf('%05d %s\n', q, s_spec);
205 
206  if nj == 1
207  idxs{2} = 1;
208  end;
209 
210  % populate the block_cube
211  molds{idxs{:}} = blk;
212  specs{idxs{:}} = s_spec;
213  idxss{idxs{:}} = idxs;
214  end;
215 
216  % Runs stuff
217  cube = moldcube;
218  cube.block_mold = molds;
219  cubelog = cube.use(data);
220 
221 
222  % Collects results
223  sov = sov.read_log_cube(cubelog, []);
224  sov = sov.set_field('spec', specs);
225  sov = sov.set_field('mold', molds);
226  sov = sov.set_field('idxs', idxss);
227 
228  if nj == 1
229  sov.ax(2) = raxisdata_singleton();
230  end;
231 
232  log.sovaluess(iiter) = sov;
233 
234 
235 
236  [item, idxs] = sov.choose_one();
237  idxs = idxs(1:nj);
238  idxs = cell2mat(idxs);
239  flag_edge = any(idxs == 1) || any(idxs == nvv); % Whether any value was on the edge
240  flag_shrink = 0;
241  flag_moved = 0;
242  if flag_edge
243  if imove >= o.maxmoves
244  irverbose(sprintf('Still hit the edge after %d moves', imove));
245  flag_shrink = 1;
246  else
247  irverbose('Hit the edge, will move to have best point in the centre', 1);
248  for j = 1:nj
249  params(j) = params(j).move_to(idxs(j));
250  end;
251  imove = imove+1;
252  nExpected = nExpected+1;
253  flag_moved = 1;
254  end;
255  else
256  flag_shrink = 1;
257  end;
258 
259  if flag_shrink && irefin < o.no_refinements
260  % Prepares for refinement
261  for j = 1:nj
262  params(j) = params(j).shrink_around(idxs(j));
263  end;
264  irefin = irefin+1;
265  imove = 0;
266  elseif ~flag_moved
267  break;
268  end;
269  iiter = iiter+1;
270 
271  ipro = progress2_change(ipro, [], [], iiter, nExpected);
272  end;
273  progress2_close(ipro);
274 
275  if o.flag_parallel
276  % I don't need a try..."finally" for this, not critical, really
277  parallel_close();
278  end;
279  end;
280  end;
281 
282 
283  methods(Access=protected)
284  function v = get_ticks(o, centre, length, no_points)
285  x1 = centre-length/2;
286  x2 = centre+length/2;
287  v = linspace(x1, x2, no_points);
288  end
289 
290  end;
291 end
Base Sub-dataset Generation Specification (SGS) class.
Definition: sgs.m:6
Group Aggregator - Classes - Vote.
function parallel_close()
function irerror(in s)
Block that resolves estimato posterior probabilities into classes.
Definition: decider.m:10
Classifiers base class.
Definition: clssr.m:6
Group Aggregator - Classes - First row.
Analysis Session (AS) base class.
Definition: as.m:6
Grid Search.
Definition: gridsearch.m:20
Cascade block: sequence of blocks represented by a block.
function get_value(in o, in idx)
REpeated Train-Test - Block Cube.