XGBoosting Home | About | Contact | Examples

XGBoost for the Handwritten Digits Dataset

The handwritten digits dataset is a classic dataset for image classification tasks, consisting of 8x8 grayscale images of digits from 0 to 9.

In this example, we’ll load the dataset from scikit-learn, perform hyperparameter tuning using GridSearchCV with common XGBoost parameters, save the best model, load it, and use it to make predictions.

from sklearn.datasets import load_digits
from sklearn.model_selection import GridSearchCV, train_test_split
from xgboost import XGBClassifier
import numpy as np

# Load the digits dataset
digits = load_digits()
X, y = digits.data, digits.target

# Print key information about the dataset
print(f"Dataset shape: {X.shape}")
print(f"Number of classes: {len(np.unique(y))}")

# Display a sample image
import matplotlib.pyplot as plt
plt.imshow(digits.images[0], cmap=plt.cm.binary)
plt.show()

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Define parameter grid
param_grid = {
    'max_depth': [3, 4, 5],
    'learning_rate': [0.1, 0.01, 0.05],
    'n_estimators': [50, 100, 200],
    'subsample': [0.8, 1.0],
    'colsample_bytree': [0.8, 1.0]
}

# Create XGBClassifier
model = XGBClassifier(objective='multi:softmax', random_state=42, n_jobs=1)

# Perform grid search
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=3, n_jobs=-1)
grid_search.fit(X_train, y_train)

# Print best score and parameters
print(f"Best score: {grid_search.best_score_:.3f}")
print(f"Best parameters: {grid_search.best_params_}")

# Access best model
best_model = grid_search.best_estimator_

# Save best model
best_model.save_model('best_model_digits.ubj')

# Load saved model
loaded_model = XGBClassifier()
loaded_model.load_model('best_model_digits.ubj')

# Use loaded model for predictions
predictions = loaded_model.predict(X_test)

# Print accuracy score
accuracy = loaded_model.score(X_test, y_test)
print(f"Accuracy: {accuracy:.3f}")

Running this example, you will see output similar to:

Dataset shape: (1797, 64)
Number of classes: 10
Best score: 0.962
Best parameters: {'colsample_bytree': 0.8, 'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 200, 'subsample': 1.0}
Accuracy: 0.972

The example first loads the digits dataset using load_digits() from scikit-learn and prints the dataset shape and number of unique classes. It also displays a sample image from the dataset.

The data is then split into train and test sets, and a parameter grid is defined with common XGBoost hyperparameters. An XGBClassifier is created, and GridSearchCV is performed with 3-fold cross-validation to find the best parameters.

The best model is accessed, saved to a file named ‘best_model_digits.ubj’, and then loaded back. The loaded model is used to make predictions on the test set, and the accuracy score is printed.

This example demonstrates how to use XGBoost for handwritten digit classification, perform hyperparameter tuning, save and load the best model, and evaluate its performance on a test set.



See Also