XGBoosting Home | About | Contact | Examples

Save XGBoost Model Hyperparameters

Saving the hyperparameters of a trained XGBoost model separately from the model itself can be beneficial for documentation, reproducibility, and model comparison purposes.

This example demonstrates how to extract and save the hyperparameters of an XGBoost model to a Python dictionary.

from sklearn.datasets import fetch_california_housing
from xgboost import XGBRegressor
import json

# Load the dataset
housing = fetch_california_housing()
X, y = housing.data, housing.target

# Define the hyperparameters for the XGBoost model
params = {
    'n_estimators': 100,
    'max_depth': 3,
    'learning_rate': 0.1,
    'objective': 'reg:squarederror',
    'subsample': 0.8,
    'colsample_bytree': 0.8,
    'random_state': 42

# Train the XGBoost model
model = XGBRegressor(**params)
model.fit(X, y)

# Retrieve a dict of the hyperparameters from the trained model
hyperparams = model.get_params()

# Optionally, save the dictionary to a JSON file
with open('xgb_params.json', 'w') as f:
    json.dump(hyperparams, f)

# Load the saved hyperparameters from the dictionary or JSON file
loaded_params = json.load(open('xgb_params.json'))

# Create a new XGBoost model and set its hyperparameters
new_model = XGBRegressor()

# Retrieve the parameters from the fit model
hyperparams2 = model.get_params()

# Verify that the new model has the same set of hyperparameters
assert new_model.get_params().keys() == model.get_params().keys()

# Verify all values match
for key in hyperparams.keys():
    if key != 'missing':
        assert hyperparams[key] == hyperparams2[key]

In this example, we train an XGBoost model on the Housing dataset using a set of predefined hyperparameters.

After training, we extract the hyperparameters from the model using the get_params() method that returns a Python dictionary. We save the dictionary to a JSON file for future reference.

To demonstrate loading the hyperparameters, we create a new XGBoost model and set its hyperparameters using the set_params() method with the loaded hyperparameter values.

Finally, we verify that the new model has the same hyperparameters as the original model using an assertion. We skip checking the value of the missing parameters because it is set to nan and cannot be compared directly.

By saving the hyperparameters separately, we can easily keep track of the settings used for a particular model, which can be helpful for reproducing results or comparing different models.

See Also