Skip to content

Commit 2e158fd

Browse files
Merge pull request #265 from tidymodels/naming-insensitivity
2 parents a1639e1 + b45cc16 commit 2e158fd

14 files changed

+397
-333
lines changed

tests/testthat/helper-objects.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,10 @@ spark_not_installed <- function() {
2828
}
2929
need_install
3030
}
31+
32+
# ------------------------------------------------------------------------------
33+
34+
expect_ptype <- function(x, ptype) {
35+
expect_equal(x[0, names(ptype)], ptype)
36+
}
37+

tests/testthat/test-survival-augment.R

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,17 @@ test_that("augmenting survival models", {
3232

3333
sr_aug <- augment(sr_fit, new_data = sim_tr, eval_time = time_points)
3434
expect_equal(nrow(sr_aug), nrow(sim_tr))
35-
expect_equal(names(sr_aug), c(".pred", ".pred_time", "event_time", "X1", "X2"))
35+
expect_named(
36+
sr_aug,
37+
c(".pred", ".pred_time", "event_time", "X1", "X2"),
38+
ignore.order = TRUE
39+
)
3640
expect_true(is.list(sr_aug$.pred))
37-
expect_equal(
38-
names(sr_aug$.pred[[1]]),
41+
expect_named(
42+
sr_aug$.pred[[1]],
3943
c(".eval_time", ".pred_survival", ".weight_time", ".pred_censored",
40-
".weight_censored")
44+
".weight_censored"),
45+
ignore.order = TRUE
4146
)
4247

4348
# proportional_hazards() -----------------------------------------------------
@@ -49,12 +54,17 @@ test_that("augmenting survival models", {
4954

5055
glmn_aug <- augment(glmn_fit, new_data = sim_tr, eval_time = time_points)
5156
expect_equal(nrow(glmn_aug), nrow(sim_tr))
52-
expect_equal(names(glmn_aug), c(".pred", ".pred_time", "event_time", "X1", "X2"))
57+
expect_named(
58+
glmn_aug,
59+
c(".pred", ".pred_time", "event_time", "X1", "X2"),
60+
ignore.order = TRUE
61+
)
5362
expect_true(is.list(glmn_aug$.pred))
54-
expect_equal(
55-
names(glmn_aug$.pred[[1]]),
63+
expect_named(
64+
glmn_aug$.pred[[1]],
5665
c(".eval_time", ".pred_survival", ".weight_time", ".pred_censored",
57-
".weight_censored")
66+
".weight_censored"),
67+
ignore.order = TRUE
5868
)
5969
})
6070

@@ -118,11 +128,16 @@ test_that("augment() works for tune_results", {
118128
)
119129

120130
expect_equal(nrow(aug_res), nrow(sim_tr))
121-
expect_equal(names(aug_res), c(".pred", ".pred_time", "event_time", "X1", "X2"))
131+
expect_named(
132+
aug_res,
133+
c(".pred", ".pred_time", "event_time", "X1", "X2"),
134+
ignore.order = TRUE
135+
)
122136
expect_true(is.list(aug_res$.pred))
123-
expect_equal(
124-
names(aug_res$.pred[[1]]),
125-
c(".eval_time", ".pred_survival", ".weight_censored")
137+
expect_named(
138+
aug_res$.pred[[1]],
139+
c(".eval_time", ".pred_survival", ".weight_censored"),
140+
ignore.order = TRUE
126141
)
127142

128143
expect_no_warning(
@@ -172,11 +187,16 @@ test_that("augment() works for resample_results", {
172187
aug_res <- augment(rs_mixed_res)
173188

174189
expect_equal(nrow(aug_res), nrow(sim_tr))
175-
expect_equal(names(aug_res), c(".pred", ".pred_time", "event_time", "X1", "X2"))
190+
expect_named(
191+
aug_res,
192+
c(".pred", ".pred_time", "event_time", "X1", "X2"),
193+
ignore.order = TRUE
194+
)
176195
expect_true(is.list(aug_res$.pred))
177-
expect_equal(
178-
names(aug_res$.pred[[1]]),
179-
c(".eval_time", ".pred_survival", ".weight_censored")
196+
expect_named(
197+
aug_res$.pred[[1]],
198+
c(".eval_time", ".pred_survival", ".weight_censored"),
199+
ignore.order = TRUE
180200
)
181201
})
182202

@@ -215,10 +235,15 @@ test_that("augment() works for last fit", {
215235
aug_res <- augment(rs_mixed_res)
216236

217237
expect_equal(nrow(aug_res), nrow(sim_te))
218-
expect_equal(names(aug_res), c(".pred", ".pred_time", "event_time", "X1", "X2"))
238+
expect_named(
239+
aug_res,
240+
c(".pred", ".pred_time", "event_time", "X1", "X2"),
241+
ignore.order = TRUE
242+
)
219243
expect_true(is.list(aug_res$.pred))
220-
expect_equal(
221-
names(aug_res$.pred[[1]]),
222-
c(".eval_time", ".pred_survival", ".weight_censored")
244+
expect_named(
245+
aug_res$.pred[[1]],
246+
c(".eval_time", ".pred_survival", ".weight_censored"),
247+
ignore.order = TRUE
223248
)
224249
})

tests/testthat/test-survival-fit-resamples.R

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ test_that("resampling survival models with static metric", {
6363
# test structure of results --------------------------------------------------
6464

6565
expect_false(".eval_time" %in% names(rs_static_res$.metrics[[1]]))
66-
expect_equal(
67-
names(rs_static_res$.predictions[[1]]),
68-
c(".pred_time", ".row", "event_time", ".config")
66+
expect_named(
67+
rs_static_res$.predictions[[1]],
68+
c(".pred_time", ".row", "event_time", ".config"),
69+
ignore.order = TRUE
6970
)
7071

7172
# test metric collection -----------------------------------------------------
@@ -82,7 +83,7 @@ test_that("resampling survival models with static metric", {
8283
)
8384

8485
expect_true(nrow(metric_sum) == 1)
85-
expect_equal(metric_sum[0,], exp_metric_sum)
86+
expect_ptype(metric_sum, exp_metric_sum)
8687
expect_true(all(metric_sum$.metric == "concordance_survival"))
8788

8889
metric_all <- collect_metrics(rs_static_res, summarize = FALSE)
@@ -96,7 +97,8 @@ test_that("resampling survival models with static metric", {
9697
)
9798

9899
expect_true(nrow(metric_all) == 10)
99-
expect_equal(metric_all[0,], exp_metric_all)
100+
expect_ptype(metric_all, exp_metric_all)
101+
100102
expect_true(all(metric_all$.metric == "concordance_survival"))
101103

102104
# test prediction collection -------------------------------------------------
@@ -110,11 +112,11 @@ test_that("resampling survival models with static metric", {
110112
)
111113

112114
unsum_pred <- collect_predictions(rs_static_res)
113-
expect_equal(unsum_pred[0,], static_ptype)
115+
expect_ptype(unsum_pred, static_ptype)
114116
expect_equal(nrow(unsum_pred), nrow(sim_tr))
115117

116118
sum_pred <- collect_predictions(rs_static_res, summarize = TRUE)
117-
expect_equal(sum_pred[0,], static_ptype[, names(static_ptype) != "id"])
119+
expect_ptype(sum_pred, static_ptype[, names(static_ptype) != "id"])
118120
expect_equal(nrow(sum_pred), nrow(sim_tr))
119121

120122
})
@@ -161,14 +163,17 @@ test_that("resampling survival models with integrated metric", {
161163
# test structure of results --------------------------------------------------
162164

163165
expect_false(".eval_time" %in% names(rs_integrated_res$.metrics[[1]]))
164-
expect_equal(
165-
names(rs_integrated_res$.predictions[[1]]),
166-
c(".pred", ".row", "event_time", ".config")
166+
expect_named(
167+
rs_integrated_res$.predictions[[1]],
168+
c(".pred", ".row", "event_time", ".config"),
169+
ignore.order = TRUE
167170
)
171+
168172
expect_true(is.list(rs_integrated_res$.predictions[[1]]$.pred))
169-
expect_equal(
170-
names(rs_integrated_res$.predictions[[1]]$.pred[[1]]),
171-
c(".eval_time", ".pred_survival", ".weight_censored")
173+
expect_named(
174+
rs_integrated_res$.predictions[[1]]$.pred[[1]],
175+
c(".eval_time", ".pred_survival", ".weight_censored"),
176+
ignore.order = TRUE
172177
)
173178
expect_equal(
174179
rs_integrated_res$.predictions[[1]]$.pred[[1]]$.eval_time,
@@ -189,7 +194,7 @@ test_that("resampling survival models with integrated metric", {
189194
)
190195

191196
expect_true(nrow(metric_sum) == 1)
192-
expect_equal(metric_sum[0,], exp_metric_sum)
197+
expect_ptype(metric_sum, exp_metric_sum)
193198
expect_true(all(metric_sum$.metric == "brier_survival_integrated"))
194199

195200
metric_all <- collect_metrics(rs_integrated_res, summarize = FALSE)
@@ -203,7 +208,7 @@ test_that("resampling survival models with integrated metric", {
203208
)
204209

205210
expect_true(nrow(metric_all) == 10)
206-
expect_equal(metric_all[0,], exp_metric_all)
211+
expect_ptype(metric_all, exp_metric_all)
207212
expect_true(all(metric_all$.metric == "brier_survival_integrated"))
208213

209214
# test prediction collection -------------------------------------------------
@@ -224,17 +229,17 @@ test_that("resampling survival models with integrated metric", {
224229
)
225230

226231
unsum_pred <- collect_predictions(rs_integrated_res)
227-
expect_equal(unsum_pred[0,], integrated_ptype)
232+
expect_ptype(unsum_pred, integrated_ptype)
228233
expect_equal(nrow(unsum_pred), nrow(sim_tr))
229234

230-
expect_equal(unsum_pred$.pred[[1]][0,], integrated_list_ptype)
235+
expect_ptype(unsum_pred$.pred[[1]], integrated_list_ptype)
231236
expect_equal(nrow(unsum_pred$.pred[[1]]), length(time_points))
232237

233238
sum_pred <- collect_predictions(rs_integrated_res, summarize = TRUE)
234-
expect_equal(sum_pred[0,], integrated_ptype[, names(integrated_ptype) != "id"])
239+
expect_ptype(sum_pred, integrated_ptype[, names(integrated_ptype) != "id"])
235240
expect_equal(nrow(sum_pred), nrow(sim_tr))
236241

237-
expect_equal(sum_pred$.pred[[1]][0,], integrated_list_ptype)
242+
expect_ptype(sum_pred$.pred[[1]], integrated_list_ptype)
238243
expect_equal(nrow(sum_pred$.pred[[1]]), length(time_points))
239244

240245
})
@@ -281,14 +286,18 @@ test_that("resampling survival models with dynamic metric", {
281286
# test structure of results --------------------------------------------------
282287

283288
expect_true(".eval_time" %in% names(rs_dynamic_res$.metrics[[1]]))
284-
expect_equal(
285-
names(rs_dynamic_res$.predictions[[1]]),
286-
c(".pred", ".row", "event_time", ".config")
289+
290+
expect_named(
291+
rs_dynamic_res$.predictions[[1]],
292+
c(".pred", ".row", "event_time", ".config"),
293+
ignore.order = TRUE
287294
)
288295
expect_true(is.list(rs_dynamic_res$.predictions[[1]]$.pred))
289-
expect_equal(
290-
names(rs_dynamic_res$.predictions[[1]]$.pred[[1]]),
291-
c(".eval_time", ".pred_survival", ".weight_censored")
296+
297+
expect_named(
298+
rs_dynamic_res$.predictions[[1]]$.pred[[1]],
299+
c(".eval_time", ".pred_survival", ".weight_censored"),
300+
ignore.order = TRUE
292301
)
293302
expect_equal(
294303
rs_dynamic_res$.predictions[[1]]$.pred[[1]]$.eval_time,
@@ -310,7 +319,7 @@ test_that("resampling survival models with dynamic metric", {
310319
)
311320

312321
expect_true(nrow(metric_sum) == length(time_points))
313-
expect_equal(metric_sum[0,], exp_metric_sum)
322+
expect_ptype(metric_sum, exp_metric_sum)
314323
expect_true(all(metric_sum$.metric == "brier_survival"))
315324

316325
metric_all <- collect_metrics(rs_dynamic_res, summarize = FALSE)
@@ -325,7 +334,7 @@ test_that("resampling survival models with dynamic metric", {
325334
)
326335

327336
expect_true(nrow(metric_all) == length(time_points) * nrow(sim_rs))
328-
expect_equal(metric_all[0,], exp_metric_all)
337+
expect_ptype(metric_all, exp_metric_all)
329338
expect_true(all(metric_all$.metric == "brier_survival"))
330339

331340
# test prediction collection -------------------------------------------------
@@ -346,17 +355,17 @@ test_that("resampling survival models with dynamic metric", {
346355
)
347356

348357
unsum_pred <- collect_predictions(rs_dynamic_res)
349-
expect_equal(unsum_pred[0,], dynamic_ptype)
358+
expect_ptype(unsum_pred, dynamic_ptype)
350359
expect_equal(nrow(unsum_pred), nrow(sim_tr))
351360

352-
expect_equal(unsum_pred$.pred[[1]][0,], dynamic_list_ptype)
361+
expect_ptype(unsum_pred$.pred[[1]], dynamic_list_ptype)
353362
expect_equal(nrow(unsum_pred$.pred[[1]]), length(time_points))
354363

355364
sum_pred <- collect_predictions(rs_dynamic_res, summarize = TRUE)
356-
expect_equal(sum_pred[0,], dynamic_ptype[, names(dynamic_ptype) != "id"])
365+
expect_ptype(sum_pred, dynamic_ptype[, names(dynamic_ptype) != "id"])
357366
expect_equal(nrow(sum_pred), nrow(sim_tr))
358367

359-
expect_equal(sum_pred$.pred[[1]][0,], dynamic_list_ptype)
368+
expect_ptype(sum_pred$.pred[[1]], dynamic_list_ptype)
360369
expect_equal(nrow(sum_pred$.pred[[1]]), length(time_points))
361370

362371
})
@@ -404,14 +413,16 @@ test_that("resampling survival models mixture of metric types", {
404413
# test structure of results --------------------------------------------------
405414

406415
expect_true(".eval_time" %in% names(rs_mixed_res$.metrics[[1]]))
407-
expect_equal(
408-
names(rs_mixed_res$.predictions[[1]]),
409-
c(".pred", ".row", ".pred_time", "event_time", ".config")
416+
expect_named(
417+
rs_mixed_res$.predictions[[1]],
418+
c(".pred", ".row", ".pred_time", "event_time", ".config"),
419+
ignore.order = TRUE
410420
)
411421
expect_true(is.list(rs_mixed_res$.predictions[[1]]$.pred))
412-
expect_equal(
413-
names(rs_mixed_res$.predictions[[1]]$.pred[[1]]),
414-
c(".eval_time", ".pred_survival", ".weight_censored")
422+
expect_named(
423+
rs_mixed_res$.predictions[[1]]$.pred[[1]],
424+
c(".eval_time", ".pred_survival", ".weight_censored"),
425+
ignore.order = TRUE
415426
)
416427
expect_equal(
417428
rs_mixed_res$.predictions[[1]]$.pred[[1]]$.eval_time,
@@ -433,7 +444,7 @@ test_that("resampling survival models mixture of metric types", {
433444
)
434445

435446
expect_true(nrow(metric_sum) == length(time_points) + 2)
436-
expect_equal(metric_sum[0,], exp_metric_sum)
447+
expect_ptype(metric_sum, exp_metric_sum)
437448
expect_true(sum(is.na(metric_sum$.eval_time)) == 2)
438449
expect_equal(as.vector(table(metric_sum$.metric)), c(length(time_points), 1L, 1L))
439450

@@ -449,7 +460,7 @@ test_that("resampling survival models mixture of metric types", {
449460
)
450461

451462
expect_true(nrow(metric_all) == (length(time_points) + 2) * nrow(sim_rs))
452-
expect_equal(metric_all[0,], exp_metric_all)
463+
expect_ptype(metric_all, exp_metric_all)
453464
expect_true(sum(is.na(metric_all$.eval_time)) == 2* nrow(sim_rs))
454465
expect_equal(as.vector(table(metric_all$.metric)), c(length(time_points), 1L, 1L) * nrow(sim_rs))
455466

@@ -472,17 +483,17 @@ test_that("resampling survival models mixture of metric types", {
472483
)
473484

474485
unsum_pred <- collect_predictions(rs_mixed_res)
475-
expect_equal(unsum_pred[0,], mixed_ptype)
486+
expect_ptype(unsum_pred, mixed_ptype)
476487
expect_equal(nrow(unsum_pred), nrow(sim_tr))
477488

478-
expect_equal(unsum_pred$.pred[[1]][0,], mixed_list_ptype)
489+
expect_ptype(unsum_pred$.pred[[1]], mixed_list_ptype)
479490
expect_equal(nrow(unsum_pred$.pred[[1]]), length(time_points))
480491

481492
sum_pred <- collect_predictions(rs_mixed_res, summarize = TRUE)
482-
expect_equal(sum_pred[0,], mixed_ptype[, names(mixed_ptype) != "id"])
493+
expect_ptype(sum_pred, mixed_ptype[, names(mixed_ptype) != "id"])
483494
expect_equal(nrow(sum_pred), nrow(sim_tr))
484495

485-
expect_equal(sum_pred$.pred[[1]][0,], mixed_list_ptype)
496+
expect_ptype(sum_pred$.pred[[1]], mixed_list_ptype)
486497
expect_equal(nrow(sum_pred$.pred[[1]]), length(time_points))
487498

488499
# test show_best() -----------------------------------------------------------

0 commit comments

Comments
 (0)