XGBoosting Home | About | Contact | Examples

XGBoost get_booster()

XGBoost is a powerful library that can be used directly in Python or through the Scikit-Learn API.

While the Scikit-Learn interface provides a convenient and familiar way to work with XGBoost models, there may be situations where you need access to the underlying Booster object to utilize some of the advanced native features of XGBoost.

Fortunately, the Scikit-Learn API provides a get_booster() method that allows you to retrieve the Booster object from a fitted XGBClassifier or XGBRegressor model.

This enables you to work with the native XGBoost API when needed, while still benefiting from the simplicity and consistency of the Scikit-Learn interface.

Here’s an example of how to access the Booster object from a fitted XGBClassifier model:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier

# Generate a small synthetic classification dataset
X, y = make_classification(n_samples=100, n_features=5, random_state=42)

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize and fit an XGBClassifier using the Scikit-Learn API
model = XGBClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Access the underlying Booster object
booster = model.get_booster()

# Print the type of the returned object to confirm it's a Booster
print(type(booster))  # Output: <class 'xgboost.core.Booster'>

In this example, we first generate a small synthetic classification dataset using make_classification() from Scikit-Learn. We then split the data into training and test sets using train_test_split().

Next, we initialize an XGBClassifier with a specified number of estimators and a random state for reproducibility. We fit the model using the training data through the familiar Scikit-Learn fit() method.

After the model is trained, we call the get_booster() method on the fitted model to retrieve the underlying Booster object. We print the type of the returned object to confirm that it is indeed an instance of xgboost.core.Booster.

Once you have access to the Booster object, you can use various native XGBoost methods and properties to perform advanced operations or access detailed information about the model.

Keep in mind that when using the native XGBoost API, you may need to work with DMatrix objects for data input and be cautious about any modifications you make, as they can affect the model’s behavior and performance.

By leveraging the get_booster() method, you can enjoy the best of both worlds: the simplicity of the Scikit-Learn API and the power of the native XGBoost API when needed.



See Also