Imports¶
import polars as plChargement des données¶
from sksurv.util import Surv
df_train = pl.read_parquet("../../data/df_study_L18_w6_train.parquet")
df_test = pl.read_parquet("../../data/df_study_L18_w6_test.parquet")
X_train = df_train.drop(["event", "time"]).to_pandas()
X_test = df_test.drop(["event", "time"]).to_pandas()
y_train = Surv.from_dataframe("event", "time", df_train.to_pandas())
y_test = Surv.from_dataframe("event", "time", df_test.to_pandas())Entraînement du modèle¶
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
gbs = GradientBoostingSurvivalAnalysis(
loss="coxph", learning_rate=0.1, n_estimators=100
)
gbs.fit(X_train, y_train)Loading...
Évaluation du modèle final¶
from utils import evaluate_survival_model
risk_gbs = gbs.predict(X_test)
surv_gbs = gbs.predict_survival_function(X_test)
evaluate_survival_model(df_train, df_test, risk_gbs, surv_gbs)Loading...