GClasses

GClasses::GEnsemble Class Reference

This is a base-class for ensembles that combine the predictions from multiple weightd models. More...

#include <GEnsemble.h>

Inheritance diagram for GClasses::GEnsemble:
GClasses::GSupervisedLearner GClasses::GTransducer GClasses::GAdaBoost GClasses::GBag GClasses::GBayesianModelAveraging GClasses::GBayesianModelCombination

List of all members.

Public Member Functions

 GEnsemble (GRand &rand)
 General-purpose constructor.
 GEnsemble (GDomNode *pNode, GLearnerLoader &ll)
 Deserializing constructor.
virtual ~GEnsemble ()

Protected Member Functions

virtual void serializeBase (GDom *pDoc, GDomNode *pNode)
 Base classes should call this method to serialize the base object as part of their implementation of the serialize method.
virtual void clearBase ()
 Calls clear on all of the models, and resets the accumulator buffer.
virtual void trainInner (GMatrix &features, GMatrix &labels)
 Sets up the accumulator buffer (ballot box) then calls trainInnerInner.
virtual void trainInnerInner (GMatrix &features, GMatrix &labels)=0
 Implement this method to train the ensemble.
virtual void predictInner (const double *pIn, double *pOut)
 See the comment for GSupervisedLearner::predictInner.
virtual void predictDistributionInner (const double *pIn, GPrediction *pOut)
 See the comment for GSupervisedLearner::predictDistributionInner.
void normalizeWeights ()
 Scales the weights of all the models so they sum to 1.0.
void castVote (double weight, const double *pOut)
 Adds the vote from one of the models.
void tally (GPrediction *pOut)
 Counts all the votes from the models in the bag, assuming you are interested in knowing the distribution.
void tally (double *pOut)
 Counts all the votes from the models in the bag, assuming you only care to know the winner, and do not care about the distribution.

Protected Attributes

sp_relation m_pLabelRel
std::vector< GWeightedModel * > m_models
size_t m_nAccumulatorDims
double * m_pAccumulator

Detailed Description

This is a base-class for ensembles that combine the predictions from multiple weightd models.


Constructor & Destructor Documentation

GClasses::GEnsemble::GEnsemble ( GRand rand)

General-purpose constructor.

GClasses::GEnsemble::GEnsemble ( GDomNode pNode,
GLearnerLoader ll 
)

Deserializing constructor.

virtual GClasses::GEnsemble::~GEnsemble ( ) [virtual]

Member Function Documentation

void GClasses::GEnsemble::castVote ( double  weight,
const double *  pOut 
) [protected]

Adds the vote from one of the models.

virtual void GClasses::GEnsemble::clearBase ( ) [protected, virtual]

Calls clear on all of the models, and resets the accumulator buffer.

void GClasses::GEnsemble::normalizeWeights ( ) [protected]

Scales the weights of all the models so they sum to 1.0.

virtual void GClasses::GEnsemble::predictDistributionInner ( const double *  pIn,
GPrediction pOut 
) [protected, virtual]
virtual void GClasses::GEnsemble::predictInner ( const double *  pIn,
double *  pOut 
) [protected, virtual]
virtual void GClasses::GEnsemble::serializeBase ( GDom pDoc,
GDomNode pNode 
) [protected, virtual]

Base classes should call this method to serialize the base object as part of their implementation of the serialize method.

void GClasses::GEnsemble::tally ( double *  pOut) [protected]

Counts all the votes from the models in the bag, assuming you only care to know the winner, and do not care about the distribution.

void GClasses::GEnsemble::tally ( GPrediction pOut) [protected]

Counts all the votes from the models in the bag, assuming you are interested in knowing the distribution.

virtual void GClasses::GEnsemble::trainInner ( GMatrix features,
GMatrix labels 
) [protected, virtual]

Sets up the accumulator buffer (ballot box) then calls trainInnerInner.

Implements GClasses::GSupervisedLearner.

virtual void GClasses::GEnsemble::trainInnerInner ( GMatrix features,
GMatrix labels 
) [protected, pure virtual]

Implement this method to train the ensemble.

Implemented in GClasses::GBag, and GClasses::GAdaBoost.


Member Data Documentation

std::vector<GWeightedModel*> GClasses::GEnsemble::m_models [protected]