Skip to content

Commit b2e8c32

Browse files
Jupytercon tutorial WIP
Co-authored-by: Vincent M <maladiere.vincent@yahoo.fr>
1 parent 1599ea2 commit b2e8c32

14 files changed

+27585
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,4 @@ dmypy.json
140140
*.7z
141141
*.csv
142142
*.parquet
143+
*.npz

model_selection/wrappers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,31 @@ def predict_survival_function(self, X_test, times):
2222

2323
class PipelineWrapper(BaseWrapper):
2424

25-
def fit(self, X_train, y_train, times):
25+
def fit(self, X_train, y_train, times=None):
2626
last_est_name = self.estimator.steps[-1][0]
2727
times_kwargs = {f"{last_est_name}__times": times}
2828
self.estimator.fit(X_train, y_train, **times_kwargs)
29-
29+
3030
def predict_survival_function(self, X_test, times=None):
31-
return self.estimator.predict_survival_function(X_test)
31+
return self.estimator.predict_survival_function(X_test, times=times)
32+
33+
def predict_cumulative_incidence(self, X_test, times=None):
34+
transformers = self.estimator[:-1]
35+
X_test = transformers.transform(X_test)
36+
estimator = self.estimator[-1]
37+
return estimator.predict_cumulative_incidence(X_test, times=times)
38+
39+
def predict_quantile(self, X_test, quantile=0.5, times=None):
40+
transformers = self.estimator[:-1]
41+
X_test = transformers.transform(X_test)
42+
estimator = self.estimator[-1]
43+
return estimator.predict_quantile(X_test, quantile=quantile, times=times)
44+
45+
def predict_proba(self, X_test, time_horizon=None):
46+
transformers = self.estimator[:-1]
47+
X_test = transformers.transform(X_test)
48+
estimator = self.estimator[-1]
49+
return estimator.predict_proba(X_test, time_horizon=time_horizon)
3250

3351

3452
class SkurvWrapper(BaseWrapper):

0 commit comments

Comments
 (0)