This example examines two survival analysis objectives within XGBoost: "survival:cox"
for a proportional hazards model and "survival:aft"
for an accelerated failure time model.
We’ll explore suitable scenarios for each, demonstrating how to set up the data and compare their performance through a practical code implementation.
import xgboost as xgb
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
# Generate synthetic survival data
np.random.seed(42)
X, _ = make_classification(n_samples=1000, n_features=10, random_state=42)
time = np.random.exponential(scale=1, size=1000) # Survival times
event = np.random.binomial(1, p=0.5, size=1000) # Censoring indicator
X_train, X_test, time_train, time_test, event_train, event_test = train_test_split(X, time, event, test_size=0.2, random_state=42)
# Prepare DMatrix for Cox Proportional Hazards
dtrain_cox = xgb.DMatrix(X_train, label=time_train, weight=event_train)
dtest_cox = xgb.DMatrix(X_test)
# XGBoost Cox Proportional Hazards Model
params_cox = {'objective': 'survival:cox'}
model_cox = xgb.train(params_cox, dtrain_cox, num_boost_round=50)
predictions_cox = model_cox.predict(dtest_cox)
# Prepare DMatrix for AFT model
lower_bound = np.zeros_like(time_train) # Lower bounds of survival times, set to zero for non-censoring
upper_bound = np.inf * np.ones_like(time_train) # Assume infinity where data is censored
upper_bound[event_train == 1] = time_train[event_train == 1] # Actual survival times for uncensored data
dtrain_aft = xgb.DMatrix(X_train, label=time_train)
dtrain_aft.set_float_info('label_lower_bound', lower_bound)
dtrain_aft.set_float_info('label_upper_bound', upper_bound)
dtest_aft = xgb.DMatrix(X_test)
# XGBoost AFT Model
params_aft = {'objective': 'survival:aft', 'aft_loss_distribution': 'normal'}
model_aft = xgb.train(params_aft, dtrain_aft, num_boost_round=50)
predictions_aft = model_aft.predict(dtest_aft)
# Output survival estimates for comparison
print(f"Cox Proportional Hazards Predictions: {predictions_cox[:5]}")
print(f"Accelerated Failure Time Predictions: {predictions_aft[:5]}")
- The
"survival:cox"
objective models the log-risk of events as a linear function of the predictors and is suitable for cases where the hazard functions of different individuals are proportional. - The
"survival:aft"
objective assumes that the effect of predictors is to accelerate or decelerate the life course of an individual by a constant factor. It’s particularly useful when the underlying distribution of survival times is assumed to follow a parametric form.
Best Practices and Tips:
- For
"survival:cox"
, ensure the proportionality assumption is plausible with your data; otherwise, model performance might degrade. - For
"survival:aft"
, choose the appropriate loss distribution based on the nature of the survival data. Common options include normal, logistic, and extreme value distributions. - When preparing data for survival analysis, handle right-censored data correctly by incorporating an indicator variable that distinguishes between observed and censored events.
- Tune hyperparameters like
learning_rate
andmax_depth
thoughtfully to avoid overfitting while ensuring the model captures relevant patterns in the data.