Imports¶
import polars as plChargement des données¶
from sksurv.util import Surv
df_train = pl.read_parquet("../../data/df_study_L18_w6_train.parquet")
df_test = pl.read_parquet("../../data/df_study_L18_w6_test.parquet")
X_train = df_train.drop(["event", "time"]).to_pandas()
X_test = df_test.drop(["event", "time"]).to_pandas()
y_train = Surv.from_dataframe("event", "time", df_train.to_pandas())
y_test = Surv.from_dataframe("event", "time", df_test.to_pandas())Entraînement du modèle¶
from sksurv.ensemble import RandomSurvivalForest
rsf = RandomSurvivalForest(
n_estimators=400,
min_samples_split=10,
min_samples_leaf=5,
max_features="sqrt",
n_jobs=-1,
)
rsf.fit(X_train, y_train)Loading...
Évaluation du modèle final¶
from utils import evaluate_survival_model
risk_rsf = rsf.predict(X_test)
surv_rsf = rsf.predict_survival_function(X_test)
evaluate_survival_model(df_train, df_test, risk_rsf, surv_rsf)Loading...