Skip to main content

KNeighborsRegressor

#include <Skigen/Neighbors>

template <typename Scalar = double>
class Skigen::KNeighborsRegressor(n_neighbors=5)

Regression based on k-nearest neighbors.

The target is predicted by local interpolation of the targets associated of the nearest neighbors in the training set (mean of the k nearest values).

Mirrors sklearn.neighbors.KNeighborsRegressor.


Parameters:

  • n_neighbors : int, default=5 Number of neighbors to use (int, default 5).

Attributes:

  • is_fitted : bool Whether the estimator has been fitted.

Methods

fit(X, y)

Fit the k-nearest neighbors regressor.

Parameters:

  • X : MatrixType Training data of shape (n_samples, n_features).

  • y : VectorType Target values of shape (n_samples,).

Returns:

  • result : KNeighborsRegressor Reference to the fitted estimator (*this).

Throws:

  • std::invalid_argument — if X and y have inconsistent lengths.

predict(X)

Predict target values for the provided data.

Returns the mean of the target values of the k nearest neighbors.

Parameters:

  • X : MatrixType Test samples of shape (n_samples, n_features).

Returns:

  • result : VectorType Predicted values of shape (n_samples,).

Throws:

  • std::runtime_error — if the model has not been fitted.

score(X, y)

Return the R2R^2 coefficient of determination.

Parameters:

  • X : MatrixType Test samples of shape (n_samples, n_features).

  • y : VectorType True values of shape (n_samples,).

Returns:

  • result : Scalar R2R^2 score.

Example

// KNN for regression
Eigen::VectorXd y_reg(split.X_train.rows());
for (Eigen::Index i = 0; i < y_reg.size(); ++i)
y_reg(i) = split.X_train(i, 0) + 0.5 * split.X_train(i, 1);

Skigen::KNeighborsRegressor<double> knn_reg(5);
knn_reg.fit(split.X_train, y_reg);

Eigen::VectorXd y_reg_test(split.X_test.rows());
for (Eigen::Index i = 0; i < y_reg_test.size(); ++i)
y_reg_test(i) = split.X_test(i, 0) + 0.5 * split.X_test(i, 1);

std::cout << "\n=== KNeighborsRegressor (k=5) ===\n";
std::cout << "R²: " << knn_reg.score(split.X_test, y_reg_test) << "\n";