Data Preparation
library(dplyr)
library(ggplot2)
library(knitr)
library(splines)
library(nlme)
library(survival)
library(JMbayes2)
library(JointODE)We use the PBC dataset from JMbayes2. Log serum bilirubin (standardized) serves as the longitudinal biomarker.
survival_data <- pbc2.id |>
transmute(
id = as.integer(id), time = years, status = status2,
drug = as.numeric(drug == "D-penicil"), age = age,
sex = as.numeric(sex == "female")
)
longitudinal_data <- pbc2 |>
transmute(
id = as.integer(id), time = year,
observed = log(albumin)
) |>
left_join(
survival_data |> select(id, drug, age, sex),
by = "id"
)
bili_mean <- mean(longitudinal_data$observed)
bili_sd <- sd(longitudinal_data$observed)
longitudinal_data$observed <-
(longitudinal_data$observed - bili_mean) / bili_sd
cat(sprintf(
"Patients: %d | Events: %d (%.0f%%) | Obs: %d\n",
nrow(survival_data), sum(survival_data$status),
100 * mean(survival_data$status),
nrow(longitudinal_data)
))
set.seed(123)
ids20 <- sample(unique(longitudinal_data$id), 20)
longitudinal_data |>
filter(id %in% ids20) |>
left_join(
survival_data |> select(id, status),
by = "id"
) |>
ggplot(aes(time, observed,
group = id,
color = factor(status)
)) +
geom_line(alpha = 0.6) +
geom_point(alpha = 0.6, size = 0.8) +
scale_color_manual(
values = c("0" = "#3498db", "1" = "#e74c3c"),
labels = c("Censored", "Event"), name = "Status"
) +
labs(
title = "Log Bilirubin Trajectories (20 Patients)",
x = "Time (years)",
y = "Log Bilirubin (standardized)"
) +
theme_minimal(base_size = 12)Model Fitting
JointODE
fit_ode <- JointODE(
longitudinal_formula =
observed ~ biomarker + velocity + drug +
(biomarker + velocity | id),
survival_formula = Surv(time, status) ~ drug + age + sex,
longitudinal_data = longitudinal_data,
survival_data = survival_data,
init = "marginal",
control = list(parallel = !nzchar(Sys.getenv("CI")))
)
summary(fit_ode)Extended Cox
td <- longitudinal_data |>
arrange(id, time) |>
left_join(
survival_data |>
select(id, time_event = time, status),
by = "id"
) |>
group_by(id) |>
mutate(
tstart = time,
tstop = lead(time, default = first(time_event)),
event = ifelse(row_number() == n(), first(status), 0L)
) |>
ungroup() |>
filter(tstop > tstart)
cox_td <- coxph(
Surv(tstart, tstop, event) ~ observed + drug + age + sex,
data = td, x = TRUE
)JMbayes2
Four configurations: 2 RE structures x 2 association types.
# Cox for JMbayes2 (standard format, no time-varying)
cox_fit <- coxph(
Surv(time, status) ~ drug + age + sex,
data = survival_data, x = TRUE
)
lme_lin <- lme(
observed ~ ns(time, 3) * drug,
random = ~ time | id,
data = longitudinal_data
)
lme_ns <- lme(
observed ~ ns(time, 3) * drug,
random = ~ ns(time, 3) | id,
data = longitudinal_data,
control = lmeControl(opt = "optim", maxIter = 200)
)
jm_fit <- function(lme_obj, ff) {
jm(cox_fit, lme_obj,
time_var = "time",
functional_forms = list("observed" = ff)
)
}
fit_lin_val <- jm_fit(lme_lin, ~ value(observed))
fit_lin_both <- jm_fit(
lme_lin,
~ value(observed) + slope(observed)
)
fit_ns_val <- jm_fit(lme_ns, ~ value(observed))
fit_ns_both <- jm_fit(
lme_ns,
~ value(observed) + slope(observed)
)Results
Hazard Coefficients
ode_c <- coef(fit_ode)
ode_se <- sqrt(diag(vcov(fit_ode)))
cox_s <- summary(cox_td)$coefficients
get_jm <- function(fit) summary(fit)$Survival
lv <- get_jm(fit_lin_val)
lb <- get_jm(fit_lin_both)
nv <- get_jm(fit_ns_val)
nb <- get_jm(fit_ns_both)
f <- function(est, se) sprintf("%.3f (%.3f)", est, se)
comp <- data.frame(
Parameter = c("value", "slope", "drug", "age", "sex"),
Ext.Cox = c(
f(cox_s["observed", 1], cox_s["observed", 3]),
"--",
f(cox_s["drug", 1], cox_s["drug", 3]),
f(cox_s["age", 1], cox_s["age", 3]),
f(cox_s["sex", 1], cox_s["sex", 3])
),
JointODE = c(
f(ode_c["hazard:alpha_1"], ode_se["hazard:alpha_1"]),
f(ode_c["hazard:alpha_2"], ode_se["hazard:alpha_2"]),
f(ode_c["hazard:drug"], ode_se["hazard:drug"]),
f(ode_c["hazard:age"], ode_se["hazard:age"]),
f(ode_c["hazard:sex"], ode_se["hazard:sex"])
),
JM.lin.val = c(
f(lv["value(observed)", 1], lv["value(observed)", 2]),
"--",
f(lv["drug", 1], lv["drug", 2]),
f(lv["age", 1], lv["age", 2]),
f(lv["sex", 1], lv["sex", 2])
),
JM.lin.both = c(
f(lb["value(observed)", 1], lb["value(observed)", 2]),
f(lb["slope(observed)", 1], lb["slope(observed)", 2]),
f(lb["drug", 1], lb["drug", 2]),
f(lb["age", 1], lb["age", 2]),
f(lb["sex", 1], lb["sex", 2])
),
JM.ns.val = c(
f(nv["value(observed)", 1], nv["value(observed)", 2]),
"--",
f(nv["drug", 1], nv["drug", 2]),
f(nv["age", 1], nv["age", 2]),
f(nv["sex", 1], nv["sex", 2])
),
JM.ns.both = c(
f(nb["value(observed)", 1], nb["value(observed)", 2]),
f(nb["slope(observed)", 1], nb["slope(observed)", 2]),
f(nb["drug", 1], nb["drug", 2]),
f(nb["age", 1], nb["age", 2]),
f(nb["sex", 1], nb["sex", 2])
)
)
kable(
comp,
align = c("l", "r", "r", "r", "r", "r", "r"),
caption = paste(
"Hazard coefficients: estimate (SE).",
"lin/ns = RE structure;",
"val = value-only; both = value+slope."
)
)Dynamic AUC
cindex <- function(risk) {
survival::concordance(
Surv(time, status) ~ risk[as.character(id)],
data = survival_data, reverse = TRUE
)$concordance
}
# JointODE: lp = alpha_1*m + alpha_2*v + W*phi at last obs
pred_ode <- predict(fit_ode)
ode_c <- coef(fit_ode)
last_ode <- pred_ode |>
mutate(id = as.integer(id)) |>
group_by(id) |>
slice_tail(n = 1) |>
ungroup() |>
left_join(survival_data |> select(id, drug, age, sex), by = "id")
ode_lp <- ode_c["hazard:alpha_1"] * last_ode$biomarker +
ode_c["hazard:alpha_2"] * last_ode$velocity +
ode_c["hazard:drug"] * last_ode$drug +
ode_c["hazard:age"] * last_ode$age +
ode_c["hazard:sex"] * last_ode$sex
c_ode <- cindex(setNames(ode_lp, last_ode$id))
# Extended Cox: lp at last obs
last <- longitudinal_data |>
group_by(id) |>
slice_tail(n = 1) |>
ungroup()
c_cox <- cindex(setNames(predict(cox_td, type = "lp", newdata = last), last$id))
# JMbayes2: lp from lme fitted value/slope + association coefficients
sc <- summary(fit_ns_both)$Survival[, "Mean"]
names(sc) <- rownames(summary(fit_ns_both)$Survival)
last$val <- predict(lme_ns, newdata = last, level = 1)
eps_t <- last
eps_t$time <- eps_t$time + 0.01
last$slp <- (predict(lme_ns, newdata = eps_t, level = 1) - last$val) / 0.01
jm_lp <- sc["value(observed)"] * last$val +
sc["slope(observed)"] * last$slp +
sc["drug"] * last$drug + sc["age"] * last$age + sc["sex"] * last$sex
c_jm <- cindex(setNames(jm_lp, last$id))
kable(
data.frame(
Model = c("JointODE", "Extended Cox", "JMbayes2 (ns+both)"),
C.index = sprintf("%.3f", c(c_ode, c_cox, c_jm))
),
align = c("l", "r"),
caption = "Concordance index (C-index) for discrimination"
)