In imbalanced classification problems, the default probability threshold of 0.5 may not always yield the best performance.
Threshold moving is a technique that involves adjusting the probability threshold used to assign class labels, allowing you to find the optimal threshold that maximizes a chosen evaluation metric, such as F1-score or balanced accuracy.
This example demonstrates how to apply threshold moving to an XGBoost model trained on an imbalanced binary classification dataset and evaluate the model’s performance at different probability thresholds using precision-recall curves and a threshold optimization function.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_curve
import numpy as np
# Generate an imbalanced synthetic binary classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.9, 0.1], random_state=42)
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train an XGBClassifier with scale_pos_weight to handle class imbalance
model = XGBClassifier(scale_pos_weight=9, random_state=42)
model.fit(X_train, y_train)
# Get predicted probabilities for the test set
y_pred_proba = model.predict_proba(X_test)[:, 1]
# Define a function to calculate evaluation metrics for a given threshold
def evaluate_threshold(y_true, y_pred_proba, threshold):
y_pred = (y_pred_proba >= threshold).astype(int)
precision, recall, _ = precision_recall_curve(y_true, y_pred)
f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
balanced_accuracy = (recall[1] + recall[0]) / 2
return precision, recall, f1, balanced_accuracy
# Evaluate the model's performance at different probability thresholds
thresholds = np.arange(0.1, 1.0, 0.1)
metrics = [evaluate_threshold(y_test, y_pred_proba, t) for t in thresholds]
precision, recall, f1, balanced_accuracy = zip(*metrics)
# Find the optimal probability threshold
def find_optimal_threshold(y_true, y_pred_proba, metric='f1'):
thresholds = np.arange(0.1, 1.0, 0.01)
metrics = [evaluate_threshold(y_true, y_pred_proba, t) for t in thresholds]
if metric == 'f1':
idx = np.argmax([f1 for _, _, f1, _ in metrics])
elif metric == 'balanced_accuracy':
idx = np.argmax([ba for _, _, _, ba in metrics])
return thresholds[idx]
optimal_threshold = find_optimal_threshold(y_test, y_pred_proba, metric='f1')
# Generate final predictions using the optimal threshold
y_pred = (y_pred_proba >= optimal_threshold).astype(int)
# Evaluate the model with the optimal threshold
print(f"Optimal Threshold: {optimal_threshold:.2f}")
print("Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))
print("\nClassification Report:")
print(classification_report(y_test, y_pred))
The example starts by generating an imbalanced synthetic binary classification dataset and splitting it into training and testing sets. An XGBClassifier is trained on the imbalanced dataset, using the scale_pos_weight
parameter to handle class imbalance.
Next, the evaluate_threshold
function is defined to calculate evaluation metrics (precision, recall, F1-score, balanced accuracy) for a given set of true labels, predicted probabilities, and probability threshold.
The model’s performance is evaluated at different probability thresholds using the evaluate_threshold
function, and the results are stored for plotting precision-recall curves.
The find_optimal_threshold
function is defined to find the optimal probability threshold that maximizes a chosen evaluation metric (e.g., F1-score or balanced accuracy). This function is then applied to find the best threshold for the model.
Finally, the model generates predictions using the optimal threshold and evaluates its performance using the confusion matrix and classification report.
By applying threshold moving, you can find the optimal probability threshold that maximizes the chosen evaluation metric, helping to improve the model’s performance on imbalanced classification problems.