The booster
parameter in XGBoost determines the type of base learner used in the model.
The three options are gbtree
(gradient boosted trees), gblinear
(gradient boosted linear models), and dart
(dropout-enabled trees).
Selecting the right booster can significantly impact the performance of your XGBoost model.
This example demonstrates how to compare different boosters using cross-validation to find the best one for your dataset.
import xgboost as xgb
from sklearn.datasets import make_classification
from sklearn.model_selection import StratifiedKFold, cross_val_score
# Create a synthetic dataset
X, y = make_classification(n_samples=1000, n_classes=2, n_features=20, n_informative=10, random_state=42)
# Define booster types to compare
booster_types = ['gbtree', 'gblinear', 'dart']
# Set up cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# Function to train and evaluate XGBoost models with different boosters
def evaluate_booster(booster):
model = xgb.XGBClassifier(n_estimators=100, learning_rate=0.1, booster=booster, random_state=42)
scores = cross_val_score(model, X, y, cv=cv, scoring='accuracy', n_jobs=-1)
return scores.mean()
# Compare boosters
results = {}
for booster in booster_types:
avg_score = evaluate_booster(booster)
results[booster] = avg_score
print(f"Booster: {booster}, Average CV Accuracy: {avg_score:.4f}")
# Select the best booster
best_booster = max(results, key=results.get)
print(f"\nBest Booster: {best_booster}")
# Train a final model with the best booster
final_model = xgb.XGBClassifier(n_estimators=100, learning_rate=0.1, booster=best_booster, random_state=42)
final_model.fit(X, y)
In this example, we create a synthetic binary classification dataset using scikit-learn’s make_classification
function. We define a list of booster types (gbtree
, gblinear
, and dart
) that we want to compare.
We set up a StratifiedKFold
cross-validation object to ensure that the class distribution is preserved in each fold. We then define a function evaluate_booster
that takes a booster type as input, creates an XGBClassifier
with the specified booster, and performs cross-validation using the cross_val_score
function from scikit-learn. The function returns the mean cross-validation accuracy score.
We iterate through the booster types, call the evaluate_booster
function for each booster, and store the results in a dictionary. We print the average cross-validation accuracy score for each booster type.
To select the best booster, we use the max
function with the key
parameter set to results.get
to find the booster with the highest average score. We print the best booster type.
Finally, we train a final model using the best booster and the full dataset. This model can be used for making predictions on new data.
By comparing different boosters using cross-validation, you can find the one that performs best on your specific dataset. This helps you configure your XGBoost model optimally and achieve better results.