XGBoosting Home | About | Contact | Examples

Predict with XGBoost's Native API

While XGBoost integrates seamlessly with scikit-learn’s API, the library also provides its own native interface for training and making predictions.

Using XGBoost’s predict() method can offer more flexibility and efficiency, especially when working with large datasets or in production environments.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import xgboost as xgb

# Generate synthetic data
X, y = make_classification(n_samples=10000, n_features=5, n_informative=3, n_redundant=1, random_state=42)
# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

# Create DMatrix objects for training and testing sets
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

# Set XGBoost parameters
params = {
    'objective': 'binary:logistic',
    'learning_rate': 0.1,
    'random_state': 42
}

# Train the model
model = xgb.train(params, dtrain)

# Make predictions
predictions = model.predict(dtest)

print("Predicted values:\n", predictions[:5])  # Print the first 5 predictions

To use the native API for prediction:

  1. Create an xgb.DMatrix object from your test data. This is a special data structure used by XGBoost that can handle both dense and sparse input formats.

  2. Fit your model on the training dataset or load your pre-trained model using xgb.Booster.load_model().

  3. Call the predict() method on the model, passing in the DMatrix object containing your test data.

The predict() method will return an array of predictions, one for each sample in your test dataset.

Quick Summary

Using XGBoost’s native API for making predictions offers several benefits:

Consider using the native API when you need to make predictions on large datasets or in production scenarios where performance is critical. The scikit-learn API is still a great choice for many tasks, especially during the model development and evaluation phases.

Efficient prediction is crucial in real-world applications where models need to generate results quickly and at scale. By leveraging XGBoost’s native API, you can optimize your prediction pipeline and ensure your models are performing at their best when it matters most.



See Also