XGBoosting Home | About | Contact | Examples

How to Use XGBoost TrainingCheckPoint Callback

The TrainingCheckPoint callback in XGBoost allows you to save the model during training at specified intervals.

This is useful for checkpointing the model in case of interruptions and for monitoring the model’s performance over time.

In this example, we’ll demonstrate how to use the TrainingCheckPoint callback to save the model every 10 iterations.

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error
import xgboost as xgb
import os

# Load example dataset
data = fetch_california_housing()
X, y = data.data, data.target

# Split data into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Create DMatrix objects
dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)

# Define XGBoost parameters
params = {
    'objective': 'reg:squarederror',
    'max_depth': 3,
    'learning_rate': 0.1,
    'subsample': 0.8,
    'colsample_bytree': 0.8,
}

# Create TrainingCheckPoint callback
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = xgb.callback.TrainingCheckPoint(directory=checkpoint_dir, iterations=10)

# Train model with TrainingCheckPoint callback
model = xgb.train(
    params,
    dtrain,
    num_boost_round=100,
    evals=[(dtrain, "train"), (dval, "validation")],
    callbacks=[checkpoint]
)

# Make predictions and evaluate performance
y_pred = model.predict(dval)
rmse = root_mean_squared_error(y_val, y_pred)
print(f"Final RMSE: {rmse:.3f}")

# Load a saved checkpoint model
loaded_model = xgb.Booster()
loaded_model.load_model(os.path.join(checkpoint_dir, 'model_90.json'))

# Make predictions with the loaded model
y_pred_loaded = loaded_model.predict(dval)
rmse_loaded = root_mean_squared_error(y_val, y_pred_loaded)
print(f"RMSE from loaded model: {rmse_loaded:.3f}")

In this example, we load the California Housing dataset and split it into train and validation sets. We create DMatrix objects for the train and validation data and set the XGBoost parameters.

Next, we create a TrainingCheckPoint callback, specifying the directory to save the checkpoints (./checkpoints) and the interval at which to save the model (every 10 iterations). We pass the callback to the xgb.train function along with the other parameters.

The model is trained for 100 boosting iterations, and the model is saved every 10 iterations in the specified directory. After training, we make predictions on the validation set and evaluate the model’s performance using root mean squared error (RMSE).

To demonstrate loading a saved checkpoint, we create a new Booster object and load the model from the checkpoint saved at iteration 90 using load_model. We then make predictions with the loaded model and calculate the RMSE to verify that the loaded model performs identically to the original model.

By using the TrainingCheckPoint callback, you can easily save your model during training, allowing for interruption recovery and performance monitoring. The saved checkpoints can also be used for inference or resuming training at a later time.



See Also