Skip to main content

DecisionTreeClassifier

#include <Skigen/Tree>

template <typename Scalar = double>
class Skigen::DecisionTreeClassifier(max_depth=-1, min_samples_split=2, max_features_mode=0, max_features_value=0.0, random_state=std::nullopt)

A decision tree classifier.

A non-parametric supervised learning method used for classification. The model predicts the value of a target variable by learning simple decision rules inferred from the data features. Uses Gini impurity as the splitting criterion.

Mirrors sklearn.tree.DecisionTreeClassifier.



Attributes:

  • classes : const Eigen::VectorXi

  • n_classes : int

  • feature_importances : RowVectorType


Methods

fit(X, y)

Build a decision tree classifier from the training set.


fit(X, y)

Fit natively on a sparse design matrix without densifying.

Split finding runs through a CSC column accessor that materialises one feature column at a time (implicit zeros filled as 0), so the full dense n×pn \times p matrix is never built. Results match a dense fit exactly (scikit-learn's sparse splitter treats implicit zeros as value 0 in the sorted order).


predict(X)

Predict using a sparse design matrix (densifies internally).


fit_with_indices(X, y, sample_indices)

Fit using a specific row index subset (used by RandomForest bootstraps).

Parameters:

  • sample_indices : const std::vector< Eigen::Index > & If empty, uses all rows of X. Otherwise, builds the tree from the rows specified (with possible repetitions).

predict(X)

Predict class labels for samples in X.


predict_proba(X)

Per-class probability estimates from leaf class distributions.


fit_columns(cols, y, sample_indices, n_rows, n_cols)


build_tree(X, y, indices, depth)


fit(X, y)


predict(X)


Example

// Best model
Skigen::DecisionTreeClassifier<double> best(5);
best.fit(split.X_train, split.y_train);
auto best_pred = best.predict(split.X_test);

std::cout << "\n=== Confusion Matrix (depth=5) ===\n";
auto cm = Skigen::Metrics::confusion_matrix(split.y_test, best_pred);
std::cout << cm << "\n";