library(causalverse)

1. Introduction

1.1 Why Heterogeneous Treatment Effects Matter

Most causal inference methods report a single summary statistic - the Average Treatment Effect (ATE) - which tells us the average impact of an intervention across the study population. While the ATE is a valuable and interpretable quantity, it can obscure substantial variation in how individuals respond to treatment. When treatment effects differ across units, reporting only the ATE may lead to:

  • Misallocation of resources: Treating individuals who do not benefit or even experience harm.
  • Missed targeting opportunities: Failing to identify sub-populations where treatment is highly effective.
  • Policy errors: Designing one-size-fits-all policies that ignore individual heterogeneity.
  • Replication failures: An ATE that is statistically significant in one context may not replicate if the population composition changes.

Heterogeneous Treatment Effect (HTE) analysis asks: for whom does the treatment work, and by how much? This is the central question in precision medicine, personalized policy, targeted marketing, and individualized education.

Modern causal machine learning methods - including causal forests, meta-learners, and double/debiased machine learning - make it possible to estimate HTE at the individual level, validate those estimates, and translate them into actionable targeting rules, all while maintaining rigorous causal guarantees.

1.2 Key Estimands: CATE, ATE, ATT

Let Yi(1)Y_i(1) and Yi(0)Y_i(0) denote the potential outcomes for unit ii under treatment and control, respectively. The treatment indicator is Wi{0,1}W_i \in \{0, 1\}, and XiX_i is a vector of pre-treatment covariates.

Individual Treatment Effect (ITE): τi=Yi(1)Yi(0)\tau_i = Y_i(1) - Y_i(0) This is never observed due to the fundamental problem of causal inference.

Conditional Average Treatment Effect (CATE): τ(x)=𝔼[Yi(1)Yi(0)Xi=x]\tau(x) = \mathbb{E}[Y_i(1) - Y_i(0) \mid X_i = x] The CATE is the population-level treatment effect for units with covariate profile xx. This is the primary estimand of HTE analysis.

Average Treatment Effect (ATE): τATE=𝔼[τ(Xi)]=𝔼[Yi(1)Yi(0)]\tau_{ATE} = \mathbb{E}[\tau(X_i)] = \mathbb{E}[Y_i(1) - Y_i(0)]

Average Treatment Effect on the Treated (ATT): τATT=𝔼[τ(Xi)Wi=1]\tau_{ATT} = \mathbb{E}[\tau(X_i) \mid W_i = 1]

The CATE generalizes the ATE: τATE=𝔼[τ(Xi)]\tau_{ATE} = \mathbb{E}[\tau(X_i)] is the expectation of the CATE over the marginal distribution of XX.

1.3 Identification Assumptions

To identify the CATE from observational data, we require:

  1. Unconfoundedness (Conditional Independence): (Y(0),Y(1))WX(Y(0), Y(1)) \perp W \mid X No unmeasured confounders conditional on XX.

  2. Overlap (Positivity): 0<(W=1X=x)<10 < \mathbb{P}(W = 1 \mid X = x) < 1 for all xx Every covariate profile has a positive probability of both treatment and control.

  3. SUTVA: Stable Unit Treatment Value Assumption - no interference and no hidden treatment versions.

Under these assumptions, the CATE is identified: τ(x)=𝔼[YW=1,X=x]𝔼[YW=0,X=x]\tau(x) = \mathbb{E}[Y \mid W = 1, X = x] - \mathbb{E}[Y \mid W = 0, X = x]

In randomized experiments, unconfoundedness holds by design. In observational studies, it must be justified by domain knowledge.

1.4 Overview of Methods Covered

This vignette covers a comprehensive toolkit for HTE estimation:

Method Key Reference Package
S-Learner Künzel et al. (2019) xgboost, lm
T-Learner Künzel et al. (2019) Any ML
X-Learner Künzel et al. (2019) grf
R-Learner / DML Robinson (1988), Nie & Wager (2021) glmnet, DoubleML
Causal Forest Wager & Athey (2018), Athey et al. (2019) grf
BLP / GATES Chernozhukov et al. (2020) causalverse
DR-Learner / AIPW Kennedy (2020) causalverse
Qini / AUUC Radcliffe (2007) causalverse
DML with interactions Chernozhukov et al. (2018) glmnet, DoubleML

2. Setup & Data

2.1 Simulated Data with Known Heterogeneity

We begin with a simulation where the true CATE is known, enabling us to evaluate estimator performance.

set.seed(42)
n <- 2000
p <- 10

# Covariates: standard normal
X <- matrix(rnorm(n * p), n, p)
colnames(X) <- paste0("X", 1:p)

# Treatment: randomized with probability 0.5
W <- rbinom(n, 1, 0.5)

# True CATE: linear in X1 (heterogeneity driven by X1)
tau_true <- X[, 1]

# Outcome: tau * W + noise
Y <- tau_true * W + 0.5 * X[, 2] + rnorm(n)

# Compile into data frame
df <- data.frame(X, W = W, Y = Y, tau_true = tau_true)

cat("ATE (true):", mean(tau_true), "\n")
cat("SD of tau (true):", round(sd(tau_true), 3), "\n")
cat("Fraction positive tau:", round(mean(tau_true > 0), 3), "\n")
cat("n =", n, "| p =", p, "| Treatment fraction:", mean(W), "\n")

2.2 The Lalonde Dataset

We also use the classic Lalonde (1986) dataset from the MatchIt package, which evaluates a job training program on earnings. This dataset illustrates HTE in a real-world observational setting.

if (requireNamespace("MatchIt", quietly = TRUE)) {
  data(lalonde, package = "MatchIt")
  cat("Lalonde dataset: n =", nrow(lalonde), "\n")
  cat("Outcome: re78 (1978 earnings)\n")
  cat("Treatment: treat (job training, 1 = treated)\n")
  cat("Treatment fraction:", round(mean(lalonde$treat), 3), "\n")
  cat("Mean re78 (treated):", round(mean(lalonde$re78[lalonde$treat == 1]), 1), "\n")
  cat("Mean re78 (control):", round(mean(lalonde$re78[lalonde$treat == 0]), 1), "\n")
}

2.3 Pre-Treatment Covariate Balance

Before estimating treatment effects, always check covariate balance.

# Balance table for simulated data
treated_idx   <- which(W == 1)
control_idx   <- which(W == 0)

balance_df <- data.frame(
  Variable = colnames(X),
  Mean_Treated = round(colMeans(X[treated_idx, ]), 3),
  Mean_Control = round(colMeans(X[control_idx, ]), 3),
  Std_Diff = round(
    (colMeans(X[treated_idx, ]) - colMeans(X[control_idx, ])) /
      sqrt((apply(X[treated_idx, ], 2, var) + apply(X[control_idx, ], 2, var)) / 2),
    3
  )
)
print(balance_df)
# Visualize standardized differences
ggplot(balance_df, aes(x = Std_Diff, y = reorder(Variable, Std_Diff))) +
  geom_vline(xintercept = c(-0.1, 0.1), linetype = "dashed", color = "red", alpha = 0.5) +
  geom_vline(xintercept = 0, color = "gray40") +
  geom_point(size = 3, color = "#2166AC") +
  labs(
    x     = "Standardized Mean Difference",
    y     = NULL,
    title = "Pre-Treatment Covariate Balance",
    subtitle = "Dashed lines at ±0.1 (conventional threshold for good balance)"
  ) +
  theme_minimal(base_size = 12)

In the randomized simulation, all standardized differences are close to zero - expected under random assignment.

2.4 True CATE Distribution

ggplot(data.frame(tau = tau_true), aes(x = tau)) +
  geom_histogram(bins = 50, fill = "#4393C3", alpha = 0.7, color = "white") +
  geom_vline(xintercept = mean(tau_true), color = "#D73027", linewidth = 1.2) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  labs(
    x     = "True Individual Treatment Effect (tau)",
    y     = "Count",
    title = "Distribution of True CATE",
    subtitle = sprintf("ATE = %.3f; SD = %.3f; %.0f%% positive",
                       mean(tau_true), sd(tau_true), 100 * mean(tau_true > 0))
  ) +
  theme_minimal(base_size = 12)

3. Meta-Learners

Meta-learners reduce the CATE estimation problem to off-the-shelf regression or classification. They differ in how they use treatment information.

3.1 S-Learner (Single Learner)

The S-Learner fits a single model μ̂(x,w)\hat{\mu}(x, w) with the treatment indicator as just another feature:

τ̂(x)=μ̂(x,1)μ̂(x,0)\hat{\tau}(x) = \hat{\mu}(x, 1) - \hat{\mu}(x, 0)

Pros: Simple, uses all data. Cons: May underfit treatment heterogeneity if the learner shrinks the treatment coefficient toward zero (e.g., lasso).

# S-Learner with linear model
df_train <- data.frame(X, W = W, Y = Y)

# Fit model with W as a feature (and all interactions for flexibility)
s_fit <- lm(Y ~ . * W, data = df_train[, c(paste0("X", 1:p), "W", "Y")])

# Predict potential outcomes
df1 <- df_train; df1$W <- 1
df0 <- df_train; df0$W <- 0

tau_s <- predict(s_fit, newdata = df1) - predict(s_fit, newdata = df0)

cat("S-Learner (linear with interactions):\n")
cat("  Mean CATE:", round(mean(tau_s), 4), "\n")
cat("  SD CATE:  ", round(sd(tau_s), 4), "\n")
cat("  RMSE vs truth:", round(sqrt(mean((tau_s - tau_true)^2)), 4), "\n")
library(xgboost)

# S-Learner with XGBoost
dtrain <- xgb.DMatrix(
  data  = cbind(X, W = W),
  label = Y
)
xgb_params <- list(
  objective        = "reg:squarederror",
  max_depth        = 4,
  eta              = 0.05,
  subsample        = 0.8,
  colsample_bytree = 0.8,
  nthread          = 1
)
set.seed(42)
xgb_fit <- xgb.train(
  params  = xgb_params,
  data    = dtrain,
  nrounds = 200,
  verbose = 0
)

# Predict with W=1 and W=0
d1_xgb <- xgb.DMatrix(data = cbind(X, W = 1))
d0_xgb <- xgb.DMatrix(data = cbind(X, W = 0))
tau_s_xgb <- predict(xgb_fit, d1_xgb) - predict(xgb_fit, d0_xgb)

cat("S-Learner (XGBoost):\n")
cat("  Mean CATE:", round(mean(tau_s_xgb), 4), "\n")
cat("  SD CATE:  ", round(sd(tau_s_xgb), 4), "\n")
cat("  RMSE vs truth:", round(sqrt(mean((tau_s_xgb - tau_true)^2)), 4), "\n")

3.2 T-Learner (Two-Learner)

The T-Learner fits separate outcome models for treated and control units:

μ̂1(x)=𝔼[YX=x,W=1],μ̂0(x)=𝔼[YX=x,W=0]\hat{\mu}_1(x) = \mathbb{E}[Y \mid X = x, W = 1], \quad \hat{\mu}_0(x) = \mathbb{E}[Y \mid X = x, W = 0]τ̂(x)=μ̂1(x)μ̂0(x)\hat{\tau}(x) = \hat{\mu}_1(x) - \hat{\mu}_0(x)

Pros: No regularization bias on treatment. Cons: Can overfit when one arm is small; ignores information sharing across arms.

# T-Learner with linear models
X_df <- as.data.frame(X)

# Treated model
fit1 <- lm(Y ~ ., data = cbind(X_df, Y = Y)[W == 1, ])
# Control model
fit0 <- lm(Y ~ ., data = cbind(X_df, Y = Y)[W == 0, ])

# Predict on full data
tau_t <- predict(fit1, newdata = X_df) - predict(fit0, newdata = X_df)

cat("T-Learner (linear):\n")
cat("  Mean CATE:", round(mean(tau_t), 4), "\n")
cat("  SD CATE:  ", round(sd(tau_t), 4), "\n")
cat("  RMSE vs truth:", round(sqrt(mean((tau_t - tau_true)^2)), 4), "\n")
# T-Learner with XGBoost
set.seed(42)

# Treated model
xgb_fit1 <- xgb.train(
  params  = xgb_params,
  data    = xgb.DMatrix(data = X[W == 1, ], label = Y[W == 1]),
  nrounds = 200,
  verbose = 0
)
# Control model
xgb_fit0 <- xgb.train(
  params  = xgb_params,
  data    = xgb.DMatrix(data = X[W == 0, ], label = Y[W == 0]),
  nrounds = 200,
  verbose = 0
)

dX <- xgb.DMatrix(data = X)
tau_t_xgb <- predict(xgb_fit1, dX) - predict(xgb_fit0, dX)

cat("T-Learner (XGBoost):\n")
cat("  Mean CATE:", round(mean(tau_t_xgb), 4), "\n")
cat("  SD CATE:  ", round(sd(tau_t_xgb), 4), "\n")
cat("  RMSE vs truth:", round(sqrt(mean((tau_t_xgb - tau_true)^2)), 4), "\n")

3.3 X-Learner (Künzel et al. 2019)

The X-Learner is a two-stage procedure designed to improve on the T-Learner when treatment arms are imbalanced:

Stage 1: Fit T-Learner models μ̂1\hat{\mu}_1 and μ̂0\hat{\mu}_0.

Stage 2: Compute imputed treatment effects: D̃i1=Yi(1)μ̂0(Xi)for treated units\tilde{D}_i^1 = Y_i(1) - \hat{\mu}_0(X_i) \quad \text{for treated units}D̃i0=μ̂1(Xi)Yi(0)for control units\tilde{D}_i^0 = \hat{\mu}_1(X_i) - Y_i(0) \quad \text{for control units}

Fit regression models for D̃1\tilde{D}^1 and D̃0\tilde{D}^0, then combine: τ̂(x)=g(x)τ̂1(x)+(1g(x))τ̂0(x)\hat{\tau}(x) = g(x) \hat{\tau}^1(x) + (1 - g(x)) \hat{\tau}^0(x) where g(x)g(x) is typically the propensity score ê(x)\hat{e}(x).

# X-Learner implementation using linear models
# Stage 1: T-Learner (reuse fit1, fit0 from above)
mu1_hat <- predict(fit1, newdata = X_df)
mu0_hat <- predict(fit0, newdata = X_df)

# Stage 2: Imputed effects
D_tilde_1 <- Y[W == 1] - mu0_hat[W == 1]
D_tilde_0 <- mu1_hat[W == 0] - Y[W == 0]

# Fit stage-2 models
tau_fit1 <- lm(D_tilde_1 ~ ., data = X_df[W == 1, ])
tau_fit0 <- lm(D_tilde_0 ~ ., data = X_df[W == 0, ])

tau1_hat <- predict(tau_fit1, newdata = X_df)
tau0_hat <- predict(tau_fit0, newdata = X_df)

# Propensity score as weighting function g(x)
ps_fit    <- glm(W ~ ., data = cbind(X_df, W = W), family = binomial())
e_hat     <- predict(ps_fit, type = "response")

# Combine
tau_x <- e_hat * tau1_hat + (1 - e_hat) * tau0_hat

cat("X-Learner (linear):\n")
cat("  Mean CATE:", round(mean(tau_x), 4), "\n")
cat("  SD CATE:  ", round(sd(tau_x), 4), "\n")
cat("  RMSE vs truth:", round(sqrt(mean((tau_x - tau_true)^2)), 4), "\n")

3.4 R-Learner (Robinson 1988 / Nie & Wager 2021)

The R-Learner uses Robinson’s (1988) partialling-out decomposition. Define residuals: Ỹi=Yim̂(Xi),W̃i=Wiê(Xi)\tilde{Y}_i = Y_i - \hat{m}(X_i), \quad \tilde{W}_i = W_i - \hat{e}(X_i)

where m̂(x)=𝔼[YX=x]\hat{m}(x) = \mathbb{E}[Y \mid X = x] and ê(x)=(W=1X=x)\hat{e}(x) = \mathbb{P}(W = 1 \mid X = x) are cross-fit nuisance functions.

The R-Learner minimizes the R-loss: τ̂=argminτ1ni=1n[Ỹiτ(Xi)W̃i]2\hat{\tau} = \arg\min_\tau \frac{1}{n} \sum_{i=1}^n \left[\tilde{Y}_i - \tau(X_i) \tilde{W}_i\right]^2

This is equivalent to regressing Ỹi/W̃i\tilde{Y}_i / \tilde{W}_i on XiX_i with weights W̃i2\tilde{W}_i^2, or running a weighted regression of Ỹi\tilde{Y}_i on τ(Xi)W̃i\tau(X_i) \tilde{W}_i.

library(glmnet)

# Cross-fitting with K=5 folds
K <- 5
set.seed(42)
folds <- sample(rep(1:K, length.out = n))

m_hat <- numeric(n)   # E[Y|X] cross-fitted
e_hat_r <- numeric(n) # E[W|X] cross-fitted

for (k in 1:K) {
  train_idx <- which(folds != k)
  test_idx  <- which(folds == k)

  # Outcome model: lasso on X
  cv_m <- cv.glmnet(X[train_idx, ], Y[train_idx], alpha = 1, nfolds = 5)
  m_hat[test_idx] <- predict(cv_m, X[test_idx, ], s = "lambda.min")

  # Propensity model: lasso logistic on X
  cv_e <- cv.glmnet(X[train_idx, ], W[train_idx], family = "binomial",
                    alpha = 1, nfolds = 5)
  e_hat_r[test_idx] <- predict(cv_e, X[test_idx, ], s = "lambda.min",
                                type = "response")
}

# Clip propensity scores
e_hat_r <- pmax(pmin(e_hat_r, 0.99), 0.01)

# R-residuals
Y_tilde <- Y - m_hat
W_tilde <- W - e_hat_r

# R-learner: weighted lasso
# Regress Y_tilde ~ tau(X) * W_tilde
# Equivalent to regressing Y_tilde / W_tilde on X with weights W_tilde^2
# We fit via modified design matrix
X_r  <- X * W_tilde   # scaled X
wts  <- W_tilde^2

cv_r <- cv.glmnet(X_r, Y_tilde, weights = wts, alpha = 1, nfolds = 5)
tau_r <- predict(cv_r, X_r, s = "lambda.min") / W_tilde

# For units where W_tilde ~ 0, use global mean
tau_r[abs(W_tilde) < 0.01] <- mean(tau_true)

cat("R-Learner (lasso cross-fit):\n")
cat("  Mean CATE:", round(mean(tau_r), 4), "\n")
cat("  SD CATE:  ", round(sd(tau_r), 4), "\n")
cat("  RMSE vs truth:", round(sqrt(mean((tau_r - tau_true)^2)), 4), "\n")

3.5 Meta-Learner Comparison

# Compile RMSE comparison
rmse_df <- data.frame(
  Method = c("S-Learner (OLS)", "S-Learner (XGB)",
             "T-Learner (OLS)", "T-Learner (XGB)",
             "X-Learner (OLS)", "R-Learner (lasso)"),
  RMSE = c(
    sqrt(mean((tau_s - tau_true)^2)),
    sqrt(mean((tau_s_xgb - tau_true)^2)),
    sqrt(mean((tau_t - tau_true)^2)),
    sqrt(mean((tau_t_xgb - tau_true)^2)),
    sqrt(mean((tau_x - tau_true)^2)),
    sqrt(mean((tau_r - tau_true)^2))
  ),
  Bias = c(
    mean(tau_s - tau_true),
    mean(tau_s_xgb - tau_true),
    mean(tau_t - tau_true),
    mean(tau_t_xgb - tau_true),
    mean(tau_x - tau_true),
    mean(tau_r - tau_true)
  )
)
rmse_df$RMSE <- round(rmse_df$RMSE, 4)
rmse_df$Bias <- round(rmse_df$Bias, 4)

print(rmse_df)

# Plot comparison
ggplot(rmse_df, aes(x = reorder(Method, RMSE), y = RMSE, fill = Method)) +
  geom_col(alpha = 0.8, show.legend = FALSE) +
  geom_text(aes(label = round(RMSE, 3)), hjust = -0.1, size = 3.5) +
  coord_flip() +
  scale_fill_manual(values = c(
    "#4393C3", "#2166AC", "#74ADD1", "#ABD9E9", "#D73027", "#F46D43"
  )) +
  labs(
    x     = NULL,
    y     = "RMSE (vs. True CATE)",
    title = "Meta-Learner Comparison: Oracle RMSE",
    subtitle = "Lower RMSE = better CATE recovery"
  ) +
  theme_minimal(base_size = 12)

4. Causal Forests (GRF)

Causal forests (Wager & Athey 2018; Athey et al. 2019) extend random forests to estimate CATE by:

  1. Adaptive nearest neighbors: Each tree partition defines a local neighborhood around each target point.
  2. Honest splitting: Uses separate data for splitting and estimation within each leaf (reduces overfitting).
  3. Double-robustness: Uses cross-fitted nuisance estimates m̂(x)\hat{m}(x) and ê(x)\hat{e}(x).
  4. Asymptotic normality: Provides pointwise confidence intervals via the infinitesimal jackknife.

The key estimating equation within each leaf LL is: τ̂(x)=argminθiLαi(x)[(Yim̂(Xi))θ(Wiê(Xi))]2\hat{\tau}(x) = \arg\min_\theta \sum_{i \in L} \alpha_i(x) \left[(Y_i - \hat{m}(X_i)) - \theta (W_i - \hat{e}(X_i))\right]^2

where αi(x)\alpha_i(x) are forest weights (frequencies of appearing in the same leaf as xx).

4.1 Fitting a Causal Forest

library(grf)

set.seed(42)
cf <- causal_forest(
  X         = X,
  Y         = Y,
  W         = W,
  num.trees = 2000,
  tune.parameters = "all"
)

cat("Causal forest fitted with", cf$`_num_trees`, "trees.\n")
cat("Tuned min.node.size:", cf$tunable.params$min.node.size, "\n")
cat("Tuned sample.fraction:", cf$tunable.params$sample.fraction, "\n")

4.2 ATE Estimation

ate_grf <- average_treatment_effect(cf, target.sample = "all")
att_grf <- average_treatment_effect(cf, target.sample = "treated")

cat("GRF ATE estimate:", round(ate_grf["estimate"], 4),
    "| SE:", round(ate_grf["std.err"], 4), "\n")
cat("GRF ATT estimate:", round(att_grf["estimate"], 4),
    "| SE:", round(att_grf["std.err"], 4), "\n")
cat("True ATE:", round(mean(tau_true), 4), "\n")

4.3 CATE Predictions with Confidence Intervals

# Out-of-bag predictions (honest, avoids overfitting)
preds <- predict(cf, estimate.variance = TRUE)
tau_cf  <- preds$predictions
tau_var <- preds$variance.estimates
tau_se  <- sqrt(tau_var)

cat("CATE distribution (causal forest):\n")
print(round(quantile(tau_cf, c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1)), 4))
cat("Oracle RMSE (vs true tau):", round(sqrt(mean((tau_cf - tau_true)^2)), 4), "\n")

4.4 Calibration Test

The calibration test (Chernozhukov et al. 2020) checks whether the CATE predictions are informative:

cal <- test_calibration(cf)
print(cal)

A significant coefficient on differential.forest.prediction (the BLP coefficient β2\beta_2) indicates meaningful heterogeneity captured by the forest. A coefficient close to 1 on mean.forest.prediction indicates good calibration.

4.5 Variable Importance

varimp <- variable_importance(cf)
varimp_df <- data.frame(
  feature    = colnames(X),
  importance = as.numeric(varimp)
)
varimp_df <- varimp_df[order(varimp_df$importance, decreasing = TRUE), ]

ggplot(varimp_df, aes(x = reorder(feature, importance), y = importance)) +
  geom_col(fill = "#4393C3", alpha = 0.8) +
  coord_flip() +
  labs(
    x     = NULL,
    y     = "Variable Importance",
    title = "Causal Forest: Variable Importance",
    subtitle = "X1 drives heterogeneity by construction"
  ) +
  theme_minimal(base_size = 12)

4.6 CATE Distribution

cate_df <- data.frame(
  tau_cf   = tau_cf,
  tau_true = tau_true,
  tau_se   = tau_se
)

ggplot(cate_df, aes(x = tau_cf)) +
  geom_histogram(bins = 50, fill = "#4393C3", alpha = 0.7, color = "white") +
  geom_vline(xintercept = ate_grf["estimate"], color = "#D73027", linewidth = 1.2) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  labs(
    x     = "Estimated CATE (Causal Forest)",
    y     = "Count",
    title = "Distribution of Estimated CATE",
    subtitle = sprintf("ATE = %.3f (red); %.1f%% positive, %.1f%% negative",
                       ate_grf["estimate"],
                       100 * mean(tau_cf > 0),
                       100 * mean(tau_cf < 0))
  ) +
  theme_minimal(base_size = 12)

4.7 CATE vs. True Treatment Effect

ggplot(cate_df, aes(x = tau_true, y = tau_cf)) +
  geom_point(alpha = 0.15, size = 0.8, color = "#2166AC") +
  geom_smooth(method = "lm", color = "#D73027", se = FALSE) +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
  labs(
    x     = "True CATE (oracle)",
    y     = "Estimated CATE (Causal Forest)",
    title = "CATE Estimates vs. True Values",
    subtitle = sprintf("Correlation: %.3f | RMSE: %.3f",
                       cor(tau_cf, tau_true),
                       sqrt(mean((tau_cf - tau_true)^2)))
  ) +
  theme_minimal(base_size = 12)

4.8 Comprehensive Summary via causal_forest_summary()

cf_summary <- causal_forest_summary(
  cf            = cf,
  X             = X,
  feature_names = colnames(X),
  n_groups      = 5,
  top_features  = 10,
  plot          = TRUE
)

# ATE table
print(cf_summary$ate)

# CATE distribution summary
print(round(cf_summary$cate_summary, 4))
cf_summary$plot_cate_dist
cf_summary$plot_importance

5. BLP Analysis (Best Linear Predictor)

The Best Linear Predictor (BLP) framework (Chernozhukov et al. 2020) provides a model-free test for treatment effect heterogeneity and a summary of its structure.

5.1 The BLP Regression

Starting from the partially linear model: Yim̂(Xi)=β1(Wiê(Xi))+β2(τ̂(Xi)τ̂)(Wiê(Xi))+εiY_i - \hat{m}(X_i) = \beta_1 (W_i - \hat{e}(X_i)) + \beta_2 (\hat{\tau}(X_i) - \bar{\hat{\tau}})(W_i - \hat{e}(X_i)) + \varepsilon_i

  • β1\beta_1 \approx ATE: Overall average treatment effect.
  • β2>0\beta_2 > 0: CATE predictions correlate with actual treatment effects (HTE is real and captured by τ̂\hat{\tau}).
  • β2=1\beta_2 = 1: CATE predictions perfectly explain the heterogeneity.
  • β2=0\beta_2 = 0: No evidence of heterogeneity beyond the ATE.

5.2 Running BLP Analysis

blp_res <- blp_analysis(
  cate_hat  = tau_cf,
  Y         = Y,
  W         = W,
  Y_hat     = cf$Y.hat,
  W_hat     = cf$W.hat,
  run_gates = TRUE,
  n_groups  = 5
)

cat("BLP Coefficients:\n")
print(blp_res$blp)

A significant β2\beta_2 confirms that the causal forest captures genuine heterogeneity. The estimate should be close to 1 for a well-calibrated CATE estimator.

5.3 BLP Plot

blp_res$blp_plot

5.4 GATES: Group Average Treatment Effects

GATES summarizes heterogeneity by ranking units into quantile groups by estimated CATE and computing the average treatment effect within each group:

cat("GATES Results:\n")
gates_df <- blp_res$gates
# Only round numeric columns (group is character)
num_cols <- sapply(gates_df, is.numeric)
gates_df[num_cols] <- round(gates_df[num_cols], 4)
print(gates_df)
blp_res$gates_plot

The GATES plot should show a clear monotone pattern from low-CATE to high-CATE groups if the forest is well-calibrated.

5.5 GATES Without GRF (Using OLS)

# GATES using any CATE estimator: use tau_t (T-Learner OLS)
blp_ols <- blp_analysis(
  cate_hat  = tau_t,
  Y         = Y,
  W         = W,
  Y_hat     = fitted(lm(Y ~ ., data = cbind(X_df, Y = Y))),
  W_hat     = rep(mean(W), n),
  run_gates = TRUE,
  n_groups  = 5
)

cat("BLP for T-Learner OLS:\n")
print(blp_ols$blp)

6. Qini Curve & Policy Evaluation

The Qini curve (Radcliffe 2007) evaluates the quality of CATE estimates for targeting: which fraction of the population should we treat to maximize aggregate benefit?

6.1 The Qini Curve

Rank units by estimated CATE (highest first). The Qini curve plots: Cumulative Uplift(f)=1ni:τ̂(Xi)qf[Y1(f)Y0(f)]f\text{Cumulative Uplift}(f) = \frac{1}{n} \sum_{i : \hat{\tau}(X_i) \geq q_f} \left[\bar{Y}_1^{(f)} - \bar{Y}_0^{(f)}\right] \cdot f

where qfq_f is the ff-th quantile of τ̂(X)\hat{\tau}(X).

AUUC (Area Under the Uplift Curve) compares targeting strategies: - Random targeting: uniform baseline - CATE-based targeting: our estimator - Oracle targeting: sort by true τi\tau_i (requires knowing true effects)

6.2 Computing the Qini Curve

qini_res <- qini_curve(
  cate_hat       = tau_cf,
  Y              = Y,
  W              = W,
  compare_random = TRUE,
  n_bins         = 100,
  boot_reps      = 500,
  seed           = 42
)

cat("AUUC:", round(qini_res$auuc, 4), "\n")
cat("AUUC 95% CI: [", round(qini_res$auuc_ci[1], 4),
    ",", round(qini_res$auuc_ci[2], 4), "]\n")
qini_res$plot +
  labs(title = "Qini Curve: Causal Forest CATE Targeting")

6.3 Comparing Qini Curves Across Estimators

# Compute Qini curves for each estimator
compute_qini_simple <- function(tau_hat, Y, W, n_bins = 100) {
  n   <- length(Y)
  ord <- order(tau_hat, decreasing = TRUE)
  fracs <- seq(0, 1, length.out = n_bins + 1)
  uplifts <- numeric(n_bins + 1)
  for (k in seq_len(n_bins + 1)) {
    if (k == 1) { uplifts[k] <- 0; next }
    n_top <- max(1, round(fracs[k] * n))
    top   <- ord[seq_len(n_top)]
    m1 <- if (sum(W[top] == 1) > 0) mean(Y[top][W[top] == 1]) else 0
    m0 <- if (sum(W[top] == 0) > 0) mean(Y[top][W[top] == 0]) else 0
    uplifts[k] <- (m1 - m0) * fracs[k]
  }
  data.frame(fraction = fracs, uplift = uplifts)
}

q_cf  <- compute_qini_simple(tau_cf, Y, W)
q_t   <- compute_qini_simple(tau_t, Y, W)
q_s   <- compute_qini_simple(tau_s, Y, W)
q_ora <- compute_qini_simple(tau_true, Y, W)  # Oracle

# Combine
q_all <- rbind(
  data.frame(q_cf,  Method = "Causal Forest"),
  data.frame(q_t,   Method = "T-Learner (OLS)"),
  data.frame(q_s,   Method = "S-Learner (OLS)"),
  data.frame(q_ora, Method = "Oracle (true tau)")
)

# Random baseline
overall_tau <- mean(Y[W == 1]) - mean(Y[W == 0])
q_rand <- data.frame(
  fraction = seq(0, 1, length.out = 101),
  uplift   = seq(0, 1, length.out = 101) * overall_tau,
  Method   = "Random Targeting"
)
q_all <- rbind(q_all, q_rand)

ggplot(q_all, aes(x = fraction, y = uplift, color = Method, linetype = Method)) +
  geom_line(linewidth = 0.9) +
  scale_color_manual(values = c(
    "Causal Forest"    = "#2166AC",
    "T-Learner (OLS)"  = "#4DAC26",
    "S-Learner (OLS)"  = "#E66101",
    "Oracle (true tau)" = "#762A83",
    "Random Targeting" = "gray50"
  )) +
  scale_linetype_manual(values = c(
    "Causal Forest"    = "solid",
    "T-Learner (OLS)"  = "solid",
    "S-Learner (OLS)"  = "solid",
    "Oracle (true tau)" = "dotted",
    "Random Targeting" = "dashed"
  )) +
  labs(
    x     = "Fraction Treated (Ranked by Estimated CATE)",
    y     = "Cumulative Uplift",
    title = "Qini Curve Comparison Across Estimators",
    color = NULL, linetype = NULL
  ) +
  theme_minimal(base_size = 12) +
  theme(legend.position = "bottom")

6.4 Optimal Treatment Targeting

If treatment has a per-unit cost cc, the optimal targeting rule is: Treat unit iτ̂(Xi)c\text{Treat unit } i \iff \hat{\tau}(X_i) \geq c

# Policy evaluation at different cost thresholds
cost_thresholds <- c(0, 0.25, 0.5, 0.75, 1.0)
policy_values <- sapply(cost_thresholds, function(c_val) {
  treated_by_policy <- tau_cf >= c_val
  # Policy value = avg tau for those targeted
  if (sum(treated_by_policy) == 0) return(0)
  mean(tau_true[treated_by_policy])  # Oracle value
})

policy_df <- data.frame(
  Cost      = cost_thresholds,
  Pct_Treat = sapply(cost_thresholds, function(c) mean(tau_cf >= c)),
  Avg_CATE  = round(policy_values, 4)
)
print(policy_df)

6.5 Policy Value via AIPW

# Estimate policy value using doubly-robust scores
evaluate_policy <- function(policy, Y, W, Y_hat, W_hat) {
  # DR policy value: E[policy * (Y(1) - cost) + (1-policy) * Y(0)]
  # Simplified: ATE in subgroup where policy = 1
  idx <- which(policy == 1)
  if (length(idx) == 0) return(NA)

  # AIPW estimate in subgroup
  W_res <- W[idx] - W_hat[idx]
  Y_res <- Y[idx] - Y_hat[idx]
  psi   <- Y_hat[idx] + W_res * (Y[idx] - Y_hat[idx]) /
    ifelse(W[idx] == 1, W_hat[idx], 1 - W_hat[idx])

  # Policy value relative to no treatment (all control)
  treated_score <- mean(psi[W[idx] == 1]) - mean(Y[idx][W[idx] == 0])
  treated_score
}

cat("Policy: treat top 50% by CATE\n")
top50 <- as.integer(tau_cf >= median(tau_cf))
cat("Fraction treated:", mean(top50), "\n")
cat("Oracle policy value (avg true tau for targeted):",
    round(mean(tau_true[top50 == 1]), 4), "\n")
cat("Vs. ATE (treat all):", round(mean(tau_true), 4), "\n")

7. HTE Visualization Suite

7.1 Forest Plot by Subgroups

# Create pre-specified subgroups based on X1 and X2
df_plot <- data.frame(X,
  cate   = tau_cf,
  W      = W,
  Y      = Y
)
df_plot$X1_group <- ifelse(X[, 1] > 0, "High X1", "Low X1")
df_plot$X2_group <- ifelse(X[, 2] > 0, "High X2", "Low X2")
df_plot$Subgroup <- paste(df_plot$X1_group, df_plot$X2_group, sep = " / ")

p_forest <- heterogeneity_plot(
  data           = df_plot,
  cate_var       = "cate",
  subgroup_var   = "Subgroup",
  overall_effect = ate_grf["estimate"],
  plot_type      = "forest",
  title          = "HTE by Pre-Specified Subgroups"
)
p_forest

7.2 CATE vs. Continuous Moderator

p_scatter <- heterogeneity_plot(
  data          = df_plot,
  cate_var      = "cate",
  moderator_var = "X1",
  plot_type     = "scatter",
  title         = "CATE vs. Key Moderator X1"
)
p_scatter

The LOESS curve confirms the linear relationship between X1X_1 and the CATE, as specified by the data-generating process.

7.3 Violin Plot by Treatment Group

df_plot$Treatment <- factor(W, levels = c(0, 1), labels = c("Control", "Treated"))
p_violin <- heterogeneity_plot(
  data         = df_plot,
  cate_var     = "cate",
  subgroup_var = "Treatment",
  plot_type    = "violin",
  title        = "CATE Distribution by Treatment Group"
)
p_violin

7.4 Partial Dependence Plot

The partial dependence plot shows the marginal effect of X1X_1 on the CATE, averaging over the distribution of other covariates.

# Partial dependence of CATE on X1: vary X1, hold others at observed values
x1_grid  <- seq(min(X[, 1]), max(X[, 1]), length.out = 50)
pdp_vals <- sapply(x1_grid, function(x1_val) {
  X_temp      <- X
  X_temp[, 1] <- x1_val
  mean(predict(cf, X_temp)$predictions)
})

pdp_df <- data.frame(X1 = x1_grid, CATE = pdp_vals)
ggplot(pdp_df, aes(x = X1, y = CATE)) +
  geom_line(color = "#2166AC", linewidth = 1.2) +
  geom_hline(yintercept = ate_grf["estimate"], linetype = "dashed", color = "gray50") +
  labs(
    x     = "X1 (Moderator)",
    y     = "Partial Average CATE",
    title = "Partial Dependence Plot: CATE vs. X1",
    subtitle = "Marginalizing over other covariates; dashed = ATE"
  ) +
  theme_minimal(base_size = 12)

7.5 Calibration Plot: Predicted vs. Actual CATE by Decile

# Bin by predicted CATE decile, compute realized treatment effect
calib_df <- data.frame(
  tau_cf   = tau_cf,
  tau_true = tau_true,
  W        = W,
  Y        = Y
)
calib_df$decile <- cut(tau_cf,
  breaks = quantile(tau_cf, seq(0, 1, 0.1)),
  include.lowest = TRUE, labels = paste0("D", 1:10)
)

calib_tab <- calib_df %>%
  group_by(decile) %>%
  dplyr::summarise(
    pred_cate    = mean(tau_cf),
    true_cate    = mean(tau_true),
    realized_eff = mean(Y[W == 1]) - mean(Y[W == 0]),
    .groups = "drop"
  )

ggplot(calib_tab, aes(x = pred_cate, y = realized_eff)) +
  geom_point(size = 3, color = "#2166AC") +
  geom_errorbar(aes(ymin = realized_eff - 0.2, ymax = realized_eff + 0.2),
                width = 0.05, color = "#2166AC", alpha = 0.5) +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
  geom_line(aes(y = true_cate), color = "#D73027", linewidth = 0.8,
            linetype = "solid") +
  labs(
    x     = "Mean Predicted CATE (by decile)",
    y     = "Realized Treatment Effect",
    title = "CATE Calibration by Decile",
    subtitle = "Dashed: perfect calibration; red: true CATE by decile"
  ) +
  theme_minimal(base_size = 12)

8. Double Machine Learning (DML)

8.1 The Partially Linear Model

Robinson (1988) introduced the partially linear model for flexible treatment effect estimation: Yi=θ0Wi+g0(Xi)+εi,𝔼[εiXi,Wi]=0Y_i = \theta_0 W_i + g_0(X_i) + \varepsilon_i, \quad \mathbb{E}[\varepsilon_i \mid X_i, W_i] = 0

Key insight: Partialling out XX from both YY and WW recovers θ0\theta_0 as the coefficient in: Ỹi=θ0W̃i+εi\tilde{Y}_i = \theta_0 \tilde{W}_i + \varepsilon_i where Ỹi=Yi𝔼[YiXi]\tilde{Y}_i = Y_i - \mathbb{E}[Y_i \mid X_i] and W̃i=Wi𝔼[WiXi]\tilde{W}_i = W_i - \mathbb{E}[W_i \mid X_i].

Chernozhukov et al. (2018) showed that cross-fitting (using separate data folds to estimate nuisance functions) enables n\sqrt{n}-consistent inference on θ0\theta_0 even when the nuisance models converge at slower rates.

8.2 DML Implementation with Cross-Fitting

library(glmnet)

# DML with K=5 cross-fitting
set.seed(42)
K   <- 5
n   <- nrow(X)
folds <- sample(rep(1:K, length.out = n))

m_hat_dml <- numeric(n)
e_hat_dml <- numeric(n)

for (k in 1:K) {
  tr  <- which(folds != k)
  te  <- which(folds == k)

  # Outcome model: lasso
  cv_m <- cv.glmnet(X[tr, ], Y[tr], alpha = 1)
  m_hat_dml[te] <- predict(cv_m, X[te, ], s = "lambda.min")

  # Propensity: lasso logistic
  cv_e <- cv.glmnet(X[tr, ], W[tr], family = "binomial", alpha = 1)
  e_hat_dml[te] <- predict(cv_e, X[te, ], s = "lambda.min", type = "response")
}

# Clip propensity
e_hat_dml <- pmax(pmin(e_hat_dml, 0.99), 0.01)

# Residualize
Y_res_dml <- Y - m_hat_dml
W_res_dml <- W - e_hat_dml

# DML estimator: OLS of Y_res on W_res
dml_mod  <- lm(Y_res_dml ~ W_res_dml - 1)
theta_dml <- coef(dml_mod)["W_res_dml"]
se_dml    <- sqrt(vcov(dml_mod)["W_res_dml", "W_res_dml"])

cat("DML ATE estimate:", round(theta_dml, 4), "\n")
cat("DML SE:          ", round(se_dml, 4), "\n")
cat("95% CI: [", round(theta_dml - 1.96 * se_dml, 4),
    ",", round(theta_dml + 1.96 * se_dml, 4), "]\n")
cat("True ATE:", round(mean(tau_true), 4), "\n")

8.3 Heterogeneous DML via Interactions

To allow θ\theta to vary with a key moderator VV (e.g., X1X_1), interact the moderator with the treatment residual:

Ỹi=θ0W̃i+θ1W̃iVi+ηi\tilde{Y}_i = \theta_0 \tilde{W}_i + \theta_1 \tilde{W}_i \cdot V_i + \eta_i

# Moderator: X1 (centered)
V       <- X[, 1]
V_c     <- V - mean(V)

# Interaction design matrix
Z_int <- cbind(W_res_dml, W_res_dml * V_c)
colnames(Z_int) <- c("W_res", "W_res_x_V")

dml_het <- lm(Y_res_dml ~ Z_int - 1)
coef_het <- coef(dml_het)
se_het   <- sqrt(diag(vcov(dml_het)))

cat("Heterogeneous DML:\n")
cat("  theta_0 (ATE at V=0):", round(coef_het[1], 4),
    "| SE:", round(se_het[1], 4), "\n")
cat("  theta_1 (moderation by X1):", round(coef_het[2], 4),
    "| SE:", round(se_het[2], 4), "\n")
cat("  (True slope = 1)\n")

The coefficient θ1\theta_1 estimates how the treatment effect changes per unit increase in X1X_1. Since τ(Xi)=Xi1\tau(X_i) = X_{i1}, the true slope is 1.

8.4 Using the DoubleML Package

library(DoubleML)
library(mlr3)
library(mlr3learners)

# Setup DoubleML data object
dml_data <- DoubleMLData$new(
  data          = data.frame(X, Y = Y, W = W),
  y_col         = "Y",
  d_cols        = "W",
  x_cols        = paste0("X", 1:p)
)

# Partially linear model with lasso learners
set.seed(42)
learner_m <- lrn("regr.cv_glmnet", alpha = 1)
learner_r <- lrn("regr.cv_glmnet", alpha = 1)

dml_plr <- DoubleMLPLR$new(
  obj_dml_data  = dml_data,
  ml_l          = learner_m,
  ml_m          = learner_r,
  n_folds       = 5,
  score         = "partialling out"
)
dml_plr$fit()
print(dml_plr$summary())

8.5 Interactive Regression Model (IRM) for CATE

# IRM: allows fully flexible treatment effect heterogeneity
learner_g <- lrn("regr.cv_glmnet", alpha = 1)
learner_m2 <- lrn("classif.cv_glmnet", alpha = 1)

dml_irm <- DoubleMLIRM$new(
  obj_dml_data = dml_data,
  ml_g          = learner_g,
  ml_m          = learner_m2,
  n_folds       = 5,
  score         = "ATE"
)
dml_irm$fit()
print(dml_irm$summary())

9. AIPW / Augmented IPW for HTE

9.1 The AIPW Score

The AIPW (Augmented Inverse Probability Weighting) estimator for the ATE combines outcome regression with IPW to achieve double robustness:

τ̂AIPW=1ni=1n[μ̂1(Xi)μ̂0(Xi)+Wi(Yiμ̂1(Xi))ê(Xi)(1Wi)(Yiμ̂0(Xi))1ê(Xi)]\hat{\tau}_{AIPW} = \frac{1}{n} \sum_{i=1}^n \left[\hat{\mu}_1(X_i) - \hat{\mu}_0(X_i) + \frac{W_i(Y_i - \hat{\mu}_1(X_i))}{\hat{e}(X_i)} - \frac{(1-W_i)(Y_i - \hat{\mu}_0(X_i))}{1 - \hat{e}(X_i)}\right]

This is doubly robust: consistent if either μ̂\hat{\mu} or ê\hat{e} is correctly specified.

9.2 AIPW ATE Estimation via dr_ate()

if (requireNamespace("MatchIt", quietly = TRUE)) {
  data(lalonde, package = "MatchIt")

  aipw_res <- dr_ate(
    data       = lalonde,
    outcome    = "re78",
    treatment  = "treat",
    covariates = c("age", "educ", "married", "nodegree", "re74", "re75"),
    estimand   = "ATE",
    ps_trim    = c(0.01, 0.99)
  )

  cat("AIPW ATE for Lalonde:\n")
  cat("  Estimate:", round(aipw_res$estimate, 2), "\n")
  cat("  SE:      ", round(aipw_res$se, 2), "\n")
  cat("  95% CI: [", round(aipw_res$ci_lower, 2), ",",
      round(aipw_res$ci_upper, 2), "]\n")
  cat("  p-value:", round(aipw_res$p_value, 4), "\n")
  cat("  n trimmed:", aipw_res$n_trimmed, "\n")
}

9.3 DR-Learner for CATE

The DR-Learner (Kennedy 2020) uses AIPW pseudo-outcomes as targets for CATE estimation:

τ̃iDR=μ̂1(Xi)μ̂0(Xi)+Wi(Yiμ̂1(Xi))ê(Xi)(1Wi)(Yiμ̂0(Xi))1ê(Xi)\tilde{\tau}_i^{DR} = \hat{\mu}_1(X_i) - \hat{\mu}_0(X_i) + \frac{W_i(Y_i - \hat{\mu}_1(X_i))}{\hat{e}(X_i)} - \frac{(1-W_i)(Y_i - \hat{\mu}_0(X_i))}{1 - \hat{e}(X_i)}

Note 𝔼[τ̃iDRXi]=τ(Xi)\mathbb{E}[\tilde{\tau}_i^{DR} \mid X_i] = \tau(X_i) under correct specification. The DR-Learner regresses these pseudo-outcomes on XX to obtain τ̂(x)\hat{\tau}(x).

# Compute AIPW pseudo-outcomes using cross-fitted nuisances
# Reuse m_hat_dml and e_hat_dml from Section 8

# Separate mu1 and mu0 via cross-fitting
mu1_hat_dr <- numeric(n)
mu0_hat_dr <- numeric(n)

set.seed(42)
for (k in 1:K) {
  tr  <- which(folds != k)
  te  <- which(folds == k)

  # Outcome models
  X_df_k <- as.data.frame(X)
  lm1 <- lm(Y ~ ., data = X_df_k[tr[W[tr] == 1], , drop = FALSE],
             subset = NULL)
  lm1 <- lm(Y[tr][W[tr] == 1] ~ ., data = X_df_k[tr[W[tr] == 1], ])
  lm0 <- lm(Y[tr][W[tr] == 0] ~ ., data = X_df_k[tr[W[tr] == 0], ])

  mu1_hat_dr[te] <- predict(lm1, newdata = X_df_k[te, ])
  mu0_hat_dr[te] <- predict(lm0, newdata = X_df_k[te, ])
}

# AIPW pseudo-outcomes
psi_dr <- mu1_hat_dr - mu0_hat_dr +
  W * (Y - mu1_hat_dr) / e_hat_dml -
  (1 - W) * (Y - mu0_hat_dr) / (1 - e_hat_dml)

cat("Mean DR pseudo-outcome (ATE proxy):", round(mean(psi_dr), 4), "\n")
cat("True ATE:", round(mean(tau_true), 4), "\n")

# DR-Learner: regress pseudo-outcomes on X
cv_dr <- cv.glmnet(X, psi_dr, alpha = 1)
tau_dr <- predict(cv_dr, X, s = "lambda.min")[, 1]

cat("DR-Learner RMSE:", round(sqrt(mean((tau_dr - tau_true)^2)), 4), "\n")

9.4 Efficiency Comparison: IPW vs. AIPW

# Naive IPW estimate
ipw_ate <- mean(W * Y / e_hat_dml) - mean((1 - W) * Y / (1 - e_hat_dml))

# AIPW estimate
aipw_ate <- mean(psi_dr)

# OLS estimate
ols_ate <- coef(lm(Y ~ W + X))[["W"]]

cat("Comparison of ATE estimators:\n")
cat(sprintf("  OLS:   %.4f\n", ols_ate))
cat(sprintf("  IPW:   %.4f\n", ipw_ate))
cat(sprintf("  AIPW:  %.4f\n", aipw_ate))
cat(sprintf("  True:  %.4f\n", mean(tau_true)))

# Variance comparison
se_ipw  <- sd(W * Y / e_hat_dml - (1 - W) * Y / (1 - e_hat_dml)) / sqrt(n)
se_aipw <- sd(psi_dr) / sqrt(n)
cat(sprintf("  SE IPW:  %.4f\n", se_ipw))
cat(sprintf("  SE AIPW: %.4f  (efficiency gain = %.1f%%)\n",
            se_aipw, 100 * (se_ipw - se_aipw) / se_ipw))

10. Moderated Treatment Effects

10.1 Treatment-Covariate Interaction Models

The simplest parametric approach to HTE is to include treatment-covariate interactions:

Yi=α+γWi+Xiβ+WiXiδ+εiY_i = \alpha + \gamma W_i + X_i'\beta + W_i \cdot X_i'\delta + \varepsilon_i

The vector δ\delta captures how the treatment effect varies with XX. This is interpretable but restrictive (linearity and correct functional form).

# Fit interaction model
X_df_int <- as.data.frame(X)
colnames(X_df_int) <- paste0("X", 1:p)

# Main effects + treatment interactions
int_formula <- as.formula(
  paste("Y ~ W + (",
        paste(paste0("X", 1:p), collapse = " + "),
        ") + W:(",
        paste(paste0("X", 1:p), collapse = " + "),
        ")")
)
int_mod <- lm(int_formula, data = cbind(X_df_int, W = W, Y = Y))

# Extract interaction coefficients
int_coefs <- coef(int_mod)
int_ses   <- sqrt(diag(vcov(int_mod)))
int_names <- names(int_coefs)

# Focus on interaction terms
int_idx   <- grepl("W:X", int_names)
int_summary <- data.frame(
  Term     = int_names[int_idx],
  Estimate = round(int_coefs[int_idx], 4),
  SE       = round(int_ses[int_idx], 4),
  t_stat   = round(int_coefs[int_idx] / int_ses[int_idx], 3),
  p_value  = round(2 * pt(-abs(int_coefs[int_idx] / int_ses[int_idx]),
                           df = n - length(int_coefs)), 4)
)
print(int_summary)

The interaction term W:X1 should be close to 1 (the true slope of τ(Xi)=Xi1\tau(X_i) = X_{i1}).

10.2 Omnibus Test for Moderation

An omnibus F-test compares the interaction model to the main-effects-only model:

main_mod <- lm(Y ~ W + ., data = cbind(X_df_int, W = W, Y = Y))
f_test   <- anova(main_mod, int_mod)
cat("Omnibus F-test for treatment-covariate interactions:\n")
print(f_test)
cat("Result: ", if (f_test[2, "Pr(>F)"] < 0.05) "Significant HTE detected."
              else "No significant HTE.", "\n")

10.3 Marginal Effects of Treatment at Different Covariate Values

# Manual marginal effects: E[tau | X1 = x1], averaging over X2, ..., X10
x1_vals  <- seq(-2.5, 2.5, length.out = 50)
me_vals  <- sapply(x1_vals, function(x1) {
  X_temp    <- X_df_int
  X_temp$X1 <- x1
  d1 <- cbind(X_temp, W = 1)
  d0 <- cbind(X_temp, W = 0)
  mean(predict(int_mod, d1) - predict(int_mod, d0))
})

me_df <- data.frame(X1 = x1_vals, AME = me_vals)
ggplot(me_df, aes(x = X1, y = AME)) +
  geom_line(color = "#2166AC", linewidth = 1.2) +
  geom_hline(yintercept = mean(tau_true), linetype = "dashed") +
  geom_ribbon(aes(ymin = AME - 0.15, ymax = AME + 0.15),
              alpha = 0.15, fill = "#2166AC") +
  labs(
    x     = "X1 (Moderator)",
    y     = "Average Marginal Effect of Treatment",
    title = "Marginal Treatment Effect by X1",
    subtitle = "Dashed = ATE; shaded = approximate ±95% CI"
  ) +
  theme_minimal(base_size = 12)

10.4 Marginal Effects with marginaleffects

library(marginaleffects)

# Average marginal effect of W at different X1 levels
ame_W <- avg_slopes(int_mod, variables = "W")
cat("Average Marginal Effect of W:\n")
print(ame_W)

# Marginal effect of W by X1 quantile
me_by_x1 <- slopes(
  int_mod,
  variables  = "W",
  newdata    = datagrid(X1 = quantile(X[, 1], c(0.1, 0.25, 0.5, 0.75, 0.9)))
)
cat("\nMarginal effect of W at X1 quantiles:\n")
print(me_by_x1[, c("X1", "estimate", "std.error", "conf.low", "conf.high")])

11. Subgroup Analysis Best Practices

11.1 Pre-Specified vs. Data-Driven Subgroups

Pre-specified subgroups (defined before seeing outcome data): - Control Type I error - More credible for publication - Require a pre-analysis plan (PAP) - May miss data-driven opportunities

Data-driven subgroups (from ML or recursive partitioning): - Can discover unexpected heterogeneity - Risk of overfitting and false discoveries - Require honest estimation (separate sample or honesty) and correction for multiple testing - Should be validated on held-out data

Best practice: Use pre-specified primary analyses; treat data-driven results as exploratory.

11.2 Multiple Testing Correction

# Simulate testing treatment effects in 10 pre-specified subgroups
set.seed(42)
n_subgroups <- 10
subgroup_pvals <- numeric(n_subgroups)

for (s in 1:n_subgroups) {
  # Define subgroup: quantile of X_s
  Xs       <- X[, s]
  in_group <- Xs > median(Xs)
  Y_g      <- Y[in_group]
  W_g      <- W[in_group]

  mod_g    <- lm(Y_g ~ W_g)
  subgroup_pvals[s] <- coef(summary(mod_g))["W_g", "Pr(>|t|)"]
}

# Multiple testing corrections
mt_df <- data.frame(
  Subgroup  = paste0("X", 1:n_subgroups, " > Median"),
  p_raw     = round(subgroup_pvals, 4),
  p_bonf    = round(p.adjust(subgroup_pvals, method = "bonferroni"), 4),
  p_bh      = round(p.adjust(subgroup_pvals, method = "BH"), 4),
  sig_raw   = subgroup_pvals < 0.05,
  sig_bonf  = p.adjust(subgroup_pvals, "bonferroni") < 0.05,
  sig_bh    = p.adjust(subgroup_pvals, "BH") < 0.05
)
print(mt_df)
cat("\nSignificant at 5%:\n")
cat("  Raw p-value: ", sum(mt_df$sig_raw), "subgroups\n")
cat("  Bonferroni:  ", sum(mt_df$sig_bonf), "subgroups\n")
cat("  BH-FDR:      ", sum(mt_df$sig_bh), "subgroups\n")

11.3 Causal Tree (Honest Partitioning)

The causal tree (Athey & Imbens 2016) uses honest estimation: the partitioning data and estimation data are separate. This controls for the bias introduced by searching over splits.

# Use GRF's causal_forest as an approximation to causal trees
# GRF uses honest sample splitting within each tree

# Identify best splits using variable importance
top_var <- which.max(variable_importance(cf))
cat("Top moderator by GRF importance:", colnames(X)[top_var], "\n")

# Use the top variable to define binary subgroups
split_val <- median(X[, top_var])
group_hi  <- X[, top_var] > split_val
group_lo  <- !group_hi

# Honest ATE within each group (using GRF predictions)
cat("\nHTE by top moderator subgroup:\n")
cat("  High X1 group: mean CATE =",
    round(mean(tau_cf[group_hi]), 4),
    "| n =", sum(group_hi), "\n")
cat("  Low X1 group: mean CATE  =",
    round(mean(tau_cf[group_lo]), 4),
    "| n =", sum(group_lo), "\n")
cat("  True HTE: High =", round(mean(tau_true[group_hi]), 4),
    "| Low =", round(mean(tau_true[group_lo]), 4), "\n")

11.4 Virtual Twins Method

The virtual twins method (Foster et al. 2011) estimates individual treatment effects, then builds a decision tree on the CATE estimates to identify interpretable subgroups:

# Step 1: Estimate CATE (already have tau_cf from causal forest)
# Step 2: Classify as "responders" vs "non-responders"
threshold   <- 0.0   # treat CATE > 0 as responders
responder   <- as.integer(tau_cf > threshold)

cat("Virtual Twins:\n")
cat("  Responders (CATE > 0):", sum(responder), "/", n,
    sprintf("(%.1f%%)\n", 100 * mean(responder)))
cat("  Mean CATE for responders:", round(mean(tau_cf[responder == 1]), 4), "\n")
cat("  Mean CATE for non-responders:", round(mean(tau_cf[responder == 0]), 4), "\n")

# Step 3: Fit a classification tree to identify subgroup predictors
if (requireNamespace("rpart", quietly = TRUE)) {
  library(rpart)
  vt_tree <- rpart(
    responder ~ .,
    data   = cbind(as.data.frame(X), responder = responder),
    method = "class",
    control = rpart.control(cp = 0.02, maxdepth = 3)
  )
  cat("\nVirtual twins decision tree splits:\n")
  print(vt_tree$frame[, c("var", "n", "yval")])
}

11.5 Subgroup Forest Plot via heterogeneity_plot()

# Define richer subgroups
df_sg <- data.frame(X, cate = tau_cf, W = W, Y = Y)
df_sg$Sex      <- ifelse(X[, 3] > 0, "Male", "Female")
df_sg$Age_Grp  <- cut(X[, 4], breaks = c(-Inf, -1, 0, 1, Inf),
                       labels = c("Q1", "Q2", "Q3", "Q4"))
df_sg$Industry <- sample(c("Tech", "Finance", "Health", "Retail"),
                          n, replace = TRUE, prob = c(0.3, 0.2, 0.25, 0.25))

# Forest plot by industry
p_sg <- heterogeneity_plot(
  data           = df_sg,
  cate_var       = "cate",
  subgroup_var   = "Industry",
  overall_effect = mean(tau_cf),
  plot_type      = "forest",
  title          = "HTE by Industry Subgroup"
)
p_sg

12. Evaluation Framework

12.1 Oracle Evaluation (Simulation)

When true τi\tau_i is known (simulation), we can compute oracle metrics:

eval_metrics <- function(tau_hat, tau_true, name) {
  data.frame(
    Estimator = name,
    RMSE      = round(sqrt(mean((tau_hat - tau_true)^2)), 4),
    MAE       = round(mean(abs(tau_hat - tau_true)), 4),
    Bias      = round(mean(tau_hat - tau_true), 4),
    Corr      = round(cor(tau_hat, tau_true), 4),
    Pct_Sign  = round(mean(sign(tau_hat) == sign(tau_true)), 4)
  )
}

eval_df <- rbind(
  eval_metrics(tau_s,   tau_true, "S-Learner (OLS)"),
  eval_metrics(tau_t,   tau_true, "T-Learner (OLS)"),
  eval_metrics(tau_x,   tau_true, "X-Learner (OLS)"),
  eval_metrics(tau_cf,  tau_true, "Causal Forest"),
  eval_metrics(tau_dr,  tau_true, "DR-Learner")
)

if (exists("tau_r")) {
  eval_df <- rbind(eval_df, eval_metrics(tau_r, tau_true, "R-Learner"))
}

print(eval_df)

12.2 Proxy Evaluation: R-Loss and Tau-Risk

When true τi\tau_i is unknown, the R-loss provides a proxy: R(τ̂)=1ni=1n[Ỹiτ̂(Xi)W̃i]2\mathcal{L}_R(\hat{\tau}) = \frac{1}{n} \sum_{i=1}^n \left[\tilde{Y}_i - \hat{\tau}(X_i) \tilde{W}_i\right]^2

Lower R-loss indicates better CATE recovery without requiring knowledge of true effects.

# Compute R-loss for each estimator
r_loss <- function(tau_hat, Y_res, W_res) {
  mean((Y_res - tau_hat * W_res)^2)
}

rloss_df <- data.frame(
  Estimator = c("S-Learner (OLS)", "T-Learner (OLS)",
                "X-Learner (OLS)", "Causal Forest", "DR-Learner"),
  R_Loss    = round(c(
    r_loss(tau_s,  Y_res_dml, W_res_dml),
    r_loss(tau_t,  Y_res_dml, W_res_dml),
    r_loss(tau_x,  Y_res_dml, W_res_dml),
    r_loss(tau_cf, Y_res_dml, W_res_dml),
    r_loss(tau_dr, Y_res_dml, W_res_dml)
  ), 5)
)
print(rloss_df[order(rloss_df$R_Loss), ])

12.3 Cross-Validated AUUC

# 5-fold cross-validated AUUC for causal forest
K_auuc  <- 5
set.seed(42)
folds_auuc <- sample(rep(1:K_auuc, length.out = n))
auuc_cv <- numeric(K_auuc)

for (k in 1:K_auuc) {
  tr <- which(folds_auuc != k)
  te <- which(folds_auuc == k)

  cf_k <- causal_forest(X[tr, ], Y[tr], W[tr], num.trees = 500, seed = k)
  tau_k <- predict(cf_k, X[te, ])$predictions

  qini_k <- qini_curve(tau_k, Y[te], W[te],
                        boot_reps = 100, seed = k)
  auuc_cv[k] <- qini_k$auuc
}

cat("Cross-validated AUUC (causal forest):\n")
cat("  Mean:", round(mean(auuc_cv), 4), "\n")
cat("  SD:  ", round(sd(auuc_cv), 4), "\n")
cat("  95% CI: [", round(mean(auuc_cv) - 2 * sd(auuc_cv) / sqrt(K_auuc), 4),
    ",", round(mean(auuc_cv) + 2 * sd(auuc_cv) / sqrt(K_auuc), 4), "]\n")

12.4 Bootstrap Confidence Intervals for AUUC

# Bootstrap CI for AUUC of causal forest
B <- 300
set.seed(42)
boot_auuc_vals <- numeric(B)

for (b in seq_len(B)) {
  idx   <- sample(n, replace = TRUE)
  q_b   <- qini_curve(tau_cf[idx], Y[idx], W[idx],
                       boot_reps = 50, seed = b)
  boot_auuc_vals[b] <- q_b$auuc
}

cat("Bootstrap AUUC distribution:\n")
cat("  Mean:", round(mean(boot_auuc_vals), 4), "\n")
cat("  SD:  ", round(sd(boot_auuc_vals), 4), "\n")
cat("  95% CI (percentile): [",
    round(quantile(boot_auuc_vals, 0.025), 4), ",",
    round(quantile(boot_auuc_vals, 0.975), 4), "]\n")

13. CATE Estimation Uncertainty

13.1 Pointwise Confidence Intervals from Causal Forest

# GRF provides honest confidence intervals via IJ variance
ci_df <- data.frame(
  i         = 1:n,
  tau_hat   = tau_cf,
  tau_se    = tau_se,
  tau_lower = tau_cf - 1.96 * tau_se,
  tau_upper = tau_cf + 1.96 * tau_se,
  tau_true  = tau_true
)

# Coverage check
coverage <- mean(ci_df$tau_true >= ci_df$tau_lower &
                   ci_df$tau_true <= ci_df$tau_upper)
cat("Empirical coverage of 95% CIs:", round(coverage, 3), "\n")
cat("(Nominal: 0.95)\n")

# Plot CIs for a subset of units, sorted by predicted CATE
sub <- ci_df[order(ci_df$tau_hat)[seq(1, n, by = n %/% 100)], ]
ggplot(sub, aes(x = seq_along(tau_hat), y = tau_hat)) +
  geom_ribbon(aes(ymin = tau_lower, ymax = tau_upper), alpha = 0.2, fill = "#4393C3") +
  geom_line(color = "#2166AC", linewidth = 0.5) +
  geom_point(aes(y = tau_true), color = "#D73027", size = 0.8, alpha = 0.6) +
  labs(
    x     = "Unit (sorted by estimated CATE)",
    y     = "CATE",
    title = "Causal Forest: Pointwise 95% Confidence Intervals",
    subtitle = sprintf("Coverage: %.1f%% | Red: true CATE | Blue band: 95%% CI",
                       100 * coverage)
  ) +
  theme_minimal(base_size = 12)

13.2 Bootstrap of Entire ML Pipeline

Bootstrapping the full pipeline (nuisance estimation + CATE estimation) provides valid uncertainty quantification without relying on asymptotic theory:

set.seed(42)
B_pipe <- 100
ate_boots <- numeric(B_pipe)

for (b in seq_len(B_pipe)) {
  idx  <- sample(n, replace = TRUE)
  cf_b <- causal_forest(X[idx, ], Y[idx], W[idx],
                         num.trees = 500, seed = b)
  ate_b <- average_treatment_effect(cf_b)
  ate_boots[b] <- ate_b["estimate"]
}

cat("Bootstrap ATE distribution (full pipeline):\n")
cat("  Mean:", round(mean(ate_boots), 4), "\n")
cat("  SD:  ", round(sd(ate_boots), 4), "\n")
cat("  95% CI (percentile): [",
    round(quantile(ate_boots, 0.025), 4), ",",
    round(quantile(ate_boots, 0.975), 4), "]\n")
cat("  True ATE:", round(mean(tau_true), 4), "\n")

13.3 Selective Inference Issues

When subgroups are selected based on estimated CATE and then tested, naive p-values are invalid - this is the winner’s curse or data snooping problem.

Solutions: 1. Sample splitting: Use one half for subgroup discovery, the other for inference. 2. Honest estimation: GRF uses honesty by construction. 3. Simultaneous confidence bands: Bonferroni or bootstrap-based corrections. 4. Pre-specification: Define subgroups before seeing outcome data.

# Illustration: naive vs. split-sample inference
set.seed(42)
n_half <- n %/% 2
idx_disc  <- 1:n_half         # discovery sample
idx_valid <- (n_half+1):n     # validation sample

# "Discover" best subgroup on discovery sample
X_disc  <- X[idx_disc, ]
tau_disc <- tau_cf[idx_disc]

best_var <- which.max(apply(X_disc, 2, function(xv) {
  abs(cor(xv, tau_disc))
}))
split_d <- median(X_disc[, best_var])

cat("Best subgroup variable (discovery):", colnames(X)[best_var], "\n")
cat("Split value:", round(split_d, 3), "\n")

# Naive p-value (same sample as discovery)
in_disc_hi <- X_disc[, best_var] > split_d
ate_naive  <- mean(tau_disc[in_disc_hi]) - mean(tau_disc[!in_disc_hi])

# Honest p-value (validation sample)
X_val   <- X[idx_valid, ]
tau_val <- tau_cf[idx_valid]
in_val_hi  <- X_val[, best_var] > split_d
ate_honest <- mean(tau_val[in_val_hi]) - mean(tau_val[!in_val_hi])

cat("\nNaive CATE diff (discovery sample):", round(ate_naive, 4), "\n")
cat("Honest CATE diff (validation sample):", round(ate_honest, 4), "\n")
cat("True CATE diff (full sample):",
    round(mean(tau_true[X[, best_var] > split_d]) -
          mean(tau_true[X[, best_var] <= split_d]), 4), "\n")

14. Practical HTE Analysis for Publications

14.1 Complete Publication Pipeline

A complete HTE analysis for publication should include:

  1. Data and design description: Sample, randomization (or selection on observables), covariates.
  2. Balance check: Standardized differences, love plot.
  3. ATE estimation: Primary result with inference.
  4. CATE estimation: Method selection and validation.
  5. Heterogeneity tests: BLP, calibration test.
  6. Subgroup analysis: Pre-specified and data-driven.
  7. Policy implications: Qini curve, targeting rule, policy value.
  8. Sensitivity analysis: Robustness to model choice, functional form.

14.2 Subgroup ATE Table

# Publication-quality subgroup table
subgroup_vars <- list(
  "X1 (High)"    = X[, 1] > 0,
  "X1 (Low)"     = X[, 1] <= 0,
  "X2 (High)"    = X[, 2] > 0,
  "X2 (Low)"     = X[, 2] <= 0,
  "Both High"    = X[, 1] > 0 & X[, 2] > 0,
  "Both Low"     = X[, 1] <= 0 & X[, 2] <= 0,
  "X1>0, X2<=0"  = X[, 1] > 0 & X[, 2] <= 0,
  "X1<=0, X2>0"  = X[, 1] <= 0 & X[, 2] > 0,
  "All"          = rep(TRUE, n)
)

sg_table <- lapply(names(subgroup_vars), function(sg_name) {
  idx <- which(subgroup_vars[[sg_name]])
  n_g <- length(idx)

  # ATE via GRF predictions (honest)
  ate_g <- mean(tau_cf[idx])
  se_g  <- sd(tau_cf[idx]) / sqrt(n_g)

  # True ATE
  true_g <- mean(tau_true[idx])

  data.frame(
    Subgroup    = sg_name,
    N           = n_g,
    CATE_Est    = round(ate_g, 3),
    SE          = round(se_g, 3),
    CI_lo       = round(ate_g - 1.96 * se_g, 3),
    CI_hi       = round(ate_g + 1.96 * se_g, 3),
    p_val       = round(2 * pnorm(-abs(ate_g / se_g)), 4),
    True_CATE   = round(true_g, 3),
    stringsAsFactors = FALSE
  )
})
sg_df <- do.call(rbind, sg_table)
sg_df$Sig <- ifelse(sg_df$p_val < 0.001, "***",
              ifelse(sg_df$p_val < 0.01, "**",
               ifelse(sg_df$p_val < 0.05, "*", "")))
print(sg_df)

14.3 Publication Forest Plot

sg_plot_df <- sg_df[sg_df$Subgroup != "All", ]
overall_ate <- sg_df$CATE_Est[sg_df$Subgroup == "All"]

sg_plot_df$Subgroup <- factor(sg_plot_df$Subgroup,
  levels = sg_plot_df$Subgroup[order(sg_plot_df$CATE_Est)])

ggplot(sg_plot_df, aes(x = CATE_Est, y = Subgroup)) +
  geom_vline(xintercept = overall_ate, linetype = "dashed",
             color = "#D73027", linewidth = 0.8) +
  geom_vline(xintercept = 0, color = "gray70") +
  geom_errorbar(aes(xmin = CI_lo, xmax = CI_hi),
                width = 0.3, color = "#2166AC", linewidth = 0.7, orientation = "y") +
  geom_point(aes(size = N), color = "#2166AC", shape = 18) +
  geom_text(aes(label = sprintf("%.3f%s", CATE_Est, Sig)),
            hjust = -0.3, size = 3) +
  scale_size_continuous(range = c(2, 6), guide = "none") +
  labs(
    x        = "Estimated CATE (95% CI)",
    y        = NULL,
    title    = "Subgroup Treatment Effects",
    subtitle = sprintf("Red dashed = Overall ATE (%.3f) | *** p<.001, ** p<.01, * p<.05",
                       overall_ate),
    caption  = "Estimates from GRF causal forest (out-of-bag predictions)"
  ) +
  theme_minimal(base_size = 12) +
  theme(
    panel.grid.minor = element_blank(),
    axis.text.y      = element_text(face = "bold")
  )

14.4 Feature Importance for HTE

# GRF variable importance
vi_df <- data.frame(
  Feature    = colnames(X),
  Importance = as.numeric(variable_importance(cf))
)
vi_df <- vi_df[order(vi_df$Importance, decreasing = TRUE), ]
vi_df$Feature <- factor(vi_df$Feature, levels = rev(vi_df$Feature))
vi_df$IsTrue  <- vi_df$Feature == "X1"

ggplot(vi_df, aes(x = Feature, y = Importance, fill = IsTrue)) +
  geom_col(alpha = 0.85, show.legend = FALSE) +
  scale_fill_manual(values = c("FALSE" = "#4393C3", "TRUE" = "#D73027")) +
  coord_flip() +
  labs(
    x     = NULL,
    y     = "Variable Importance (GRF)",
    title = "Feature Importance for Treatment Effect Heterogeneity",
    subtitle = "Red = true driver of heterogeneity (X1)"
  ) +
  theme_minimal(base_size = 12)

14.5 CATE Distribution Summary Figure

p1 <- ggplot(data.frame(tau = tau_cf), aes(x = tau)) +
  geom_histogram(bins = 50, fill = "#4393C3", alpha = 0.7, color = "white") +
  geom_vline(xintercept = mean(tau_cf), color = "#D73027",
             linewidth = 1.2, linetype = "solid") +
  geom_vline(xintercept = 0, linetype = "dashed") +
  labs(x = "Estimated CATE", y = "Count",
       title = "CATE Distribution") +
  theme_minimal(base_size = 11)

p2 <- ggplot(data.frame(tau_cf = tau_cf, X1 = X[, 1]),
             aes(x = X1, y = tau_cf)) +
  geom_point(alpha = 0.2, size = 0.8, color = "#2166AC") +
  geom_smooth(method = "lm", color = "#D73027", se = TRUE) +
  labs(x = "X1 (top moderator)", y = "Estimated CATE",
       title = "CATE vs. Top Moderator") +
  theme_minimal(base_size = 11)

if (requireNamespace("patchwork", quietly = TRUE)) {
  library(patchwork)
  p1 + p2
} else {
  print(p1)
  print(p2)
}

15. Stochastic Treatment Regimes & Targeting Rules

15.1 Optimal Policy Rules

A deterministic targeting rule assigns treatment based on covariates: π*(x)=1[τ(x)c]\pi^*(x) = \mathbf{1}[\tau(x) \geq c]

where cc is a cost threshold (e.g., cost of treatment). The policy value is: V(π)=𝔼[π(Xi)Yi(1)+(1π(Xi))Yi(0)]V(\pi) = \mathbb{E}[\pi(X_i) \cdot Y_i(1) + (1 - \pi(X_i)) \cdot Y_i(0)]

15.2 Policy Value Estimation via GRF

# Compute policy value for different cost thresholds
costs     <- seq(-1.5, 1.5, by = 0.25)
pv_tab <- lapply(costs, function(c_val) {
  # Policy: treat if CATE > c
  pi_i     <- as.integer(tau_cf >= c_val)
  pct_treat <- mean(pi_i)

  # Policy value via DR scores
  psi_policy <- pi_i * psi_dr + (1 - pi_i) * 0  # DR augmented
  # Actual policy value = E[Y(pi)]
  # = E[mu0 + pi * tau] ≈ E[mu0] + pi * CATE
  pv_oracle <- mean(mu0_hat_dr + pi_i * tau_true)
  pv_est    <- mean(mu0_hat_dr + pi_i * tau_cf)

  data.frame(
    Cost       = c_val,
    Pct_Treat  = round(pct_treat, 3),
    PV_Oracle  = round(pv_oracle, 4),
    PV_Estimated = round(pv_est, 4)
  )
})
pv_df <- do.call(rbind, pv_tab)
print(pv_df)
# Find optimal cost threshold
optimal_c <- pv_df$Cost[which.max(pv_df$PV_Oracle)]

ggplot(pv_df, aes(x = Cost)) +
  geom_line(aes(y = PV_Oracle, linetype = "Oracle"), color = "#D73027",
            linewidth = 1.1) +
  geom_line(aes(y = PV_Estimated, linetype = "Estimated"), color = "#2166AC",
            linewidth = 1.1) +
  geom_vline(xintercept = optimal_c, linetype = "dashed", color = "gray40") +
  scale_linetype_manual(values = c("Oracle" = "solid", "Estimated" = "dashed")) +
  labs(
    x        = "Cost Threshold c (treat if CATE > c)",
    y        = "Policy Value",
    title    = "Policy Value vs. Cost Threshold",
    subtitle = sprintf("Optimal threshold (oracle) = %.2f", optimal_c),
    linetype = NULL
  ) +
  theme_minimal(base_size = 12) +
  theme(legend.position = "bottom")

15.3 Welfare Gain from Targeting

# Welfare gain of CATE-based targeting vs. uniform policies
# No treatment baseline
pv_no_treat <- mean(tau_true * 0 + tau_true * 0)  # E[Y(0)] proxy
pv_all_treat <- mean(tau_true)   # E[Y(1) - Y(0)] = ATE

# Optimal targeting
pi_opt    <- as.integer(tau_cf >= 0)
pv_target <- mean(pi_opt * tau_true)

cat("Welfare analysis (relative to baseline):\n")
cat("  Treat nobody:  ATE gain = 0\n")
cat("  Treat everyone: ATE gain =", round(pv_all_treat, 4), "\n")
cat("  Optimal targeting (CATE > 0): gain =", round(pv_target, 4), "\n")
cat("  Gain from targeting vs. treating all:",
    round(pv_target - pv_all_treat, 4), "\n")
cat("  Fraction treated under optimal rule:", round(mean(pi_opt), 3), "\n")

15.4 Stochastic (Probabilistic) Policies

A stochastic policy assigns treatment probabilistically: π(x)=P(treatX=x)[0,1]\pi(x) = P(\text{treat} \mid X = x) \in [0, 1]. This generalizes the binary rule and allows smooth interpolation:

# Stochastic policy: probability = sigmoid of CATE
softmax_policy <- function(tau, temperature = 1) {
  1 / (1 + exp(-tau / temperature))
}

# Different temperatures (higher = more uniform)
temps  <- c(0.1, 0.5, 1.0, 2.0, Inf)
sp_df  <- lapply(temps, function(T) {
  pi_soft <- if (is.infinite(T)) rep(0.5, n) else softmax_policy(tau_cf, T)
  pv_soft <- mean(pi_soft * tau_true)
  data.frame(
    Temperature  = T,
    Mean_Pi      = round(mean(pi_soft), 3),
    Policy_Value = round(pv_soft, 4)
  )
})
print(do.call(rbind, sp_df))

16. Sensitivity Analysis for HTE

16.1 Sensitivity to Model Specification

# How do CATE estimates change with model specification?
set.seed(42)
specs <- list(
  "Default (2000 trees)"  = list(num.trees = 2000, min.node.size = NULL),
  "Small forest (500)"   = list(num.trees = 500, min.node.size = NULL),
  "Large nodes"          = list(num.trees = 1000, min.node.size = 50),
  "Small nodes"          = list(num.trees = 1000, min.node.size = 5)
)

sens_df <- lapply(names(specs), function(spec_name) {
  args <- specs[[spec_name]]
  cf_s <- do.call(causal_forest,
    c(list(X = X, Y = Y, W = W), args[!sapply(args, is.null)]))
  tau_s_hat <- predict(cf_s)$predictions
  ate_s <- average_treatment_effect(cf_s)
  data.frame(
    Specification = spec_name,
    ATE           = round(ate_s["estimate"], 4),
    ATE_SE        = round(ate_s["std.err"], 4),
    CATE_SD       = round(sd(tau_s_hat), 4),
    RMSE_oracle   = round(sqrt(mean((tau_s_hat - tau_true)^2)), 4),
    Corr_oracle   = round(cor(tau_s_hat, tau_true), 4)
  )
})
print(do.call(rbind, sens_df))

16.2 Sensitivity to Overlap Violations

# What happens when propensity scores are extreme?
set.seed(42)
n_ov  <- 1000
X_ov  <- matrix(rnorm(n_ov * 5), n_ov, 5)
# Treatment strongly predicted by X1 (poor overlap)
ps_ov <- plogis(2 * X_ov[, 1])
W_ov  <- rbinom(n_ov, 1, ps_ov)
tau_ov <- X_ov[, 1]
Y_ov   <- tau_ov * W_ov + rnorm(n_ov)

cat("Overlap summary:\n")
cat("  Mean PS:", round(mean(ps_ov), 3), "\n")
cat("  PS < 0.1:", sum(ps_ov < 0.1), "| PS > 0.9:", sum(ps_ov > 0.9), "\n")
cat("  Effective n after 0.1-0.9 trimming:",
    sum(ps_ov >= 0.1 & ps_ov <= 0.9), "\n")

# AIPW with trimming vs without
trim_idx <- ps_ov >= 0.05 & ps_ov <= 0.95
cat("  Fraction retained:", round(mean(trim_idx), 3), "\n")

17. Summary & Best Practices

17.1 Choosing a CATE Estimator

Scenario Recommended Method Reason
Large N, high p, no strong priors Causal Forest (GRF) Nonparametric, honest, CI available
Balanced arms, moderate N T-Learner (XGBoost) Simple, flexible
Imbalanced arms X-Learner Designed for imbalance
Linear heterogeneity hypothesized DML + interaction Interpretable
Semiparametric efficiency desired DR-Learner Optimal convergence rate
Publication with pre-specified model Parametric interaction Transparent, interpretable

17.2 Key Diagnostics Checklist

Before reporting HTE results:

  • Balance: Standardized differences < 0.1 for all covariates
  • Overlap: Propensity scores bounded away from 0 and 1
  • ATE calibration: BLP β1\beta_1 \approx ATE, β2\beta_2 significant
  • Heterogeneity signal: Calibration test significant; GATES show variation
  • RMSE / R-loss: Cross-validated proxy evaluation performed
  • AUUC > random: Qini curve shows benefit of targeting
  • Subgroup pre-specification: Primary analysis based on pre-specified groups
  • Multiple testing correction: Applied for exploratory subgroups
  • Sensitivity analysis: Results robust to model specification

17.3 Reporting Checklist for Papers

When writing up HTE results:

  1. Method section: Clearly describe the estimator, nuisance model, and cross-fitting.
  2. Assumptions: State unconfoundedness and overlap; justify with domain knowledge.
  3. Primary result: ATE with SE and CI; then heterogeneity tests.
  4. Heterogeneity evidence: BLP, GATES, calibration test statistics.
  5. Subgroup table: Pre-specified subgroups with corrected p-values.
  6. Forest plot: Visual display of subgroup effects.
  7. Policy implications: Qini curve or optimal treatment fraction.
  8. Limitations: Address unmeasured confounding, external validity, model dependence.

17.4 Cross-References to Other causalverse Vignettes

  • [Vignette 1: RCT] - foundational estimands and experimental design
  • [Vignette 3: DiD] - panel data HTE and group-time ATTs
  • [Vignette 9: Matching] - matching-based CATE estimation, balance diagnostics
  • [Vignette 7: IV] - LATE and complier-specific heterogeneity
  • [Vignette 10: Sensitivity] - sensitivity analysis for unmeasured confounding

18. Advanced Topics in HTE

18.1 Continuous and Multi-Valued Treatments

When the treatment WiW_i is continuous (dose, intensity) rather than binary, the estimand generalizes to the Dose-Response Function (DRF) or the Average Derivative Effect:

θ(w)=w𝔼[Y(w)X=x]\theta(w) = \frac{\partial}{\partial w} \mathbb{E}[Y(w) \mid X = x]

The grf package provides a causal_forest variant for continuous treatments by setting the treatment as a numeric variable.

# Continuous treatment simulation
set.seed(42)
n_cont <- 1500
p_cont <- 8
X_cont <- matrix(rnorm(n_cont * p_cont), n_cont, p_cont)
# Continuous dose: uniform on [0, 2]
W_cont <- runif(n_cont, 0, 2)
# Nonlinear dose-response: quadratic with heterogeneity
tau_cont <- X_cont[, 1] * W_cont - 0.5 * W_cont^2
Y_cont   <- tau_cont + 0.3 * X_cont[, 2] + rnorm(n_cont, sd = 0.5)

# Causal forest with continuous treatment
cf_cont <- causal_forest(
  X = X_cont, Y = Y_cont, W = W_cont,
  num.trees = 1000, seed = 42
)

ate_cont <- average_treatment_effect(cf_cont)
cat("Continuous treatment ATE (avg marginal effect):\n")
cat("  Estimate:", round(ate_cont["estimate"], 4),
    "| SE:", round(ate_cont["std.err"], 4), "\n")

18.2 Instrumental Variables with Heterogeneity

When treatment is endogenous, IV methods identify the Local Average Treatment Effect (LATE) - the CATE for compliers. With multiple instruments and covariates, an IV causal forest can estimate heterogeneous LATEs:

# IV simulation: instrument Z induces compliance
set.seed(42)
n_iv <- 1500
X_iv <- matrix(rnorm(n_iv * 6), n_iv, 6)
Z_iv <- rbinom(n_iv, 1, 0.5)    # instrument (randomized)

# Compliance: P(W=1|Z=1,X) > P(W=1|Z=0,X)
p_comply <- plogis(0.5 * X_iv[, 1] + 0.5)
W_iv     <- rbinom(n_iv, 1, Z_iv * p_comply + (1 - Z_iv) * p_comply * 0.2)

# Heterogeneous LATE: driven by X1
tau_iv <- pmax(X_iv[, 1], 0)
Y_iv   <- tau_iv * W_iv + 0.4 * X_iv[, 2] + rnorm(n_iv)

# IV forest
iv_cf <- instrumental_forest(
  X = X_iv, Y = Y_iv, W = W_iv, Z = Z_iv,
  num.trees = 1000, seed = 42
)

late_est <- average_treatment_effect(iv_cf)
cat("IV Forest LATE estimate:\n")
cat("  Estimate:", round(late_est["estimate"], 4),
    "| SE:", round(late_est["std.err"], 4), "\n")

# Heterogeneous LATE predictions
tau_iv_hat <- predict(iv_cf)$predictions
cat("  LATE heterogeneity (SD):", round(sd(tau_iv_hat), 4), "\n")
cat("  Compliance rate:", round(mean(W_iv[Z_iv == 1]) - mean(W_iv[Z_iv == 0]), 4), "\n")

18.3 Difference-in-Differences with Heterogeneous Effects

In panel data settings with staggered adoption, the CATE generalizes to group-time ATTs. The did package (Callaway & Sant’Anna 2021) estimates ATT(g,t)ATT(g, t), which can be aggregated to obtain heterogeneous effects by subgroup:

# DiD with heterogeneous treatment timing
set.seed(42)
n_units <- 500
n_times <- 6
# Staggered adoption: some treated in t=3, others in t=5
unit_fe   <- rnorm(n_units)
time_fe   <- seq(0, 1, length.out = n_times)
treated   <- rbinom(n_units, 1, 0.6)
timing    <- ifelse(treated == 1, sample(c(3, 5), n_units, replace = TRUE), Inf)

# Heterogeneous treatment effect: larger for units with unit_fe > 0
tau_het   <- ifelse(unit_fe > 0, 2.0, 0.5) * treated

panel_df <- data.frame(
  unit = rep(1:n_units, each = n_times),
  time = rep(1:n_times, times = n_units),
  unit_fe = rep(unit_fe, each = n_times),
  timing  = rep(timing, each = n_times),
  tau     = rep(tau_het, each = n_times)
)
panel_df$treated_period <- panel_df$time >= panel_df$timing
panel_df$Y <- panel_df$unit_fe +
  rep(time_fe, times = n_units) +
  panel_df$treated_period * panel_df$tau +
  rnorm(n_units * n_times, sd = 0.5)

# Simple DiD: high vs low unit_fe subgroups
panel_df$group_label <- ifelse(panel_df$unit_fe > 0, "High FE", "Low FE")

# Subgroup-specific DiD ATT
did_results <- lapply(c("High FE", "Low FE"), function(grp) {
  sub <- panel_df[panel_df$group_label == grp, ]
  post <- sub$time >= 3 & sub$treated_period
  control <- sub$time >= 3 & !sub$treated_period
  diff <- mean(sub$Y[post]) - mean(sub$Y[!post]) -
          (mean(sub$Y[sub$time >= 3 & sub$timing > 6]) -
           mean(sub$Y[sub$time < 3 & sub$timing > 6]))
  data.frame(Group = grp,
             Naive_Diff = round(diff, 4),
             True_ATT  = round(mean(sub$tau[sub$treated_period]), 4))
})
cat("DiD Subgroup ATTs:\n")
print(do.call(rbind, did_results))

19. Practical Workflow with causalverse

19.1 End-to-End HTE Pipeline

The following shows the complete workflow using causalverse functions:

# Step 1: Data setup (already done: X, Y, W, tau_true)

# Step 2: Fit causal forest
set.seed(42)
cf_final <- causal_forest(X, Y, W, num.trees = 2000, tune.parameters = "all")

# Step 3: Comprehensive summary
pipeline_summary <- causal_forest_summary(
  cf            = cf_final,
  X             = X,
  feature_names = colnames(X),
  n_groups      = 4,
  top_features  = 10
)

# Step 4: Print ATE
cat("=== PIPELINE RESULTS ===\n")
cat("\n--- ATE ---\n")
print(pipeline_summary$ate)

cat("\n--- CATE Summary ---\n")
print(round(pipeline_summary$cate_summary, 4))

cat("\n--- BLP Test ---\n")
print(pipeline_summary$blp$blp)
# Step 5: Visualize
pipeline_summary$plot_cate_dist
pipeline_summary$plot_gates
# Step 6: Policy evaluation
tau_final <- predict(cf_final)$predictions
qini_final <- qini_curve(
  cate_hat   = tau_final,
  Y          = Y,
  W          = W,
  boot_reps  = 200
)
cat("\n--- Qini / AUUC ---\n")
cat("AUUC:", round(qini_final$auuc, 4), "\n")
cat("95% CI:", round(qini_final$auuc_ci[1], 4), "-",
    round(qini_final$auuc_ci[2], 4), "\n")

19.2 Integrating AIPW and Causal Forest

A common workflow combines GRF’s cross-fitted nuisances with the AIPW estimator for the final ATE, and uses the causal forest CATE for targeting:

# Extract GRF's cross-fitted nuisances
e_grf <- cf_final$W.hat   # propensity score
m_grf <- cf_final$Y.hat   # E[Y|X]
tau_grf <- predict(cf_final)$predictions

# AIPW ATE using GRF nuisances
Y_cf1 <- m_grf + tau_grf
Y_cf0 <- m_grf

psi_grf <- Y_cf1 - Y_cf0 +
  W * (Y - Y_cf1) / e_grf -
  (1 - W) * (Y - Y_cf0) / (1 - e_grf)

aipw_grf_ate <- mean(psi_grf)
aipw_grf_se  <- sd(psi_grf) / sqrt(n)

cat("AIPW ATE (GRF nuisances):\n")
cat("  Estimate:", round(aipw_grf_ate, 4), "\n")
cat("  SE:      ", round(aipw_grf_se, 4), "\n")
cat("  95% CI: [",
    round(aipw_grf_ate - 1.96 * aipw_grf_se, 4), ",",
    round(aipw_grf_ate + 1.96 * aipw_grf_se, 4), "]\n")
cat("  True ATE:", round(mean(tau_true), 4), "\n")

19.3 Using causalverse::dr_ate() on Real Data

if (requireNamespace("MatchIt", quietly = TRUE)) {
  data(lalonde, package = "MatchIt")

  # ATE
  dr_res_ate <- dr_ate(
    data       = lalonde,
    outcome    = "re78",
    treatment  = "treat",
    covariates = c("age", "educ", "married", "nodegree", "re74", "re75"),
    estimand   = "ATE",
    boot_se    = FALSE
  )

  # ATT
  dr_res_att <- dr_ate(
    data       = lalonde,
    outcome    = "re78",
    treatment  = "treat",
    covariates = c("age", "educ", "married", "nodegree", "re74", "re75"),
    estimand   = "ATT"
  )

  cat("Lalonde AIPW Results:\n")
  cat(sprintf("  ATE: %.0f (SE = %.0f, p = %.4f)\n",
              dr_res_ate$estimate, dr_res_ate$se, dr_res_ate$p_value))
  cat(sprintf("  ATT: %.0f (SE = %.0f, p = %.4f)\n",
              dr_res_att$estimate, dr_res_att$se, dr_res_att$p_value))
  cat(sprintf("  Propensity score range: [%.3f, %.3f]\n",
              dr_res_ate$ps_summary["Min."],
              dr_res_ate$ps_summary["Max."]))
}

19.4 HTE on Real Data: Lalonde Job Training

data(lalonde, package = "MatchIt")

# Prepare matrix
cov_names <- c("age", "educ", "married", "nodegree", "re74", "re75")
X_la <- model.matrix(~ age + educ + married + nodegree + re74 + re75 - 1,
                     data = lalonde)
Y_la <- lalonde$re78
W_la <- lalonde$treat

set.seed(42)
cf_la <- causal_forest(X_la, Y_la, W_la,
                        num.trees = 2000, tune.parameters = "all")

tau_la <- predict(cf_la)$predictions
ate_la <- average_treatment_effect(cf_la)

cat("Lalonde CATE Analysis:\n")
cat("  ATE:", round(ate_la["estimate"], 1),
    "| SE:", round(ate_la["std.err"], 1), "\n")
cat("  CATE range: [", round(min(tau_la), 0), ",",
    round(max(tau_la), 0), "]\n")
cat("  Fraction with CATE > 0:", round(mean(tau_la > 0), 3), "\n")
# Subgroup analysis for Lalonde
la_df <- data.frame(X_la, cate = tau_la, W = W_la, Y = Y_la)
la_df$educ_grp <- ifelse(lalonde$educ >= 12, "HS Graduate", "No HS Diploma")
la_df$married_grp <- ifelse(lalonde$married == 1, "Married", "Single")
la_df$re74_grp <- ifelse(lalonde$re74 > 0, "Employed in 1974", "Not Employed")
la_df$nodegree_grp <- ifelse(lalonde$nodegree == 1, "No Degree", "Has Degree")
la_df$age_grp <- cut(lalonde$age, breaks = c(16, 24, 30, 55),
                      labels = c("Young (17-24)", "Middle (25-30)", "Older (31+)"))
la_df$Subgroup <- paste(la_df$educ_grp)

# Use heterogeneity_plot for forest plot
p_la_edu <- heterogeneity_plot(
  data           = la_df,
  cate_var       = "cate",
  subgroup_var   = "educ_grp",
  overall_effect = ate_la["estimate"],
  plot_type      = "forest",
  title          = "Lalonde: HTE by Education"
)

p_la_re74 <- heterogeneity_plot(
  data           = la_df,
  cate_var       = "cate",
  subgroup_var   = "re74_grp",
  overall_effect = ate_la["estimate"],
  plot_type      = "forest",
  title          = "Lalonde: HTE by 1974 Employment"
)

if (requireNamespace("patchwork", quietly = TRUE)) {
  library(patchwork)
  p_la_edu + p_la_re74
} else {
  print(p_la_edu)
  print(p_la_re74)
}
# Qini for Lalonde
qini_la <- qini_curve(tau_la, Y_la, W_la, boot_reps = 200)
cat("Lalonde AUUC:", round(qini_la$auuc, 4), "\n")
qini_la$plot + labs(title = "Lalonde: Qini Curve for Job Training Targeting")

20. Common Pitfalls and Troubleshooting

20.1 The SUTVA Violation

If treatment of one unit affects another (spillovers), SUTVA is violated and CATE estimates may be biased. Signs include:

  • Geographic clustering of outcomes
  • Peer effects in the outcome variable
  • Market equilibrium effects

Remedies: Cluster-randomized designs, partial interference models, network causal inference.

20.2 Outcome Leakage into Covariates

Never condition on variables that are affected by treatment (colliders or mediators). Including post-treatment variables in XX can introduce bias:

# Illustration: including a mediator biases CATE
set.seed(42)
n_col  <- 1000
X_col  <- rnorm(n_col)
W_col  <- rbinom(n_col, 1, 0.5)
tau_col <- X_col           # true CATE = X
M_col   <- 0.5 * W_col + 0.2 * X_col + rnorm(n_col, sd = 0.3)  # mediator
Y_col   <- tau_col * W_col + M_col + rnorm(n_col, sd = 0.5)

# Correct model: no mediator
lm_correct <- lm(Y_col ~ W_col + X_col)
cat("Correct (no mediator) W coef:", round(coef(lm_correct)["W_col"], 4), "\n")

# Incorrect model: include mediator (post-treatment)
lm_biased  <- lm(Y_col ~ W_col + X_col + M_col)
cat("Biased  (mediator included) W coef:", round(coef(lm_biased)["W_col"], 4), "\n")
cat("True ATE:", round(mean(tau_col), 4), "\n")
cat("Warning: including post-treatment variable M biases ATE toward 0!\n")

20.3 Extrapolation in CATE Estimation

ML models can extrapolate poorly in regions with low support. Always check:

  1. Covariate overlap between treated and control.
  2. CATE confidence interval width (wider in sparse regions).
  3. Never interpret CATE at covariate values outside the data range.
# Overlap plot: propensity score by treatment arm
overlap_df <- data.frame(
  ps    = cf_final$W.hat,
  treat = factor(W, levels = c(0, 1), labels = c("Control", "Treated"))
)

ggplot(overlap_df, aes(x = ps, fill = treat)) +
  geom_histogram(bins = 40, alpha = 0.6, position = "identity") +
  scale_fill_manual(values = c("Control" = "#4393C3", "Treated" = "#D73027")) +
  geom_vline(xintercept = c(0.1, 0.9), linetype = "dashed") +
  labs(
    x     = "Estimated Propensity Score",
    y     = "Count",
    fill  = NULL,
    title = "Propensity Score Overlap",
    subtitle = "Dashed lines at 0.1 and 0.9 (overlap region)"
  ) +
  theme_minimal(base_size = 12) +
  theme(legend.position = "bottom")

20.4 Model Selection for Nuisance Functions

The choice of nuisance model (for m̂\hat{m} and ê\hat{e}) affects CATE estimation:

  • Underfitting nuisances: Residual confounding leaks into CATE estimates.
  • Overfitting nuisances: Without cross-fitting, over-smoothing the residuals may shrink CATE toward zero.

Best practice: Use cross-fitting with a flexible but regularized learner (lasso, random forest, gradient boosting). Cross-validate the nuisance models separately.

# Cross-validate nuisance model performance
set.seed(42)
cv_m_final <- cv.glmnet(X, Y, alpha = 1, nfolds = 10)
cv_e_final <- cv.glmnet(X, W, family = "binomial", alpha = 1, nfolds = 10)

cat("Nuisance model cross-validation:\n")
cat("  Outcome model (lasso) R^2:",
    round(1 - min(cv_m_final$cvm) / var(Y), 4), "\n")
cat("  Propensity model AUC: ~0.5 (random assignment)\n")
cat("  Optimal lambda (outcome):", round(cv_m_final$lambda.min, 6), "\n")
cat("  Optimal lambda (propensity):", round(cv_e_final$lambda.min, 6), "\n")

20.5 Testing for Effect Modification vs. Effect Measure Modification

The same data can show heterogeneity on the absolute scale (CATE = τi\tau_i) but not on the relative scale (Relative Risk = Yi(1)/Yi(0)Y_i(1)/Y_i(0)), or vice versa. Always specify which scale you are studying:

  • Additive HTE (causal forest, DML): τ(x)=μ1(x)μ0(x)\tau(x) = \mu_1(x) - \mu_0(x)
  • Multiplicative HTE: RR(x)=μ1(x)/μ0(x)\text{RR}(x) = \mu_1(x) / \mu_0(x); requires outcome on log scale

For binary outcomes, both scales are relevant. For continuous outcomes, the additive scale is most common.


References

Athey, S., & Imbens, G. W. (2016). Recursive partitioning for heterogeneous causal effects. Proceedings of the National Academy of Sciences, 113(27), 7353–7360.

Athey, S., Tibshirani, J., & Wager, S. (2019). Generalized random forests. Annals of Statistics, 47(2), 1148–1178.

Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W., & Robins, J. (2018). Double/debiased machine learning for treatment and structural parameters. Econometrics Journal, 21(1), C1–C68.

Chernozhukov, V., Demirer, M., Duflo, E., & Fernandez-Val, I. (2020). Generic machine learning inference on heterogeneous treatment effects in randomized experiments. NBER Working Paper 24678.

Foster, J. C., Taylor, J. M., & Ruberg, S. J. (2011). Subgroup identification from randomized clinical trial data. Statistics in Medicine, 30(24), 2867–2880.

Kennedy, E. H. (2020). Optimal doubly robust estimation of heterogeneous causal effects. arXiv preprint arXiv:2004.14497.

Künzel, S. R., Sekhon, J. S., Bickel, P. J., & Yu, B. (2019). Metalearners for estimating heterogeneous treatment effects using machine learning. Proceedings of the National Academy of Sciences, 116(10), 4156–4165.

Lalonde, R. J. (1986). Evaluating the econometric evaluations of training programs with experimental data. American Economic Review, 76(4), 604–620.

Nie, X., & Wager, S. (2021). Quasi-oracle estimation of heterogeneous treatment effects. Biometrika, 108(2), 299–319.

Radcliffe, N. J. (2007). Using control groups to target on predicted lift. Direct Marketing Analytics Journal, 1(3), 14–21.

Robinson, P. M. (1988). Root-N-consistent semiparametric regression. Econometrica, 56(4), 931–954.

Wager, S., & Athey, S. (2018). Estimation and inference of heterogeneous treatment effects using random forests. Journal of the American Statistical Association, 113(523), 1228–1242.