The EarlyStopping
callback in XGBoost provides a simple way to stop training early if a specified performance metric stops improving on a validation set.
This helps avoid overfitting and reduces training time by terminating the training process once the model’s performance on unseen data stops getting better.
Implementing early stopping is straightforward - simply create an EarlyStopping
callback object and pass it to the callbacks
parameter of xgb.train()
.
Here’s a complete example demonstrating how to use the EarlyStopping
callback:
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error
import xgboost as xgb
# Load example dataset
data = fetch_california_housing()
X, y = data.data, data.target
# Split data into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# Create DMatrix objects
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)
# Define XGBoost parameters
params = {
'objective': 'reg:squarederror',
'max_depth': 3,
'learning_rate': 0.1,
'subsample': 0.8,
'colsample_bytree': 0.8,
}
# Create EarlyStopping callback
early_stop = xgb.callback.EarlyStopping(
rounds=10,
metric_name='rmse',
data_name="validation_0",
save_best=True,
)
# Train model with early stopping
model = xgb.train(
params,
dtrain,
num_boost_round=1000,
evals=[(dtrain, "train"), (dval, "validation_0")],
callbacks=[early_stop],
)
# Make predictions and evaluate performance
y_pred = model.predict(dval)
rmse = root_mean_squared_error(y_val, y_pred)
print(f"Final RMSE: {rmse:.3f}")
print(f"Best iteration: {model.best_iteration}")
print(f"Best score: {model.best_score}")
In this example:
- We load the Housing dataset and split it into train and validation sets.
- We create DMatrix objects for XGBoost and define the model parameters.
- We create an
EarlyStopping
callback object, specifying:rounds
: Number of rounds to wait for improvement before stopping (10 in this case)metric_name
: Metric to monitor for improvement (‘rmse’)data_name
: Name of the validation set (“validation_0”)save_best
: Whether to save the best model (True)
- We train the model with
xgb.train()
, passing theEarlyStopping
callback. - We make predictions on the validation set and evaluate the model’s performance.
- Finally, we print the best iteration and score, confirming that the model stopped training early based on the
EarlyStopping
criteria.
The EarlyStopping
callback provides a convenient way to automate the process of finding the optimal number of boosting rounds, helping to prevent overfitting and reduce training time. By monitoring a specified metric on a validation set, it allows the model to stop training once it starts to overfit, ensuring the best performance on unseen data.