DecisionTreeClassifier
A CART (Classification and Regression Trees) decision tree classifier. The algorithm recursively partitions the feature space with binary splits, choosing at each node the feature and threshold that best separate the classes.
Splitting Criterion — Gini Impurity
Gini impurity measures the probability that a randomly chosen sample from node would be misclassified if labeled according to the class distribution at that node:
where is the proportion of class- samples at node . A pure node () contains only one class.
At each split, the algorithm searches over all features and thresholds for the partition that minimizes the weighted Gini impurity of the two child nodes:
The information gain of a split is .
CART Algorithm
The tree is built recursively using the CART algorithm:
- At each node, evaluate every possible split — feature , threshold — and select the one that minimizes .
- Create left child () and right child ().
- Recurse until a stopping condition is met:
max_depthreached, fewer thanmin_samples_splitsamples, or the node is pure.
Prediction assigns the majority class of the leaf node reached by the query sample.
When to Use
- Interpretability: Decision trees are easy to visualize and explain.
- No scaling needed: Trees are invariant to monotonic transformations of features.
- High variance: Single trees are prone to overfitting — consider ensemble methods (Random Forest, Gradient Boosting) for better generalization.
Mirrors sklearn.tree.DecisionTreeClassifier.
Constructor
Skigen::DecisionTreeClassifier<Scalar> tree(int max_depth = -1,
int min_samples_split = 2);
| Parameter | Default | Description |
|---|---|---|
max_depth | -1 | Maximum tree depth ( = unlimited) |
min_samples_split | 2 | Minimum samples required to split a node |
Methods
| Method | Description |
|---|---|
fit(X, y) | Build the decision tree |
predict(X) | Predict class labels |
score(X, y) | Return classification accuracy |
Example
#include <Skigen/Tree>
Skigen::DecisionTreeClassifier tree(/*max_depth=*/5);
tree.fit(X_train, y_train);
std::cout << "Accuracy: " << tree.score(X_test, y_test) << "\n";