Skip to contents

Data Preparation

We use the PBC dataset from JMbayes2. Log serum bilirubin (standardized) serves as the longitudinal biomarker.

survival_data <- pbc2.id |>
  transmute(
    id = as.integer(id), time = years, status = status2,
    drug = as.numeric(drug == "D-penicil"), age = age,
    sex = as.numeric(sex == "female")
  )

longitudinal_data <- pbc2 |>
  transmute(
    id = as.integer(id), time = year,
    observed = log(albumin)
  ) |>
  left_join(
    survival_data |> select(id, drug, age, sex),
    by = "id"
  )

bili_mean <- mean(longitudinal_data$observed)
bili_sd <- sd(longitudinal_data$observed)
longitudinal_data$observed <-
  (longitudinal_data$observed - bili_mean) / bili_sd

cat(sprintf(
  "Patients: %d | Events: %d (%.0f%%) | Obs: %d\n",
  nrow(survival_data), sum(survival_data$status),
  100 * mean(survival_data$status),
  nrow(longitudinal_data)
))
set.seed(123)
ids20 <- sample(unique(longitudinal_data$id), 20)

longitudinal_data |>
  filter(id %in% ids20) |>
  left_join(
    survival_data |> select(id, status),
    by = "id"
  ) |>
  ggplot(aes(time, observed,
    group = id,
    color = factor(status)
  )) +
  geom_line(alpha = 0.6) +
  geom_point(alpha = 0.6, size = 0.8) +
  scale_color_manual(
    values = c("0" = "#3498db", "1" = "#e74c3c"),
    labels = c("Censored", "Event"), name = "Status"
  ) +
  labs(
    title = "Log Bilirubin Trajectories (20 Patients)",
    x = "Time (years)",
    y = "Log Bilirubin (standardized)"
  ) +
  theme_minimal(base_size = 12)

Model Fitting

JointODE

fit_ode <- JointODE(
  longitudinal_formula =
    observed ~ biomarker + velocity + drug +
    (biomarker + velocity | id),
  survival_formula = Surv(time, status) ~ drug + age + sex,
  longitudinal_data = longitudinal_data,
  survival_data = survival_data,
  init = "marginal",
  control = list(parallel = !nzchar(Sys.getenv("CI")))
)
summary(fit_ode)

Extended Cox

td <- longitudinal_data |>
  arrange(id, time) |>
  left_join(
    survival_data |>
      select(id, time_event = time, status),
    by = "id"
  ) |>
  group_by(id) |>
  mutate(
    tstart = time,
    tstop = lead(time, default = first(time_event)),
    event = ifelse(row_number() == n(), first(status), 0L)
  ) |>
  ungroup() |>
  filter(tstop > tstart)

cox_td <- coxph(
  Surv(tstart, tstop, event) ~ observed + drug + age + sex,
  data = td, x = TRUE
)

JMbayes2

Four configurations: 2 RE structures x 2 association types.

# Cox for JMbayes2 (standard format, no time-varying)
cox_fit <- coxph(
  Surv(time, status) ~ drug + age + sex,
  data = survival_data, x = TRUE
)
lme_lin <- lme(
  observed ~ ns(time, 3) * drug,
  random = ~ time | id,
  data = longitudinal_data
)
lme_ns <- lme(
  observed ~ ns(time, 3) * drug,
  random = ~ ns(time, 3) | id,
  data = longitudinal_data,
  control = lmeControl(opt = "optim", maxIter = 200)
)

jm_fit <- function(lme_obj, ff) {
  jm(cox_fit, lme_obj,
    time_var = "time",
    functional_forms = list("observed" = ff)
  )
}

fit_lin_val <- jm_fit(lme_lin, ~ value(observed))
fit_lin_both <- jm_fit(
  lme_lin,
  ~ value(observed) + slope(observed)
)
fit_ns_val <- jm_fit(lme_ns, ~ value(observed))
fit_ns_both <- jm_fit(
  lme_ns,
  ~ value(observed) + slope(observed)
)

Results

Hazard Coefficients

ode_c <- coef(fit_ode)
ode_se <- sqrt(diag(vcov(fit_ode)))
cox_s <- summary(cox_td)$coefficients

get_jm <- function(fit) summary(fit)$Survival
lv <- get_jm(fit_lin_val)
lb <- get_jm(fit_lin_both)
nv <- get_jm(fit_ns_val)
nb <- get_jm(fit_ns_both)

f <- function(est, se) sprintf("%.3f (%.3f)", est, se)

comp <- data.frame(
  Parameter = c("value", "slope", "drug", "age", "sex"),
  Ext.Cox = c(
    f(cox_s["observed", 1], cox_s["observed", 3]),
    "--",
    f(cox_s["drug", 1], cox_s["drug", 3]),
    f(cox_s["age", 1], cox_s["age", 3]),
    f(cox_s["sex", 1], cox_s["sex", 3])
  ),
  JointODE = c(
    f(ode_c["hazard:alpha_1"], ode_se["hazard:alpha_1"]),
    f(ode_c["hazard:alpha_2"], ode_se["hazard:alpha_2"]),
    f(ode_c["hazard:drug"], ode_se["hazard:drug"]),
    f(ode_c["hazard:age"], ode_se["hazard:age"]),
    f(ode_c["hazard:sex"], ode_se["hazard:sex"])
  ),
  JM.lin.val = c(
    f(lv["value(observed)", 1], lv["value(observed)", 2]),
    "--",
    f(lv["drug", 1], lv["drug", 2]),
    f(lv["age", 1], lv["age", 2]),
    f(lv["sex", 1], lv["sex", 2])
  ),
  JM.lin.both = c(
    f(lb["value(observed)", 1], lb["value(observed)", 2]),
    f(lb["slope(observed)", 1], lb["slope(observed)", 2]),
    f(lb["drug", 1], lb["drug", 2]),
    f(lb["age", 1], lb["age", 2]),
    f(lb["sex", 1], lb["sex", 2])
  ),
  JM.ns.val = c(
    f(nv["value(observed)", 1], nv["value(observed)", 2]),
    "--",
    f(nv["drug", 1], nv["drug", 2]),
    f(nv["age", 1], nv["age", 2]),
    f(nv["sex", 1], nv["sex", 2])
  ),
  JM.ns.both = c(
    f(nb["value(observed)", 1], nb["value(observed)", 2]),
    f(nb["slope(observed)", 1], nb["slope(observed)", 2]),
    f(nb["drug", 1], nb["drug", 2]),
    f(nb["age", 1], nb["age", 2]),
    f(nb["sex", 1], nb["sex", 2])
  )
)

kable(
  comp,
  align = c("l", "r", "r", "r", "r", "r", "r"),
  caption = paste(
    "Hazard coefficients: estimate (SE).",
    "lin/ns = RE structure;",
    "val = value-only; both = value+slope."
  )
)

Dynamic AUC

cindex <- function(risk) {
  survival::concordance(
    Surv(time, status) ~ risk[as.character(id)],
    data = survival_data, reverse = TRUE
  )$concordance
}

# JointODE: lp = alpha_1*m + alpha_2*v + W*phi at last obs
pred_ode <- predict(fit_ode)
ode_c <- coef(fit_ode)
last_ode <- pred_ode |>
  mutate(id = as.integer(id)) |>
  group_by(id) |>
  slice_tail(n = 1) |>
  ungroup() |>
  left_join(survival_data |> select(id, drug, age, sex), by = "id")
ode_lp <- ode_c["hazard:alpha_1"] * last_ode$biomarker +
  ode_c["hazard:alpha_2"] * last_ode$velocity +
  ode_c["hazard:drug"] * last_ode$drug +
  ode_c["hazard:age"] * last_ode$age +
  ode_c["hazard:sex"] * last_ode$sex
c_ode <- cindex(setNames(ode_lp, last_ode$id))

# Extended Cox: lp at last obs
last <- longitudinal_data |>
  group_by(id) |>
  slice_tail(n = 1) |>
  ungroup()
c_cox <- cindex(setNames(predict(cox_td, type = "lp", newdata = last), last$id))

# JMbayes2: lp from lme fitted value/slope + association coefficients
sc <- summary(fit_ns_both)$Survival[, "Mean"]
names(sc) <- rownames(summary(fit_ns_both)$Survival)
last$val <- predict(lme_ns, newdata = last, level = 1)
eps_t <- last
eps_t$time <- eps_t$time + 0.01
last$slp <- (predict(lme_ns, newdata = eps_t, level = 1) - last$val) / 0.01
jm_lp <- sc["value(observed)"] * last$val +
  sc["slope(observed)"] * last$slp +
  sc["drug"] * last$drug + sc["age"] * last$age + sc["sex"] * last$sex
c_jm <- cindex(setNames(jm_lp, last$id))

kable(
  data.frame(
    Model = c("JointODE", "Extended Cox", "JMbayes2 (ns+both)"),
    C.index = sprintf("%.3f", c(c_ode, c_cox, c_jm))
  ),
  align = c("l", "r"),
  caption = "Concordance index (C-index) for discrimination"
)

Session Info