@@ -22,13 +22,31 @@ def predict_survival_function(self, X_test, times):
22
22
23
23
class PipelineWrapper (BaseWrapper ):
24
24
25
- def fit (self , X_train , y_train , times ):
25
+ def fit (self , X_train , y_train , times = None ):
26
26
last_est_name = self .estimator .steps [- 1 ][0 ]
27
27
times_kwargs = {f"{ last_est_name } __times" : times }
28
28
self .estimator .fit (X_train , y_train , ** times_kwargs )
29
-
29
+
30
30
def predict_survival_function (self , X_test , times = None ):
31
- return self .estimator .predict_survival_function (X_test )
31
+ return self .estimator .predict_survival_function (X_test , times = times )
32
+
33
+ def predict_cumulative_incidence (self , X_test , times = None ):
34
+ transformers = self .estimator [:- 1 ]
35
+ X_test = transformers .transform (X_test )
36
+ estimator = self .estimator [- 1 ]
37
+ return estimator .predict_cumulative_incidence (X_test , times = times )
38
+
39
+ def predict_quantile (self , X_test , quantile = 0.5 , times = None ):
40
+ transformers = self .estimator [:- 1 ]
41
+ X_test = transformers .transform (X_test )
42
+ estimator = self .estimator [- 1 ]
43
+ return estimator .predict_quantile (X_test , quantile = quantile , times = times )
44
+
45
+ def predict_proba (self , X_test , time_horizon = None ):
46
+ transformers = self .estimator [:- 1 ]
47
+ X_test = transformers .transform (X_test )
48
+ estimator = self .estimator [- 1 ]
49
+ return estimator .predict_proba (X_test , time_horizon = time_horizon )
32
50
33
51
34
52
class SkurvWrapper (BaseWrapper ):
0 commit comments