XGBoost provides built-in support for handling categorical features directly, without the need for manual encoding.
By specifying the feature types in your DataFrame
and setting the enable_categorical
parameter to True
when initializing the XGBoost model, you can streamline your data preparation process and improve the efficiency of your workflow.
import pandas as pd
from xgboost import XGBClassifier
# Load data into a DataFrame
data = pd.DataFrame({
'color': ['red', 'blue', 'green', 'red', 'blue'],
'size': ['small', 'medium', 'large', 'medium', 'small'],
'price': [10.0, 15.0, 20.0, 12.0, 8.0],
'label': [0, 1, 1, 0, 0]
})
# Specify categorical columns
categorical_columns = ['color', 'size']
# Convert categorical columns to 'category' data type
for col in categorical_columns:
data[col] = data[col].astype('category')
# Split into features and target
X = data.drop('label', axis=1)
y = data['label']
# Initialize and train the model
model = XGBClassifier(enable_categorical=True)
model.fit(X, y)
# New data for prediction
new_data = pd.DataFrame({
'color': ['green', 'red'],
'size': ['medium', 'large'],
'price': [18.0, 25.0]
})
# Convert categorical columns to 'category' data type
for col in categorical_columns:
new_data[col] = new_data[col].astype('category')
# Make predictions
predictions = model.predict(new_data)
print("Predictions:", predictions)
To leverage XGBoost’s native categorical feature support:
Load your data into a pandas DataFrame.
Identify the columns that contain categorical data.
Convert these categorical columns to the
'category'
data type usingastype('category')
. This step informs pandas that these columns hold categorical data.Split your data into features (X) and the target variable (y).
Initialize your XGBoost model with
enable_categorical=True
. This flag tells XGBoost to handle categorical features natively.Train your model using the
fit()
method, passing in your feature matrix (X) and target variable (y).When making predictions on new data, ensure that the categorical columns in the new data are also converted to the
'category'
data type, consistent with the training data.