XGBoosting Home | About | Contact | Examples

XGBoost for the Kaggle MINST Handwritten Digit Recognizer Dataset

MNIST (“Modified National Institute of Standards and Technology”) is the de facto “hello world” dataset of computer vision.

The dataset contains gray-scale images of hand-drawn digits, from zero through nine.

XGBoost is not well suited to computer vision problems, nevertheless, it can be used and achieve modest performance.

Download the Training Dataset

The first step is to download the train.csv training dataset from the competition website.

This will require you to create an account and sign-in before you can access the dataset.

We may have to accept the competition rules.

Finally, we can download the train.csv.zip file that contains the train.csv data set.

XGBoost Example

Next, we can address the dataset with XGBoost.

This example will demonstrate how to download the training dataset, explore the data, perform hyperparameter tuning with XGBoost, save the best model, load it, and use it to make predictions.

import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import accuracy_score
from xgboost import XGBClassifier
from collections import Counter

# Load the training dataset
dataset = pd.read_csv('train.csv')

# Split into input and output elements
X, y = dataset.values[:,1:], dataset.values[:,0]

# Print key information about the dataset
print(f"Dataset shape: {dataset.shape}")
print(f"Number of features: {X.shape[1]}")
print(f"Target variable: {dataset.columns[0]}")
print(f"Class distributions: {Counter(y)}")

# Split into train and validation sets
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Define parameter grid
param_grid = {
    'max_depth': [3, 5, 7, 19],
    'learning_rate': [0.1, 0.01],
    'n_estimators': [1000],
    'colsample_bytree': [0.8, 1],
    'reg_alpha': [0, 0.5, 1],
    'reg_lambda': [0, 0.001, 0.1, 0.5, 1],
    'subsample': [0.8, 1],

# Create XGBClassifier
model = XGBClassifier(objective='multi:softprob', tree_method='hist', random_state=42, n_jobs=2)

# Perform grid search
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=3, n_jobs=4, pre_dispatch=4, verbose=1)
grid_search.fit(X_train, y_train)

# Print best score and parameters
print(f"Best score: {grid_search.best_score_:.3f}")
print(f"Best parameters: {grid_search.best_params_}")

# Access best model
best_model = grid_search.best_estimator_

# Save best model

# Load saved model
loaded_model = XGBClassifier()

# Use loaded model for predictions
predictions = loaded_model.predict(X_valid)

# Print accuracy score
accuracy = accuracy_score(y_valid, predictions)
print(f"Accuracy: {accuracy:.3f}")

Load the dataset using pandas. The dataset contains 42,000 samples and 784 pixel input features.

The target variable is the digital category, which has 9 possible classes: hand written digits 0 through 9). We split the data into train and validation sets.

Next, we define a parameter grid with common XGBoost hyperparameters, create an XGBClassifier with the multi:softprob objective for multiclass classification and ‘hist’ tree type for performance, and perform a grid search with 3-fold cross-validation to find the best parameters.

After printing the best score and parameters, we access the best model, save it to disk, load the saved model, use it to make predictions on the validation set, and print the multiclass classification accuracy.

Running this code will output results similar to:

Dataset shape: (42000, 785)
Number of features: 784
Target variable: label
Class distributions: Counter({1: 4684, 7: 4401, 3: 4351, 9: 4188, 2: 4177, 6: 4137, 0: 4132, 4: 4072, 8: 4063, 5: 3795})
Fitting 3 folds for each of 480 candidates, totalling 1440 fits
Best score: 0.851
Best parameters: {'colsample_bytree': 0.8, 'learning_rate': 0.1, 'max_depth': 19, 'n_estimators': 1000, 'reg_alpha': 0.5, 'reg_lambda': 0.001, 'subsample': 0.8}
Accuracy: 0.835

This example demonstrates the end-to-end process of using XGBoost on a Kaggle dataset, from downloading the data to making predictions with a tuned model.

See Also