[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_common.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35
36
37#ifndef VIGRA_RF_COMMON_HXX
38#define VIGRA_RF_COMMON_HXX
39
40namespace vigra
41{
42
43
44struct ClassificationTag
45{};
46
47struct RegressionTag
48{};
49
50namespace detail
51{
52 class RF_DEFAULT;
53}
54inline detail::RF_DEFAULT& rf_default();
55namespace detail
56{
57
58/* \brief singleton default tag class -
59 *
60 * use the rf_default() factory function to use the tag.
61 * \sa RandomForest<>::learn();
62 */
63class RF_DEFAULT
64{
65 private:
66 RF_DEFAULT()
67 {}
68 public:
69 friend RF_DEFAULT& ::vigra::rf_default();
70
71 /** ok workaround for automatic choice of the decisiontree
72 * stackentry.
73 */
74};
75
76/* \brief chooses between default type and type supplied
77 *
78 * This is an internal class and you shouldn't really care about it.
79 * Just pass on used in RandomForest.learn()
80 * Usage:
81 *\code
82 * // example: use container type supplied by user or ArrayVector if
83 * // rf_default() was specified as argument;
84 * template<class Container_t>
85 * void do_some_foo(Container_t in)
86 * {
87 * typedef ArrayVector<int> Default_Container_t;
88 * Default_Container_t default_value;
89 * Value_Chooser<Container_t, Default_Container_t>
90 * choose(in, default_value);
91 *
92 * // if the user didn't care and the in was of type
93 * // RF_DEFAULT then default_value is used.
94 * do_some_more_foo(choose.value());
95 * }
96 * Value_Chooser choose_val<Type, Default_Type>
97 *\endcode
98 */
99template<class T, class C>
100class Value_Chooser
101{
102public:
103 typedef T type;
104 static T & choose(T & t, C &)
105 {
106 return t;
107 }
108};
109
110template<class C>
111class Value_Chooser<detail::RF_DEFAULT, C>
112{
113public:
114 typedef C type;
115
116 static C & choose(detail::RF_DEFAULT &, C & c)
117 {
118 return c;
119 }
120};
121
122
123
124
125} //namespace detail
126
127
128/**\brief factory function to return a RF_DEFAULT tag
129 * \sa RandomForest<>::learn()
130 */
131detail::RF_DEFAULT& rf_default()
132{
133 static detail::RF_DEFAULT result;
134 return result;
135}
136
137/** tags used with the RandomForestOptions class
138 * \sa RF_Traits::Option_t
139 */
140enum RF_OptionTag { RF_EQUAL,
141 RF_PROPORTIONAL,
142 RF_EXTERNAL,
143 RF_NONE,
144 RF_FUNCTION,
145 RF_LOG,
146 RF_SQRT,
147 RF_CONST,
148 RF_ALL};
149
150
151/** \addtogroup MachineLearning
152**/
153//@{
154
155/**\brief Options object for the random forest
156 *
157 * usage:
158 * RandomForestOptions a = RandomForestOptions()
159 * .param1(value1)
160 * .param2(value2)
161 * ...
162 *
163 * This class only contains options/parameters that are not problem
164 * dependent. The ProblemSpec class contains methods to set class weights
165 * if necessary.
166 *
167 * Note that the return value of all methods is *this which makes
168 * concatenating of options as above possible.
169 */
171{
172 public:
173 /**\name sampling options*/
174 /*\{*/
175 // look at the member access functions for documentation
176 double training_set_proportion_;
177 int training_set_size_;
178 int (*training_set_func_)(int);
180 training_set_calc_switch_;
181
182 bool sample_with_replacement_;
184 stratification_method_;
185
186
187 /**\name general random forest options
188 *
189 * these usually will be used by most split functors and
190 * stopping predicates
191 */
192 /*\{*/
193 RF_OptionTag mtry_switch_;
194 int mtry_;
195 int (*mtry_func_)(int) ;
196
197 bool predict_weighted_;
198 int tree_count_;
199 int min_split_node_size_;
200 bool prepare_online_learning_;
201 /*\}*/
202
204 typedef std::map<std::string, double_array> map_type;
205
206 int serialized_size() const
207 {
208 return 12;
209 }
210
211
212 bool operator==(RandomForestOptions & rhs) const
213 {
214 bool result = true;
215 #define COMPARE(field) result = result && (this->field == rhs.field);
216 COMPARE(training_set_proportion_);
217 COMPARE(training_set_size_);
218 COMPARE(training_set_calc_switch_);
219 COMPARE(sample_with_replacement_);
220 COMPARE(stratification_method_);
221 COMPARE(mtry_switch_);
222 COMPARE(mtry_);
223 COMPARE(tree_count_);
224 COMPARE(min_split_node_size_);
225 COMPARE(predict_weighted_);
226 #undef COMPARE
227
228 return result;
229 }
230 bool operator!=(RandomForestOptions & rhs_) const
231 {
232 return !(*this == rhs_);
233 }
234 template<class Iter>
235 void unserialize(Iter const & begin, Iter const & end)
236 {
237 Iter iter = begin;
238 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
239 "RandomForestOptions::unserialize():"
240 "wrong number of parameters");
241 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
242 PULL(training_set_proportion_, double);
243 PULL(training_set_size_, int);
244 ++iter; //PULL(training_set_func_, double);
245 PULL(training_set_calc_switch_, (RF_OptionTag)int);
246 PULL(sample_with_replacement_, 0 != );
247 PULL(stratification_method_, (RF_OptionTag)int);
248 PULL(mtry_switch_, (RF_OptionTag)int);
249 PULL(mtry_, int);
250 ++iter; //PULL(mtry_func_, double);
251 PULL(tree_count_, int);
252 PULL(min_split_node_size_, int);
253 PULL(predict_weighted_, 0 !=);
254 #undef PULL
255 }
256 template<class Iter>
257 void serialize(Iter const & begin, Iter const & end) const
258 {
259 Iter iter = begin;
260 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
261 "RandomForestOptions::serialize():"
262 "wrong number of parameters");
263 #define PUSH(item_) *iter = double(item_); ++iter;
264 PUSH(training_set_proportion_);
265 PUSH(training_set_size_);
266 if(training_set_func_ != 0)
267 {
268 PUSH(1);
269 }
270 else
271 {
272 PUSH(0);
273 }
274 PUSH(training_set_calc_switch_);
275 PUSH(sample_with_replacement_);
276 PUSH(stratification_method_);
277 PUSH(mtry_switch_);
278 PUSH(mtry_);
279 if(mtry_func_ != 0)
280 {
281 PUSH(1);
282 }
283 else
284 {
285 PUSH(0);
286 }
287 PUSH(tree_count_);
288 PUSH(min_split_node_size_);
289 PUSH(predict_weighted_);
290 #undef PUSH
291 }
292
293 void make_from_map(map_type & in) // -> const: .operator[] -> .find
294 {
295 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
296 #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
297 PULL(training_set_proportion_,double);
298 PULL(training_set_size_, int);
299 PULL(mtry_, int);
300 PULL(tree_count_, int);
301 PULL(min_split_node_size_, int);
302 PULLBOOL(sample_with_replacement_, bool);
303 PULLBOOL(prepare_online_learning_, bool);
304 PULLBOOL(predict_weighted_, bool);
305
306 PULL(training_set_calc_switch_, (RF_OptionTag)(int));
307
308 PULL(stratification_method_, (RF_OptionTag)(int));
309 PULL(mtry_switch_, (RF_OptionTag)(int));
310
311 /*don't pull*/
312 //PULL(mtry_func_!=0, int);
313 //PULL(training_set_func,int);
314 #undef PULL
315 #undef PULLBOOL
316 }
317 void make_map(map_type & in) const
318 {
319 #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_));
320 #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0));
321 PUSH(training_set_proportion_,double);
322 PUSH(training_set_size_, int);
323 PUSH(mtry_, int);
324 PUSH(tree_count_, int);
325 PUSH(min_split_node_size_, int);
326 PUSH(sample_with_replacement_, bool);
327 PUSH(prepare_online_learning_, bool);
328 PUSH(predict_weighted_, bool);
329
330 PUSH(training_set_calc_switch_, RF_OptionTag);
331 PUSH(stratification_method_, RF_OptionTag);
332 PUSH(mtry_switch_, RF_OptionTag);
333
334 PUSHFUNC(mtry_func_, int);
335 PUSHFUNC(training_set_func_,int);
336 #undef PUSH
337 #undef PUSHFUNC
338 }
339
340
341 /**\brief create a RandomForestOptions object with default initialisation.
342 *
343 * look at the other member functions for more information on default
344 * values
345 */
347 :
348 training_set_proportion_(1.0),
349 training_set_size_(0),
350 training_set_func_(0),
351 training_set_calc_switch_(RF_PROPORTIONAL),
352 sample_with_replacement_(true),
353 stratification_method_(RF_NONE),
354 mtry_switch_(RF_SQRT),
355 mtry_(0),
356 mtry_func_(0),
357 predict_weighted_(false),
358 tree_count_(255),
359 min_split_node_size_(1),
360 prepare_online_learning_(false)
361 {}
362
363 /**\brief specify stratification strategy
364 *
365 * default: RF_NONE
366 * possible values: RF_EQUAL, RF_PROPORTIONAL,
367 * RF_EXTERNAL, RF_NONE
368 * RF_EQUAL: get equal amount of samples per class.
369 * RF_PROPORTIONAL: sample proportional to fraction of class samples
370 * in population
371 * RF_EXTERNAL: strata_weights_ field of the ProblemSpec_t object
372 * has been set externally. (defunct)
373 */
375 {
376 vigra_precondition(in == RF_EQUAL ||
377 in == RF_PROPORTIONAL ||
378 in == RF_EXTERNAL ||
379 in == RF_NONE,
380 "RandomForestOptions::use_stratification()"
381 "input must be RF_EQUAL, RF_PROPORTIONAL,"
382 "RF_EXTERNAL or RF_NONE");
383 stratification_method_ = in;
384 return *this;
385 }
386
387 RandomForestOptions & prepare_online_learning(bool in)
388 {
389 prepare_online_learning_=in;
390 return *this;
391 }
392
393 /**\brief sample from training population with or without replacement?
394 *
395 * <br> Default: true
396 */
398 {
399 sample_with_replacement_ = in;
400 return *this;
401 }
402
403 /**\brief specify the fraction of the total number of samples
404 * used per tree for learning.
405 *
406 * This value should be in [0.0 1.0] if sampling without
407 * replacement has been specified.
408 *
409 * <br> default : 1.0
410 */
412 {
413 training_set_proportion_ = in;
414 training_set_calc_switch_ = RF_PROPORTIONAL;
415 return *this;
416 }
417
418 /**\brief directly specify the number of samples per tree
419 *
420 * This value should not be higher than the total number of
421 * samples if sampling without replacement has been specified.
422 */
424 {
425 training_set_size_ = in;
426 training_set_calc_switch_ = RF_CONST;
427 return *this;
428 }
429
430 /**\brief use external function to calculate the number of samples each
431 * tree should be learnt with.
432 *
433 * \param in function pointer that takes the number of rows in the
434 * learning data and outputs the number samples per tree.
435 */
437 {
438 training_set_func_ = in;
439 training_set_calc_switch_ = RF_FUNCTION;
440 return *this;
441 }
442
443 /**\brief weight each tree with number of samples in that node
444 */
446 {
447 predict_weighted_ = true;
448 return *this;
449 }
450
451 /**\brief use built in mapping to calculate mtry
452 *
453 * Use one of the built in mappings to calculate mtry from the number
454 * of columns in the input feature data.
455 * \param in possible values:
456 * - RF_LOG (the number of features considered for each split is \f$ 1+\lfloor \log(n_f)/\log(2) \rfloor \f$ as in Breiman's original paper),
457 * - RF_SQRT (default, the number of features considered for each split is \f$ \lfloor \sqrt{n_f} + 0.5 \rfloor \f$)
458 * - RF_ALL (all features are considered for each split)
459 */
461 {
462 vigra_precondition(in == RF_LOG ||
463 in == RF_SQRT||
464 in == RF_ALL,
465 "RandomForestOptions()::features_per_node():"
466 "input must be of type RF_LOG or RF_SQRT");
467 mtry_switch_ = in;
468 return *this;
469 }
470
471 /**\brief Set mtry to a constant value
472 *
473 * mtry is the number of columns/variates/variables randomly chosen
474 * to select the best split from.
475 *
476 */
478 {
479 mtry_ = in;
480 mtry_switch_ = RF_CONST;
481 return *this;
482 }
483
484 /**\brief use a external function to calculate mtry
485 *
486 * \param in function pointer that takes int (number of columns
487 * of the and outputs int (mtry)
488 */
490 {
491 mtry_func_ = in;
492 mtry_switch_ = RF_FUNCTION;
493 return *this;
494 }
495
496 /** How many trees to create?
497 *
498 * <br> Default: 255.
499 */
501 {
502 tree_count_ = in;
503 return *this;
504 }
505
506 /**\brief Number of examples required for a node to be split.
507 *
508 * When the number of examples in a node is below this number,
509 * the node is not split even if class separation is not yet perfect.
510 * Instead, the node returns the proportion of each class
511 * (among the remaining examples) during the prediction phase.
512 * <br> Default: 1 (complete growing)
513 */
515 {
516 min_split_node_size_ = in;
517 return *this;
518 }
519};
520
521
522/* \brief problem types
523 */
524enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
525
526
527/** \brief problem specification class for the random forest.
528 *
529 * This class contains all the problem specific parameters the random
530 * forest needs for learning. Specification of an instance of this class
531 * is optional as all necessary fields will be computed prior to learning
532 * if not specified.
533 *
534 * if needed usage is similar to that of RandomForestOptions
535 */
536
537template<class LabelType = double>
539{
540
541
542public:
543
544 /** \brief problem class
545 */
546
547 typedef LabelType Label_t;
548 ArrayVector<Label_t> classes;
550 typedef std::map<std::string, double_array> map_type;
551
552 int column_count_; // number of features
553 int class_count_; // number of classes
554 int row_count_; // number of samples
555
556 int actual_mtry_; // mtry used in training
557 int actual_msample_; // number if in-bag samples per tree
558
559 Problem_t problem_type_; // classification or regression
560
561 int used_; // this ProblemSpec is valid
562 ArrayVector<double> class_weights_; // if classes have different importance
563 int is_weighted_; // class_weights_ are used
564 double precision_; // termination criterion for regression loss
565 int response_size_;
566
567 template<class T>
568 void to_classlabel(int index, T & out) const
569 {
570 out = T(classes[index]);
571 }
572 template<class T>
573 int to_classIndex(T index) const
574 {
575 return std::find(classes.begin(), classes.end(), index) - classes.begin();
576 }
577
578 #define EQUALS(field) field(rhs.field)
579 ProblemSpec(ProblemSpec const & rhs)
580 :
581 EQUALS(column_count_),
582 EQUALS(class_count_),
583 EQUALS(row_count_),
584 EQUALS(actual_mtry_),
585 EQUALS(actual_msample_),
586 EQUALS(problem_type_),
587 EQUALS(used_),
588 EQUALS(class_weights_),
589 EQUALS(is_weighted_),
590 EQUALS(precision_),
591 EQUALS(response_size_)
592 {
593 std::back_insert_iterator<ArrayVector<Label_t> >
594 iter(classes);
595 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
596 }
597 #undef EQUALS
598 #define EQUALS(field) field(rhs.field)
599 template<class T>
600 ProblemSpec(ProblemSpec<T> const & rhs)
601 :
602 EQUALS(column_count_),
603 EQUALS(class_count_),
604 EQUALS(row_count_),
605 EQUALS(actual_mtry_),
606 EQUALS(actual_msample_),
607 EQUALS(problem_type_),
608 EQUALS(used_),
609 EQUALS(class_weights_),
610 EQUALS(is_weighted_),
611 EQUALS(precision_),
612 EQUALS(response_size_)
613 {
614 std::back_insert_iterator<ArrayVector<Label_t> >
615 iter(classes);
616 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
617 }
618 #undef EQUALS
619
620 #define EQUALS(field) (this->field = rhs.field);
621 ProblemSpec & operator=(ProblemSpec const & rhs)
622 {
623 EQUALS(column_count_);
624 EQUALS(class_count_);
625 EQUALS(row_count_);
626 EQUALS(actual_mtry_);
627 EQUALS(actual_msample_);
628 EQUALS(problem_type_);
629 EQUALS(used_);
630 EQUALS(is_weighted_);
631 EQUALS(precision_);
632 EQUALS(response_size_)
633 class_weights_.clear();
634 std::back_insert_iterator<ArrayVector<double> >
635 iter2(class_weights_);
636 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
637 classes.clear();
638 std::back_insert_iterator<ArrayVector<Label_t> >
639 iter(classes);
640 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
641 return *this;
642 }
643
644 template<class T>
645 ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs)
646 {
647 EQUALS(column_count_);
648 EQUALS(class_count_);
649 EQUALS(row_count_);
650 EQUALS(actual_mtry_);
651 EQUALS(actual_msample_);
652 EQUALS(problem_type_);
653 EQUALS(used_);
654 EQUALS(is_weighted_);
655 EQUALS(precision_);
656 EQUALS(response_size_)
657 class_weights_.clear();
658 std::back_insert_iterator<ArrayVector<double> >
659 iter2(class_weights_);
660 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
661 classes.clear();
662 std::back_insert_iterator<ArrayVector<Label_t> >
663 iter(classes);
664 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
665 return *this;
666 }
667 #undef EQUALS
668
669 template<class T>
670 bool operator==(ProblemSpec<T> const & rhs)
671 {
672 bool result = true;
673 #define COMPARE(field) result = result && (this->field == rhs.field);
674 COMPARE(column_count_);
675 COMPARE(class_count_);
676 COMPARE(row_count_);
677 COMPARE(actual_mtry_);
678 COMPARE(actual_msample_);
679 COMPARE(problem_type_);
680 COMPARE(is_weighted_);
681 COMPARE(precision_);
682 COMPARE(used_);
683 COMPARE(class_weights_);
684 COMPARE(classes);
685 COMPARE(response_size_)
686 #undef COMPARE
687 return result;
688 }
689
690 bool operator!=(ProblemSpec & rhs)
691 {
692 return !(*this == rhs);
693 }
694
695
696 size_t serialized_size() const
697 {
698 return 10 + class_count_ *int(is_weighted_+1);
699 }
700
701
702 template<class Iter>
703 void unserialize(Iter const & begin, Iter const & end)
704 {
705 Iter iter = begin;
706 vigra_precondition(end - begin >= 10,
707 "ProblemSpec::unserialize():"
708 "wrong number of parameters");
709 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
710 PULL(column_count_,int);
711 PULL(class_count_, int);
712
713 vigra_precondition(end - begin >= 10 + class_count_,
714 "ProblemSpec::unserialize(): 1");
715 PULL(row_count_, int);
716 PULL(actual_mtry_,int);
717 PULL(actual_msample_, int);
718 PULL(problem_type_, Problem_t);
719 PULL(is_weighted_, int);
720 PULL(used_, int);
721 PULL(precision_, double);
722 PULL(response_size_, int);
723 if(is_weighted_)
724 {
725 vigra_precondition(end - begin == 10 + 2*class_count_,
726 "ProblemSpec::unserialize(): 2");
727 class_weights_.insert(class_weights_.end(),
728 iter,
729 iter + class_count_);
730 iter += class_count_;
731 }
732 classes.insert(classes.end(), iter, end);
733 #undef PULL
734 }
735
736
737 template<class Iter>
738 void serialize(Iter const & begin, Iter const & end) const
739 {
740 Iter iter = begin;
741 vigra_precondition(end - begin == serialized_size(),
742 "RandomForestOptions::serialize():"
743 "wrong number of parameters");
744 #define PUSH(item_) *iter = double(item_); ++iter;
745 PUSH(column_count_);
746 PUSH(class_count_)
747 PUSH(row_count_);
748 PUSH(actual_mtry_);
749 PUSH(actual_msample_);
750 PUSH(problem_type_);
751 PUSH(is_weighted_);
752 PUSH(used_);
753 PUSH(precision_);
754 PUSH(response_size_);
755 if(is_weighted_)
756 {
757 std::copy(class_weights_.begin(),
758 class_weights_.end(),
759 iter);
760 iter += class_count_;
761 }
762 std::copy(classes.begin(),
763 classes.end(),
764 iter);
765 #undef PUSH
766 }
767
768 void make_from_map(map_type & in) // -> const: .operator[] -> .find
769 {
770 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
771 PULL(column_count_,int);
772 PULL(class_count_, int);
773 PULL(row_count_, int);
774 PULL(actual_mtry_,int);
775 PULL(actual_msample_, int);
776 PULL(problem_type_, (Problem_t)int);
777 PULL(is_weighted_, int);
778 PULL(used_, int);
779 PULL(precision_, double);
780 PULL(response_size_, int);
781 class_weights_ = in["class_weights_"];
782 #undef PULL
783 }
784 void make_map(map_type & in) const
785 {
786 #define PUSH(item_) in[#item_] = double_array(1, double(item_));
787 PUSH(column_count_);
788 PUSH(class_count_)
789 PUSH(row_count_);
790 PUSH(actual_mtry_);
791 PUSH(actual_msample_);
792 PUSH(problem_type_);
793 PUSH(is_weighted_);
794 PUSH(used_);
795 PUSH(precision_);
796 PUSH(response_size_);
797 in["class_weights_"] = class_weights_;
798 #undef PUSH
799 }
800
801 /**\brief set default values (-> values not set)
802 */
804 : column_count_(0),
805 class_count_(0),
806 row_count_(0),
807 actual_mtry_(0),
808 actual_msample_(0),
809 problem_type_(CHECKLATER),
810 used_(false),
811 is_weighted_(false),
812 precision_(0.0),
813 response_size_(1)
814 {}
815
816
817 ProblemSpec & column_count(int in)
818 {
819 column_count_ = in;
820 return *this;
821 }
822
823 /**\brief supply with class labels -
824 *
825 * the preprocessor will not calculate the labels needed in this case.
826 */
827 template<class C_Iter>
829 {
830 classes.clear();
831 int size = end-begin;
832 for(int k=0; k<size; ++k, ++begin)
833 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
834 class_count_ = size;
835 return *this;
836 }
837
838 /** \brief supply with class weights -
839 *
840 * this is the only case where you would really have to
841 * create a ProblemSpec object.
842 */
843 template<class W_Iter>
845 {
846 class_weights_.clear();
847 class_weights_.insert(class_weights_.end(), begin, end);
848 is_weighted_ = true;
849 return *this;
850 }
851
852
853
854 void clear()
855 {
856 used_ = false;
857 classes.clear();
858 class_weights_.clear();
859 column_count_ = 0 ;
860 class_count_ = 0;
861 actual_mtry_ = 0;
862 actual_msample_ = 0;
863 problem_type_ = CHECKLATER;
864 is_weighted_ = false;
865 precision_ = 0.0;
866 response_size_ = 0;
867
868 }
869
870 bool used() const
871 {
872 return used_ != 0;
873 }
874};
875
876
877//@}
878
879
880
881/**\brief Standard early stopping criterion
882 *
883 * Stop if region.size() < min_split_node_size_;
884 */
886{
887 public:
888 int min_split_node_size_;
889
890 template<class Opt>
892 : min_split_node_size_(opt.min_split_node_size_)
893 {}
894
895 template<class T>
896 void set_external_parameters(ProblemSpec<T>const &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false)
897 {}
898
899 template<class Region>
900 bool operator()(Region& region)
901 {
902 return region.size() < min_split_node_size_;
903 }
904
905 template<class WeightIter, class T, class C>
906 bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */)
907 {
908 return false;
909 }
910};
911
912
913} // namespace vigra
914
915#endif //VIGRA_RF_COMMON_HXX
Standard early stopping criterion.
Definition rf_common.hxx:886
problem specification class for the random forest.
Definition rf_common.hxx:539
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition rf_common.hxx:828
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition rf_common.hxx:844
ProblemSpec()
set default values (-> values not set)
Definition rf_common.hxx:803
LabelType Label_t
problem class
Definition rf_common.hxx:547
Class for a single RGB value.
Definition rgbvalue.hxx:128
Options object for the random forest.
Definition rf_common.hxx:171
RandomForestOptions & features_per_node(RF_OptionTag in)
use built in mapping to calculate mtry
Definition rf_common.hxx:460
RandomForestOptions & features_per_node(int(*in)(int))
use a external function to calculate mtry
Definition rf_common.hxx:489
RandomForestOptions & features_per_node(int in)
Set mtry to a constant value.
Definition rf_common.hxx:477
RandomForestOptions & sample_with_replacement(bool in)
sample from training population with or without replacement?
Definition rf_common.hxx:397
RandomForestOptions & min_split_node_size(int in)
Number of examples required for a node to be split.
Definition rf_common.hxx:514
RandomForestOptions & predict_weighted()
weight each tree with number of samples in that node
Definition rf_common.hxx:445
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition rf_common.hxx:374
RandomForestOptions()
create a RandomForestOptions object with default initialisation.
Definition rf_common.hxx:346
RandomForestOptions & samples_per_tree(int in)
directly specify the number of samples per tree
Definition rf_common.hxx:423
RandomForestOptions & samples_per_tree(int(*in)(int))
use external function to calculate the number of samples each tree should be learnt with.
Definition rf_common.hxx:436
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition rf_common.hxx:411
RandomForestOptions & tree_count(unsigned int in)
Definition rf_common.hxx:500
size_type size() const
Definition tinyvector.hxx:913
iterator end()
Definition tinyvector.hxx:864
iterator begin()
Definition tinyvector.hxx:861
RF_OptionTag
Definition rf_common.hxx:140
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition rf_common.hxx:131

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.1