Imports¶
import numpy as np
import polars as plChargement des données¶
from sksurv.util import Surv
df_train = pl.read_parquet("../../data/df_study_L18_w6_train.parquet")
df_test = pl.read_parquet("../../data/df_study_L18_w6_test.parquet")
X_train = df_train.drop(["event", "time"]).to_pandas()
X_test = df_test.drop(["event", "time"]).to_pandas()
y_train = Surv.from_dataframe("event", "time", df_train.to_pandas())
y_test = Surv.from_dataframe("event", "time", df_test.to_pandas())Entraînement du modèle¶
Le modèle est configuré en mode AFT (survival:aft) avec une distribution normale, ce qui signifie qu’il modélise .
import xgboost as xgb
# Extract time and event from the sksurv structured array
lower = y_train["time"]
upper = np.where(y_train["event"], y_train["time"], np.inf)
# Create DMatrix with AFT bounds
dtrain = xgb.DMatrix(X_train, label_lower_bound=lower, label_upper_bound=upper)
dtest = xgb.DMatrix(X_test)
params = {
"objective": "survival:aft",
"eval_metric": "aft-nloglik",
"aft_loss_distribution": "normal",
"aft_loss_distribution_scale": 1.0,
"tree_method": "hist",
"learning_rate": 0.05,
"max_depth": 3,
"subsample": 0.8,
"colsample_bytree": 0.8,
}
bst = xgb.train(params, dtrain, num_boost_round=300)Évaluation du modèle final¶
from utils import evaluate_survival_model
risk_xgb = -bst.predict(dtest)
evaluate_survival_model(df_train, df_test, risk_xgb)Loading...