XGBoosting Home | About | Contact | Examples

Configure XGBoost "multi:softmax" Objective

The "multi:softmax" objective in XGBoost is used for multi-class classification tasks, where the target variable has more than two distinct classes.

This objective extends binary classification to handle multiple classes by using the softmax function to output class probabilities.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score

# Generate a synthetic dataset for multi-class classification
X, y = make_classification(n_samples=1000, n_classes=3, n_informative=5, random_state=42)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize an XGBClassifier with the "multi:softmax" objective
model = XGBClassifier(objective="multi:softmax", num_class=3, n_estimators=100, learning_rate=0.1)

# Fit the model on the training data
model.fit(X_train, y_train)

# Make predictions on the test set
y_pred = model.predict(X_test)

# Calculate the accuracy of the predictions
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")

Use the "multi:softmax" objective when the target variable is multi-class and the classes are mutually exclusive, meaning each sample belongs to exactly one class.

The softmax function ensures that the predicted class probabilities sum up to 1.

When using the "multi:softmax" objective, keep the following tips in mind:

  1. Ensure that the target variable is encoded as integers starting from 0.
  2. Set the num_class parameter to the number of unique class labels in your dataset.
  3. Consider setting a higher value for learning_rate (eta) as multi-class problems may require more iterations to converge.
  4. Use an appropriate evaluation metric for multi-class classification, such as accuracy, F1-score, or a confusion matrix, to assess the model’s performance.

See Also