RandomForestClassifier
An ensemble of decision trees, each grown on a bootstrap sample of the data and a random subspace of features. Class predictions are formed by majority (hard) or averaged-probability (soft) voting.
Algorithm
Each tree is fit independently on a bootstrap resample; at every split a random subset of max_features candidate features is considered. Decorrelating the trees this way reduces the variance of the averaged ensemble without materially increasing bias. When oob_score=true, out-of-bag samples (those not drawn for a given tree) provide an unbiased generalisation estimate for free.
Constructor
Skigen::RandomForestClassifier<Scalar> model(int n_estimators = 100, CriterionClf = Gini, std::optional<int> max_depth = nullopt, ...);
Parameters
| Parameter | Default | Description |
|---|---|---|
n_estimators | 100 | Number of trees in the forest. |
criterion | Gini | Split quality measure (Gini or Entropy). |
max_depth | nullopt | Maximum tree depth; unbounded if unset. |
max_features | Sqrt | Candidate features per split. |
bootstrap | true | Sample with replacement per tree. |
oob_score | false | Expose an out-of-bag accuracy estimate. |
n_jobs | 1 | Trees fitted in parallel via std::async. |
random_state | nullopt | Seed for reproducible forests. |
Methods
| Method | Description |
|---|---|
fit(X, y) | Grow the forest on labelled data. |
predict(X) | Majority-vote class labels. |
predict_proba(X) | Class probabilities (mean of tree votes). |
score(X, y) | Mean accuracy. |
Fitted Attributes
| Accessor | Description |
|---|---|
estimators() | The fitted trees. |
feature_importances() | Mean impurity decrease per feature. |
oob_score() | Out-of-bag accuracy (when enabled). |
Example
Skigen::RandomForestClassifier<double> rf(100);
rf.fit(X, y);
auto preds = rf.predict(X_test);
This estimator is checked by the parity suite. See the generator tests/parity/generate_ensemble_reference.py and the reference fixtures in tests/parity/data/random_forest_classifier/, exercised by tests/parity/parity_ensemble.cpp.
For full signatures see the RandomForestClassifier API Reference.