Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Modèle de XGBoost AFT

Authors
Affiliations
M2 MIASHS - Université de Lyon
M2 MIASHS - Université de Lyon
M2 MIASHS - Université de Lyon

Imports

import numpy as np
import polars as pl

Chargement des données

from sksurv.util import Surv

df_train = pl.read_parquet("../../data/df_study_L18_w6_train.parquet")
df_test = pl.read_parquet("../../data/df_study_L18_w6_test.parquet")

X_train = df_train.drop(["event", "time"]).to_pandas()
X_test = df_test.drop(["event", "time"]).to_pandas()

y_train = Surv.from_dataframe("event", "time", df_train.to_pandas())
y_test = Surv.from_dataframe("event", "time", df_test.to_pandas())

Entraînement du modèle

Le modèle est configuré en mode AFT (survival:aft) avec une distribution normale, ce qui signifie qu’il modélise log(T)N(f(X),σ)\log(T) \sim \mathcal{N}(f(X), \sigma).

import xgboost as xgb

# Extract time and event from the sksurv structured array
lower = y_train["time"]
upper = np.where(y_train["event"], y_train["time"], np.inf)

# Create DMatrix with AFT bounds
dtrain = xgb.DMatrix(X_train, label_lower_bound=lower, label_upper_bound=upper)
dtest = xgb.DMatrix(X_test)

params = {
    "objective": "survival:aft",
    "eval_metric": "aft-nloglik",
    "aft_loss_distribution": "normal",
    "aft_loss_distribution_scale": 1.0,
    "tree_method": "hist",
    "learning_rate": 0.05,
    "max_depth": 3,
    "subsample": 0.8,
    "colsample_bytree": 0.8,
}

bst = xgb.train(params, dtrain, num_boost_round=300)

Évaluation du modèle final

from utils import evaluate_survival_model

risk_xgb = -bst.predict(dtest)

evaluate_survival_model(df_train, df_test, risk_xgb)
Loading...