XGBoosting Home | About | Contact | Examples

Train an XGBoost Model on a CSV File

When your dataset is stored in a CSV file, you can easily load it using Pandas and then convert it to a DMatrix for training an XGBoost model.

Here’s how you can do it:

import pandas as pd
from xgboost import DMatrix, train

# content of data.csv:
# "A","B","C","target"
# 1,2,3,0
# 4,5,6,1
# 7,8,9,1

# Load data from CSV file
data = pd.read_csv('data.csv')

# Separate features and target
X = data.drop('target', axis=1)
y = data['target']

# Create DMatrix from X and y
dmatrix = DMatrix(data=X, label=y)

# Set XGBoost parameters
params = {
    'objective': 'binary:logistic',
    'learning_rate': 0.1,
    'random_state': 42

# Train the model
model = train(params, dmatrix)

Here’s what’s happening:

  1. We use Pandas’ read_csv() function to load the data from a CSV file named 'data.csv' into a DataFrame called data. Pandas automatically infers the data types of each column.

  2. We separate the features and target from the data DataFrame. Here, we assume that the target variable is in a column named 'target'. We use drop() to select all columns except 'target' for our features X, and directly index the 'target' column for our target variable y.

  3. We create a DMatrix object called dmatrix from our features X and target y. This converts our Pandas DataFrame into the optimized data structure used by XGBoost.

  4. We set the XGBoost parameters using a dictionary params. Here, we specify the objective function (binary logistic for binary classification), number of estimators (trees), learning rate, and random seed. These parameters can be tuned for your specific use case.

  5. We train the model by passing the params dictionary and dmatrix to the train function. This function is part of XGBoost’s native API and handles the actual model training process.

By following these steps, you can quickly load your data from a CSV file, convert it to the appropriate format for XGBoost, and train your model.

Remember to handle any missing values or data type issues in your CSV file before creating the DMatrix. Pandas provides functions like fillna() and astype() to handle these cases.

See Also