Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Modèle Random Survival Forest

Authors
Affiliations
M2 MIASHS - Université de Lyon
M2 MIASHS - Université de Lyon
M2 MIASHS - Université de Lyon

Imports

import polars as pl

Chargement des données

from sksurv.util import Surv

df_train = pl.read_parquet("../../data/df_study_L18_w6_train.parquet")
df_test = pl.read_parquet("../../data/df_study_L18_w6_test.parquet")

X_train = df_train.drop(["event", "time"]).to_pandas()
X_test = df_test.drop(["event", "time"]).to_pandas()

y_train = Surv.from_dataframe("event", "time", df_train.to_pandas())
y_test = Surv.from_dataframe("event", "time", df_test.to_pandas())

Entraînement du modèle

from sksurv.ensemble import RandomSurvivalForest

rsf = RandomSurvivalForest(
    n_estimators=400,
    min_samples_split=10,
    min_samples_leaf=5,
    max_features="sqrt",
    n_jobs=-1,
)

rsf.fit(X_train, y_train)
Loading...

Évaluation du modèle final

from utils import evaluate_survival_model

risk_rsf = rsf.predict(X_test)
surv_rsf = rsf.predict_survival_function(X_test)
evaluate_survival_model(df_train, df_test, risk_rsf, surv_rsf)
Loading...