XGBoosting Home | About | Contact | Examples

XGBoost for Survival Analysis (Accelerated Failure Time)

XGBoost can be used for survival analysis tasks, such as predicting the time until an event occurs.

One approach to survival analysis is the Accelerated Failure Time (AFT) model, which assumes that the effect of covariates is to accelerate or decelerate the time to event.

Here’s a quick guide on how to train an XGBoost model for survival analysis using the AFT model and the scikit-learn API, with a synthetic dataset generated using NumPy.

# XGBoosting.com
# Train an XGBoost Model for Survival Analysis using AFT Model and scikit-learn API
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split

# Generate a synthetic dataset with features and survival times
n_samples = 1000
n_features = 10
X = np.random.rand(n_samples, n_features)
# Generate survival times from a Weibull distribution
scale = np.exp(-X[:, 0])
shape = 1.5
y = np.random.weibull(shape, n_samples) * scale

# Create lower and upper bounds, here they are the same as y because there is no censoring
y_lower = y_upper = y

# Split the data into training and testing sets
X_train, X_test, y_train, y_test, y_lower_train, y_lower_test, y_upper_train, y_upper_test = train_test_split(X, y, y_lower, y_upper, test_size=0.2, random_state=42)

# Convert data into DMatrix, specifying the label, label_lower_bound, and label_upper_bound
dtrain = xgb.DMatrix(X_train, label=y_train, label_lower_bound=y_lower_train, label_upper_bound=y_upper_train)
dtest = xgb.DMatrix(X_test, label=y_test, label_lower_bound=y_lower_test, label_upper_bound=y_upper_test)

# Initialize an XGBRegressor with the "survival:aft" objective
params = {
    'objective': 'survival:aft',
    'eval_metric': 'aft-nloglik',
    'aft_loss_distribution': 'normal',
    'aft_loss_distribution_scale': 1.0,
    'learning_rate': 0.1
}

# Fit the model on the training data
bst = xgb.train(params, dtrain, num_boost_round=100)

# Make predictions on the test set
y_pred = bst.predict(dtest)

# Output the predicted survival times for demonstration purposes
print("Predicted survival times:", y_pred[:5])

In this example:

  1. We generate a synthetic dataset using NumPy, with features X and survival times y drawn from a Weibull distribution.
  2. We initialize configuration parameters with the 'survival:aft' objective for survival analysis using the AFT model. The evaluation metric is set to 'aft-nloglik' (negative log-likelihood for AFT).
  3. We fit the model to the training data using the train() method, passing input features, targets and upper and lower times.


See Also