-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchapter_18.qmd
724 lines (609 loc) · 34.2 KB
/
chapter_18.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
---
title: "Visualization and interpretation of individualized treatment rule results"
authors:
- name: Xiaotong Jiang
orcid: 0000-0003-3698-4526
affiliations:
- ref: biogen
affiliations:
- id: smartdas
name: Smart Data Analysis and Statistics B.V.
city: Utrecht
- id: biogen
name: Biogen
city: Cambridge, MA, US
format:
html:
toc: true
number-sections: true
execute:
cache: true
bibliography: 'https://api.citedrive.com/bib/0d25b38b-db8f-43c4-b934-f4e2f3bd655a/references.bib?x=eyJpZCI6ICIwZDI1YjM4Yi1kYjhmLTQzYzQtYjkzNC1mNGUyZjNiZDY1NWEiLCAidXNlciI6ICIyNTA2IiwgInNpZ25hdHVyZSI6ICI0MGFkYjZhMzYyYWE5Y2U0MjQ2NWE2ZTQzNjlhMWY3NTk5MzhhNzUxZDNjYWIxNDlmYjM4NDgwOTYzMzY5YzFlIn0=/bibliography.bib'
---
In this tutorial, we will walk you through the code that implemented the precision medicine methods and generated the visualization results discussed in Chapter 18 of the book. This tutorial focuses more on helping you understand the code. We will not provide detailed interpretation of the results as they have been covered in the chapter already.
## Introduction
We first load all relevant functions for this chapter.
```{r constant}
source("resources/chapter 18/functions.r")
```
```{r setup}
#| echo: false
#| include: false
#| warning: false
library(tidyverse)
library(magrittr)
library(DTRreg)
library(precmed)
library(rpart)
library(rpart.plot)
library(tableone)
library(gbm)
library(caret)
library(corrplot)
library(MASS)
library(ggpubr)
library(knitr)
library(kableExtra)
library(table1)
```
Subsequently, we use the function `simcountdata()` to generate an example dataset with a sample size of N=2000. In this example, we have two disease modifying therapies (DMT1 and DMT0) and the outcome is the number of post-treatment multiple sclerosis relapses during follow-up.
```{r readdata}
# Randomization seed
base.seed <- 999
set.seed(base.seed)
df.ori <- simcountdata(n = 2000,
seed = 63,
beta = c(log(0.4), log(0.5), log(1), log(1.1), log(1.2)),
beta.x = c(-1.54, -0.01, 0.06, 0.25, 0.5, 0.13, 0.0000003)
)$data
```
The dataset looks as follows:
```{r}
head(df.ori)
```
Below is a summary table of the baseline characteristics by treatment group.
```{r}
#| echo: false
#| warning: false
df.tbl <- df.ori %>% mutate(age = ageatindex_centered + 48,
gender = ifelse(female == 1, "female", "male"),
prerelapse_numf = as.factor(prerelapse_num),
numSymptoms = factor(numSymptoms, labels = c("0", "1", ">=2"),
levels = c("0", "1" , ">=2")))
label(df.tbl$age) <- "Age"
units(df.tbl$age) <- "years"
label(df.tbl$gender) <- "Gender"
label(df.tbl$prerelapse_numf) <- "Previous number of relapses"
label(df.tbl$prevDMTefficacy) <- "Efficacy of previous disease modifying therapy"
label(df.tbl$premedicalcost) <- "Previous medical cost"
units(df.tbl$premedicalcost) <- "\\$"
label(df.tbl$numSymptoms) <- "Previous number of symptoms"
table1(~ age + gender + prerelapse_numf + prevDMTefficacy + premedicalcost + numSymptoms | trt, data = df.tbl, caption = "Baseline characteristics of the case study data")
```
We now define key constants for the case study.
```{r}
# Baseline characteristics
covars <- c("age.z", "female", "prevtrtB", "prevtrtC", "prevnumsymp1",
"prevnumsymp2p", "previous_cost.z", "previous_number_relapses")
# Precision medicine methods to be used
pm.methods <- c("all1", "all0", "poisson", "dWOLS", "listDTR2",
"contrastReg")
# Precision medicine method labels
method.vec <- c("All 0", "All 1", "Poisson", "dWOLS",
"Contrast\n Regression", "List DTR\n (2 branches)")
# Number of folds in each CV iteration
n.fold <- 5
# Number of CV iterations
n.cv <- 10
# Sample size of the large independent test set to get true value
big.n <- 100000
# Define formula for the CATE model
cate.formula <- as.formula(paste0("y ~", paste0(covars, collapse = "+"),
"+ offset(log(years))"))
# Define formula for the propensity score model
ps.formula <- trt ~ age.z + prevtrtB + prevtrtC
# Color
myblue <- rgb(37, 15, 186, maxColorValue = 255)
mygreen <- rgb(109, 173, 70, maxColorValue = 255)
mygrey <- rgb(124, 135, 142, maxColorValue = 255)
```
The data need to be preprocessed to be more analyzable. We recategorized treatment, previous treatment, and number of symptoms; scaled medical cost and age; and standardized the data.
```{r preprocess, eval = T, echo = T}
df <- df.ori %>%
rename(previous_treatment = prevDMTefficacy,
age = ageatindex_centered,
y = postrelapse_num,
previous_number_relapses = prerelapse_num,
previous_number_symptoms = numSymptoms,
previous_cost = premedicalcost) %>%
mutate(previous_treatment = factor(previous_treatment,
levels = c("None",
"Low efficacy",
"Medium and high efficacy"),
labels = c("drugA", "drugB", "drugC")),
previous_number_symptoms = factor(previous_number_symptoms,
levels = c("0", "1", ">=2"),
labels = c("0", "1", ">=2")),
trt = factor(trt, levels = c(0, 1), labels = c("drug0", "drug1")),
previous_cost.z = scale(log(previous_cost), scale = TRUE),
age.z = age + 48,
age.z = scale(age.z, scale = TRUE),
years = finalpostdayscount / 365.25,
mlogarr0001 = -log(y / years + 0.001),
drug1 = as.numeric(trt == "drug1"),
prevtrtB = as.numeric(previous_treatment == "drugB"),
prevtrtC = as.numeric(previous_treatment == "drugC"),
prevnumsymp1 = as.numeric(previous_number_symptoms == "1"),
prevnumsymp2p = as.numeric(previous_number_symptoms == ">=2")) %>%
dplyr::select(age.z, female, contains("prevtrt"), previous_cost.z,
contains("prevnumsymp"),
previous_number_relapses, trt, drug1, y,
mlogarr0001, years, Iscore)
# Standardize data
z.labels <- c("age.z", "previous_cost.z")
df.s <- df
df.s[, setdiff(covars, z.labels)] <- df[, setdiff(covars, z.labels)]
```
## Estmition of individualized treatment rules
The following code provides details of how to implement the precision medicine methods in the example data. Please feel free to jump to the next section if you want to focus on the results. The model results are available online for you to load and save time.
We used the function `listdtr()` in the **listdtr** package to estimate individualized treatment rules (ITRs) based on the listDTR method. We used the function `catefit()` in the **precmed** package to estimate ITRs based on the Poisson and contrast regression method. These were the methods used in Section 3 of the book where we talked about directly visualizing the ITR before bringing in the outcomes. The methods are discussed in further detail by @zhao_effectively_2013 and @yadlowsky_estimation_2021.
```{r apply.s3}
#| eval: false
#| echo: true
library(listdtr)
# Estimated ITR based on the listDTR method with 2 branches
modlist2 <- listdtr(y = df$mlogarr, # larger is more favorable
a = df$drug1,
x = df[, c("age.z", "female", "prevtrtB",
"prevtrtC", "previous_cost.z",
"prevnumsymp1", "prevnumsymp2p",
"previous_number_relapses")],
stage.x = rep(1, 8), maxlen = 2L) # somewhat slow
# Estimated ITR based on the listDTR method with 3 branches
modlist3 <- listdtr(y = df$mlogarr,
a = df$drug1,
x = df[, c("age.z", "female", "prevtrtB",
"prevtrtC", "previous_cost.z",
"prevnumsymp1", "prevnumsymp2p",
"previous_number_relapses")],
stage.x = rep(1, 8), maxlen = 3L) # somewhat slow
# Estimated CATE score based on the Poisson and contrast regression
modpm <- catefit(response = "count",
cate.model = cate.formula,
ps.model = ps.formula,
data = df,
higher.y = FALSE,
score.method = c("poisson", "contrastReg"),
initial.predictor.method = "poisson",
seed = 999)
# Estimated CATE score based on the Poisson and contrast regression
# (based on the scaled data so the coefficients are easier to compare)
modpm.s <- catefit(response = "count",
cate.model = cate.formula,
ps.model = ps.formula,
data = df.s,
higher.y = FALSE,
score.method = c("poisson", "contrastReg"),
initial.predictor.method = "poisson",
seed = 999)
```
For results in Sections 4 and 5, we applied cross validation to mitigate over-fitting. For this chapter, we created our own customized function `cvvalue()` to estimate the ITR and calculate the estimated value function via cross validation for all methods, including the fixed method. The results were all saved under the prefix `cvmod`. The **precmed** package has a built-in cross validation procedure for CATE estimation so we used the function `catefit()`.
```{r apply.s4s5}
#| eval: false
#| echo: true
# Run cross validation for each method (used for Sections 4 & 5)
## Estimated CATE scores based on the Poisson and contrast regression with cross-validation
modcv <- catecv(response = "count",
cate.model = cate.formula,
ps.model = ps.formula,
data = df,
higher.y = FALSE,
score.method = c("poisson", "contrastReg"),
initial.predictor.method = "poisson",
cv.n = n.cv,
plot.gbmperf = FALSE,
seed = 999) # somewhat slow
## Estimated value function for each method
cvmodall0 <- cvvalue(data = df, xvar = covars,
method = "all0", n.fold = n.fold, n.cv = n.cv,
seed = base.seed)
cvmodall1 <- cvvalue(data = df, xvar = covars,
method = "all1", n.fold = n.fold, n.cv = n.cv,
seed = base.seed)
cvmoddwols <- cvvalue(data = df, xvar = covars,
method = "dWOLS", n.fold = n.fold, n.cv = n.cv,
seed = base.seed)
cvmodpois <- cvvalue(data = df, xvar = covars,
method = "poisson", n.fold = n.fold, n.cv = n.cv,
seed = base.seed)
cvmodlist2 <- cvvalue(data = df, xvar = covars,
method = "listDTR2", n.fold = n.fold, n.cv = n.cv,
seed = base.seed) # very slow
cvmodcontrastreg <- cvvalue(data = df, xvar = covars,
method = "contrastReg", n.fold = n.fold,
n.cv = n.cv,
seed = base.seed) # very slow
```
```{r load}
#| echo: false
modlist2 <- readRDS("resources/chapter 18/modlist2.RData")
modlist3 <- readRDS("resources/chapter 18/modlist3.RData")
modpm <- readRDS("resources/chapter 18/modpm.RData")
modpm.s <- readRDS("resources/chapter 18/modpm.s.RData")
modcv <- readRDS("resources/chapter 18/modcv.RData")
cvmoddwols <- readRDS("resources/chapter 18/cvmoddwols.RData")
cvmodpois <- readRDS("resources/chapter 18/cvmodpois.RData")
cvmodall1 <- readRDS("resources/chapter 18/cvmodall1.RData")
cvmodall0 <- readRDS("resources/chapter 18/cvmodall0.RData")
cvmodlist2 <- readRDS("resources/chapter 18/cvmodlist2.RData")
cvmodcontrastreg <- readRDS("resources/chapter 18/cvmodcontrastreg.RData")
```
As a next step, we need to combine all estimated ITRs and value functions:
```{r combine}
# Combine CV results
# Read in each CV result in a loop
vhats.dhat <- dhats <- NULL
mod_names <- c("cvmodall1", "cvmodall0", "cvmoddwols", "cvmodpois", "cvmodcontrastreg", "cvmodlist2")
for (mod in mod_names) {
thismod <- get(mod)
for (name in names(thismod)) {
# Get estimated values, vhat.dhat
vhats.dhat <- rbind(vhats.dhat,
thismod[[name]] %>%
map_df(~bind_rows(names(.x) %>% str_detect("vhat.dhat") %>% keep(.x, .)), .id = "fold") %>%
mutate(method = mod, cv.i = name))
# Get estimated rule from CV test fold, dhat
dhats <- rbind(dhats,
thismod[[name]] %>%
map_df(~bind_rows(names(.x) %>% str_detect("^dhat$") %>% keep(.x, .)), .id = "fold") %>%
mutate(method = mod, cv.i = name))
}
}
# One time run to get true optimal and worst value
# Simulated data only
trueV <- getTrueOptimalValue(n = big.n, seed = base.seed)
trueWorstV <- getTrueWorstValue(n = big.n, seed = base.seed)
# Preprocess
vhats.dhat %<>%
mutate(V = U/W,
VR = (U/W - trueWorstV) / (trueV - trueWorstV)) %>%
group_by(method) %>%
summarize(n.batches = n(),
n.nonnaU = sum(!is.na(U)),
n.nonnaW = sum(!is.na(W)),
meanV = mean(V, na.rm = T),
sdV = sd(V, na.rm = T),
meanVR = mean(VR, na.rm = T),
sdVR = sd(VR, na.rm = T),
.groups = "keep") %>%
ungroup %>%
arrange(desc(meanV)) %>%
mutate(method = case_when(
method == "cvmodcontrastreg" ~ "Contrast\n Regression",
method == "cvmodall0" ~ "All 0",
method == "cvmodall1" ~ "All 1",
method == "cvmodlist2" ~ "List DTR\n (2 branches)",
method == "cvmoddwols" ~ "dWOLS",
method == "cvmodpois" ~ "Poisson"),
method = factor(method,
levels = method.vec,
labels = method.vec)
)
dhats %<>%
mutate(method = case_when(
method == "cvmodcontrastreg" ~ "Contrast\n Regression",
method == "cvmodall0" ~ "All 0",
method == "cvmodall1" ~ "All 1",
method == "cvmodlist2" ~ "List DTR\n (2 branches)",
method == "cvmoddwols" ~ "dWOLS",
method == "cvmodpois" ~ "Poisson"),
method = factor(method,
levels = method.vec,
labels = method.vec)
)
```
## Visualization of individualized treatment rules
### Direct visualization
#### listDTR
If the PM method already has built-in visualization (especially for tree-based methods), we can visualize the ITR directly. For example, we can simply use the function `plot()` to visualize the estimated ITR with the listDTR method.
```{r itr.list, echo = T, eval = T}
#modlist3 %>% plot()
```
We can also create our own visualization like Figure 1A in the chapter.
```{r itr.list.scatter, echo = T, eval = T}
df.list3 <- df %>%
mutate(d.list = ifelse(age.z > 0.599 | prevtrtB > 0.5, "Recommend 1", "Recommend 0"), # based on modlist3
Rule = factor(as.character(d.list), levels = c("Recommend 0", "Recommend 1")),
prevtrtB = ifelse(prevtrtB == 1, "Previous treatment is drug B", "Previous treatment is not drug B")
)
## Figure 1A
df.list3 %>%
ggplot(aes(x = age.z, fill = Rule))+
geom_histogram(position = position_dodge2(preserve = 'single'), binwidth = 0.1)+
facet_wrap(~ prevtrtB, nrow = 2) +
scale_fill_brewer(palette = "Set1") +
labs(x = "Standardized age", y = "Count") +
theme_classic() +
theme(legend.position = 'top', text = element_text(size = 20))
```
The subgroup-level annualized relapse rate (ARR) can be calculated based on the listDTR ITR:
```{r itr.list.arr, echo = T, eval = T}
df.list3 %>%
group_by(trt, d.list) %>%
summarise(ARR = round(sum(y) / sum(years), 2),
n = n(),
`prop%` = round(n / nrow(df), 2)*100, .groups = "drop") %>%
rename("listDTR ITR" = d.list,
"Observed treatment" = trt) %>%
kable() %>%
kable_styling(full_width = F)
```
Patients who received drug 0 and were recommended drug 0 by listDTR had a similar ARR on average than those who received drug 0 but were recommended drug 1 (0.32 vs 0.31). Patients who received drug 1 and were recommended drug 1 by listDTR had a much lower ARR on average than those who received drug 1 but were recommended drug 0 (0.16 vs 0.39).
#### Score-based method
Although some PM methods do not have built-in visualization or not as "white-box" as some more interpretable methods, there still might be ways to visualize the ITR. For example, score-based methods (such as Poisson and contrast regression) produce an estimate of the CATE score for each patient, and a classification tree can be fitted on these scores and visualized. Below is a histogram-density plot of the CATE scores estimated from the Poisson regression and the fitted classification tree using the estimated CATE scores. We pruned the tree so it only had three nodes for simplicity. The `rpart.plot` package has a built-in visualization function of the `rpart` model, `rpart.plot()`, which is how Figure 1B in the chapter was generated.
```{r itr.tree, echo = T, eval = T, result = "asit"}
df["score.poisson"] <- modpm$score.poisson
ggplot(df, aes(x = score.poisson)) +
geom_histogram(aes(y = ..density..), colour = "black", fill = "lightblue") +
geom_density(alpha = .2, fill = "white") +
labs(x = "Estimated CATE score from the Poisson regression", y = "Density") +
theme_classic()
modtree <- rpart(as.formula(paste0("score.poisson ~", paste0(covars, collapse = "+"))),
method = "anova", data = df, control = rpart.control(minsplit = 100, cp = 0.01)) # Fit Poisson CATE scores on a classification tree
modtree.pr <- prune(modtree, cp = 0.09) # I ended up choosing a higher cp value to have only 3 subgroups
# print(rpart.rules(modtree.pr, cover = TRUE))
## Figure 1B
rpart.plot(modtree.pr, box.palette = "RdBu", type = 5, under = TRUE, tweak = 1.2, compress = TRUE)
```
The CATE scores are now simplified as a tree classifier. Previous treatment of drug B and age seemed to be important in determining the CATE score values, which also showed up in the estimated from the listDTR method. Patients with previous treatment of drug B had the lowest CATE score on average (-0.76) and took up 41% of the samples (dark orange). Patients whose previous treatment was not drug B and age was >= -0.22 standard deviation of the mean also had a negative CATE score on average (-0.23) and took up 36% of the samples (light orange), but not as low as the dark orange group. Negative CATE scores mean that the number of relapses was expected to be lower for those recommended drug 1 than those recommended drug 0, so drug 1 was favored for them. For the blue group, the average CATE score was 0.13, taking up 23% of the samples, and they were expected to benefit from drug 0 based on the Poisson CATE scores.
### ITR accuracy
The accuracy of ITR is the proportion of patients whose estimated ITR is the same as the true optimal ITR. The estimated ITRs have been obtained from the PM methods but we need to calculate the true optimal ITR. This is only possible for simulated data where the decision boundary is known. Based on the data generating mechanism in `simcountdata()`, `Iscore` is a score generated from a linear combination of baseline covariates where lower scores represented that drug 1 was better and higher scores represented that drug 0 was better. We then classified patients in 5 equal-size subgroups based on the `Iscore`, where groups 1 and 2 have drug 1 as their true optimal ITR and groups 3 and 4 have drug 0 as their true optimal ITR. Group 3 is considered the neutral group, where patients are indifferent to either drug so we assign the true optimal ITR to be their observed treatment. Thus, we identify the true optimal ITR for every patient based on this subgrouping, which was derived from their true score `Iscore`. Since we used cross validation in estimating the ITR, we need to apply the exact same cross validation to the true optimal ITR. This is achieved by specifying the same randomization seed in the cross validation loop (see `seed`).
```{r optd}
#| echo: true
#| eval: true
## Create new columns
dhats$d <- rep(NA, nrow(dhats)) # true d
# Identify the true optimal treatment
# See simcountdata() in the function script to learn more about Iscore
sim <- df %>%
mutate(trueT = ifelse(as.numeric(Iscore) < 3, 1, 0),
trueT = ifelse(Iscore == 3, drug1, trueT)) # neutral group
# Format data
input <- data.frame(y = sim$y,
trt = sim$drug1,
time = log(sim$years),
sim[covars])
# Cross validation loop
for (i in unique(dhats$cv.i)) {
seed <- base.seed*100 + as.numeric(str_extract(i, "[0-9]+"))
set.seed(seed)
# Create CV folds
folds <- createFolds(input$trt, k = n.fold, list = TRUE)
for (fold.i in seq(n.fold)) {
testdata <- sim[folds[[fold.i]],]
# number of methods which succeeded for the given fold/batch.
ind_result <- which(dhats$fold == paste0("fold", fold.i) &
dhats$cv.i == i &
!is.na(dhats$dhat))
nr <- length(ind_result)
dhats$d[ind_result] <- rep(testdata$trueT, nr/nrow(testdata))
stopifnot(nr %% nrow(testdata) == 0)
}
} # end of all cv iterations
```
Once we identified the true optimal ITR ($d^{opt}$), we can calculate the accuracy in each validation fold for each PM method ($\hat{d}_{pm}$). Mathematically, accuracy can be expressed as
$$Accuracy_{pm}(\boldsymbol{x}^{val}) = \frac{1}{n^{val}}\sum_{i = 1}^{n^{val}} I\big(\hat{d}_{pm}(\boldsymbol{x}_i^{val}) == d^{opt}(\boldsymbol{x}_i^{val})\big),$$
where $n^{val}$ is the sample size in the validation fold, $\boldsymbol{x}_i^{val}$ is the baseline characteristics of the $i$th patient in the validation fold, and $pm$ stands for one PM method.
Below is how Figure 2 in the chapter was generated. It summarized the accuracy across all validation folds as a box plot so we can also learn the variability of accuracy across folds.
```{r accuracy, echo = T, eval = T}
##### Accuracy #####
## Calculate % accuracy for each iteration & summary statistics
dhats.accuracy <- dhats %>%
group_by(method, cv.i, fold) %>%
summarise(accuracy = sum(dhat == d)/n(), .groups = "drop") %>%
ungroup
## Make the accuracy plot, Figure 2
dhats.accuracy %>%
ggplot(aes(x = method, y = accuracy)) +
geom_boxplot() +
geom_hline(yintercept = 1, linetype = 2, linewidth = 1, color = "gray") +
geom_hline(yintercept = 0.5, linetype = 2, linewidth = 1, color = "gray") +
theme_classic() +
labs(x = "Method", y = "Accuracy") +
theme(axis.text = element_text(size = 15),
axis.title.y = element_text(size = 15),
axis.title.x = element_text(size = 15),
axis.text.x = element_text(angle = 0, size = 15),
strip.text.x = element_text(size = 15))
```
### ITR agreement
When we do not know the true data generating mechanism, e.g., real-world data, we cannot compare the estimated ITR with the true optimal ITR. However, we can compare the estimated ITR with another estimated ITR, and this is called agreement. Agreement is the proportion of patients whose estimated ITR of a method is the same as the estimated ITR of another method. Thus, agreement is between two methods. Mathematically,
$$Agreement_{1, 2}(\boldsymbol{x}^{val}) = \frac{1}{n^{val}} \sum_{i = 1}^{n^{val}} I\big( \hat{d}_{1}(\boldsymbol{x}^{val}) == \hat{d}_{2} (\boldsymbol{x}^{val}) \big), $$
where $n^{val}$ is the sample size in the validation fold, $\boldsymbol{x}_i^{val}$ is the baseline characteristics of the $i$th patient in the validation fold, and $1, 2$ stands for method 1 and method 2.
```{r agreement, echo = T, eval = T, fig.height = 8, fig.width = 8}
##### Agreement #####
dhats.concat <- dhats %>%
arrange(cv.i, fold, method) %>%
mutate(iteration.fold = (as.numeric(str_extract(cv.i, "[0-9]+")) - 1) * 10 + as.numeric(str_extract(fold, "[0-9]+"))) %>%
dplyr::select(method, iteration.fold, dhat) %>%
group_by(method, iteration.fold) %>%
mutate(i = 1:n()) %>%
ungroup
m <- length(method.vec)
dhats.agreement <- matrix(nrow = m, ncol = m)
colnames(dhats.agreement) <- method.vec
rownames(dhats.agreement) <- method.vec
for(k in seq_len(m)){
for(j in seq(k, m)){
data.k <- dhats.concat %>% filter(method == method.vec[k])
data.j <- dhats.concat %>% filter(method == method.vec[j])
data.jk <- data.k %>% full_join(data.j, by = c("iteration.fold", "i"))
dhats.agreement[k, j] <- dhats.agreement[j, k] <- sum(data.jk$dhat.x == data.jk$dhat.y, na.rm = T) / sum(is.na(data.jk$dhat.x) == FALSE & is.na(data.jk$dhat.y) == FALSE)
}
}
# Make the agreement plot, Figure 3
corrplot(dhats.agreement, method = "color", type = "lower",
addCoef.col = "orange", number.cex = 1.5,
tl.cex = 1.2, cl.cex = 1.2, tl.col = "black", tl.srt = 0, tl.offset = 1.5)
```
We used the `corrplot` package to generate Figure 3 in the chapter but agreement can be visualized in other creative ways that you prefer.
## Patient well-being
Patient well-being is evaluated via the value function, which is defined as the expected outcome had they followed the specified ITR. Like a fortune teller's crystal ball, this metric tells us how well the patients would do on average under each ITR. We can then compare across different ITRs and identify an optimal ITR. Cross validation is necessary here to mitigate over-fitting, and we visualized the value function results as error bar plots. The mean and standard deviation of the value functions have been preprocessed previously. We use `ggplot()` to generate the error bar plots. Figure 4A is the original value function estimates, and Figure 4B is the standardized value ratio estimates, which convert value functions to a ratio where 1 is always more desirable.
```{r value, echo = T, eval = T, fig.width = 15, fig.height = 7}
##### Errorbar plot #####
# Figure 4A
p4a <- vhats.dhat %>%
ggplot(aes(x = method, y = meanV)) +
geom_point(size = 8, shape = 16, color = "navy") +
geom_errorbar(aes(ymin = meanV - sdV, ymax = meanV + sdV), width = 0.3, size = 2, position = position_dodge(0.9), color = "navy") +
theme_classic() + xlab("") + ylab("Cross-validated value (mean +- SD)") +
theme(axis.text = element_text(size = 15), axis.title.y = element_text(size = 15)) +
geom_hline(yintercept = vhats.dhat$meanV[which(vhats.dhat$method == "All 1")], linetype = 2, size = 1.5, color = "gray")
##### Value ratio #####
# Figure 4B
p4b <- vhats.dhat %>%
dplyr::select(method, contains("VR"), n.nonnaU) %>%
ggplot(aes(x = method, y = meanVR)) +
geom_point(size = 8, color = "navy") +
geom_errorbar(aes(ymin = meanVR - sdVR, ymax = meanVR + sdVR), width = 0.3, size = 2, position = position_dodge(0.9), color = "navy") +
geom_hline(yintercept = 1, color = "gray", linetype = 2, size = 1) +
geom_hline(yintercept = 0, color = "gray", linetype = 2, size = 1) +
scale_y_continuous(breaks = seq(0, 1, length = 6)) +
theme_classic() +
labs(x = "", y = "Value ratio of cross-validated estimated decision rule (mean +- SD)") +
theme(axis.text = element_text(size = 13),
axis.title.y = element_text(size = 15),
axis.title.x = element_text(size = 15),
axis.text.x = element_text(size = 15),
strip.text.x = element_text(size = 12))
# Figure 4
ggarrange(p4a, p4b, ncol = 2, nrow = 1, labels = c("A", "B"))
```
## Responder diagnostics
### Validation
The package we used for the two score-based methods (Poisson and contrast regression), `PrecMed`, has built-in visualization tools to diagnose the results: validation box plots `boxplot()`, validation curves `plot()`, and area between curves (ABC) statistics `abc()`.
```{r validation, echo = T, eval = T, fig.width = 8, fig.height = 12}
##### Validation of ITR scores #####
# Figure 5A
p5a <- boxplot(modcv, ylab = "Rate ratio between T=1 and T=0 in each subgroup")
# Figure 5B
p5b <- plot(modcv, ylab = "Rate ratio between T=1 and T=0 in each subgroup")
# Figure 5
ggarrange(p5a, p5b, ncol = 1, nrow = 2, labels = c("A", "B"))
```
The `PrecMed` package has more PM methods implemented other than Poisson and contrast regression that you can try, such as negative binomial and two regressions. See its documentation for more details.
### Univariate comparision of patient characteristics
The 60/40 cutoff was used in the chapter to split patients into "high responders" and "standard responders". The function `CreateTableOne()` in the `tableone` package was used to generate a table comparing side-by-side the baseline characteristics between the two responder groups.
```{r unicomp.table, echo = T, eval = T, warning = F}
##### Side-by-side baseline characteristic comparison between responder subgroups #####
cutoff <- quantile(modpm$score.poisson, 0.6) # 60/40 high vs standard responder split
df["responder to T=1"] <- ifelse(modpm$score.poisson < cutoff, "High", "Standard")
df["age.z"] <- as.numeric(df$age.z)
df["previous_cost.z"] <- as.numeric(df$previous_cost.z)
labs <- list("age.z" = "Standardized baseline age", "female" = "Female",
"prevtrtB" = "Previous treatment drug B", "prevtrtC" = "Previous treatment drug C",
"prevnumsymp1" = "Previous number of symptoms == 1",
"prevnumsymp2p" = "Previuos number of symptoms >= 2",
"previous_cost.z" = "Standardized previous medical cost\n(excluding medication)",
"previos_number_relapses" = "Previous number of relapses")
tab <- CreateTableOne(vars = covars, strata = "responder to T=1", data = df, test = F) %>% print(smd = T)
```
We can directly present the table or visualize the comparison with errorbar plots, which is what the chapter presented (Figure 6A). Here we show both. Tables are helpful if specific numbers are important but readers would have to perform mental comparison to understand which value is higher, whereas plots are helpful if you want people to quickly identify the larger differences and not focus on the specific values of certain results.
```{r unicomp.visual, echo = T, eval = T, warning = F, fig.width = 10, fig.height = 8}
smd <- as_tibble(tab, rownames = "var") %>%
rowwise() %>%
mutate(variable = as.factor(str_extract(var, ".*(?= \\(mean \\(SD\\)\\))"))) %>%
filter(!is.na(variable)) %>%
arrange(desc(SMD)) %>%
mutate(smd = paste0("SMD =", SMD),
variable = labs[[variable]]) %>%
dplyr::select(variable, smd) %>%
mutate(ID = 1)
levels <- unique(smd$variable)
p6a <- df %>%
mutate(ID = 1:n()) %>%
dplyr::select(all_of(covars), ID, contains("responder")) %>%
melt(id = c("ID", "responder to T=1")) %>%
rowwise() %>%
mutate(variable = labs[[variable]]) %>%
left_join(smd, by = c("variable", "ID")) %>%
mutate(variable2 = factor(variable, levels = levels)) %>%
ggplot(aes(x = reorder(variable2, desc(variable2)), color = `responder to T=1`, y = value, group = `responder to T=1`)) +
stat_summary(fun = mean, geom = "point", size = 4, position = position_dodge(width = 0.5)) +
stat_summary(fun.data = mean_sdl, geom = "errorbar", position = position_dodge(width = 0.5), width= 0.3, size = 1.2) +
geom_hline(yintercept = 0, color = "gray", linetype = "dashed") +
geom_text(aes(label = smd), hjust = -0.5, y = -1.5, color = "darkgray", size = 3.5) +
# facet_wrap(~ variable2, nrow = 4) +
labs(x = "Baseline patient characteristic",
y = "Mean +- SD") +
coord_flip() +
scale_color_brewer(palette = "Set2") +
theme_classic() +
theme(legend.position = "top",
axis.text = element_text(size = 13),
axis.title.y = element_text(size = 15),
axis.title.x = element_text(size = 15),
axis.text.x = element_text(size = 15),
strip.text.x = element_text(size = 12))
# Figure 6A
ggarrange(p6a, nrow = 1, labels = c("A"))
```
We can also show the density of ITR scores obtained from the score-based methods. The results can be found in `modpm` and we used histogram to visualize (Figure 6B).
```{r density, echo = T, eval = T}
##### Density of ITR score #####
dataplot <- data.frame(score = factor(rep(c("Naive Poisson", "Contrast Regression"),
each = length(modpm$score.poisson))),
value = c(modpm$score.poisson, modpm$score.contrastReg))
p6b <- dataplot %>%
ggplot(aes(x = value, fill = score)) +
geom_density(alpha = 0.5) +
scale_fill_manual(values = c("dodgerblue", "gray30")) +
geom_vline(xintercept = 0, color = "darkgray", linetype = "dashed", size = 1) +
labs(x = "Estimated CATE score", y = "Density", fill = "Method") +
theme_classic() +
theme(legend.position = c(0.2, 0.8),
axis.text = element_text(size = 13),
axis.title.y = element_text(size = 15),
axis.title.x = element_text(size = 15),
axis.text.x = element_text(size = 15),
strip.text.x = element_text(size = 12))
```
The ITR scores are essentially a linear combination of the baseline characteristics, thus it might be also of interest for one to know the corresponding coefficients (or weights) which shows how much each baseline variable contributed to the ITR score. To make it comparable across different scales of the baseline variables, we used the scaled data and the model result `modpm.s` was used to extract the coefficients and visualize as a bar plot. The coefficients can be presented in a table as well.
```{r coefs, echo = T, eval = T, fig.width = 12, fig.height = 12}
# Coefficients
coef <- modpm.s$coefficients
p6c <- coef %>%
as_tibble(rownames = "varname") %>%
melt(id.vars = "varname") %>%
filter(variable == "poisson", varname != "(Intercept)") %>%
mutate(absval = abs(value),
sign = ifelse(value > 0, "+", "-")) %>%
arrange(absval) %>%
mutate(varname = factor(varname, levels = unique(varname))) %>%
ggplot(aes(x = varname, y = absval, fill = sign)) +
geom_bar(stat = "identity", width = 0.5) +
scale_fill_brewer(palette = "Set1") +
scale_x_discrete(labels = labs) +
coord_flip() +
labs(y = "Absolute value of the estimated coefficient of CATE scores\nbased on Poisson regression", x = "Baseline patient characteristic") +
theme_minimal() +
theme(legend.position = c(0.8, 0.2),
axis.text = element_text(size = 13),
axis.title.y = element_text(size = 15),
axis.title.x = element_text(size = 15),
axis.text.x = element_text(size = 15),
strip.text.x = element_text(size = 12))
# Figure 6B, 6C
ggarrange(p6b, p6c, nrow = 2, labels = c("B", "C"))
# Coefficients presented as a table
coef %>% round(2) %>% kable() %>% kable_styling(full_width = F)
```
## Version info {.unnumbered}
This chapter was rendered using the following version of R and its packages:
```{r}
#| echo: false
#| message: false
#| warning: false
sessionInfo()
```
## References {.unnumbered}