XGBoosting Home | About | Contact | Examples

Configure XGBoost Objective "survival:cox" vs "survival:aft"

This example examines two survival analysis objectives within XGBoost: "survival:cox" for a proportional hazards model and "survival:aft" for an accelerated failure time model.

We’ll explore suitable scenarios for each, demonstrating how to set up the data and compare their performance through a practical code implementation.

import xgboost as xgb
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Generate synthetic survival data
X, _ = make_classification(n_samples=1000, n_features=10, random_state=42)
time = np.random.exponential(scale=1, size=1000)  # Survival times
event = np.random.binomial(1, p=0.5, size=1000)  # Censoring indicator

X_train, X_test, time_train, time_test, event_train, event_test = train_test_split(X, time, event, test_size=0.2, random_state=42)

# Prepare DMatrix for Cox Proportional Hazards
dtrain_cox = xgb.DMatrix(X_train, label=time_train, weight=event_train)
dtest_cox = xgb.DMatrix(X_test)

# XGBoost Cox Proportional Hazards Model
params_cox = {'objective': 'survival:cox'}
model_cox = xgb.train(params_cox, dtrain_cox, num_boost_round=50)
predictions_cox = model_cox.predict(dtest_cox)

# Prepare DMatrix for AFT model
lower_bound = np.zeros_like(time_train)  # Lower bounds of survival times, set to zero for non-censoring
upper_bound = np.inf * np.ones_like(time_train)  # Assume infinity where data is censored
upper_bound[event_train == 1] = time_train[event_train == 1]  # Actual survival times for uncensored data

dtrain_aft = xgb.DMatrix(X_train, label=time_train)
dtrain_aft.set_float_info('label_lower_bound', lower_bound)
dtrain_aft.set_float_info('label_upper_bound', upper_bound)
dtest_aft = xgb.DMatrix(X_test)

# XGBoost AFT Model
params_aft = {'objective': 'survival:aft', 'aft_loss_distribution': 'normal'}
model_aft = xgb.train(params_aft, dtrain_aft, num_boost_round=50)
predictions_aft = model_aft.predict(dtest_aft)

# Output survival estimates for comparison
print(f"Cox Proportional Hazards Predictions: {predictions_cox[:5]}")
print(f"Accelerated Failure Time Predictions: {predictions_aft[:5]}")

Best Practices and Tips:

See Also