k_hte.Rmd
library(causalverse)Most causal inference methods report a single summary statistic - the Average Treatment Effect (ATE) - which tells us the average impact of an intervention across the study population. While the ATE is a valuable and interpretable quantity, it can obscure substantial variation in how individuals respond to treatment. When treatment effects differ across units, reporting only the ATE may lead to:
Heterogeneous Treatment Effect (HTE) analysis asks: for whom does the treatment work, and by how much? This is the central question in precision medicine, personalized policy, targeted marketing, and individualized education.
Modern causal machine learning methods - including causal forests, meta-learners, and double/debiased machine learning - make it possible to estimate HTE at the individual level, validate those estimates, and translate them into actionable targeting rules, all while maintaining rigorous causal guarantees.
Let and denote the potential outcomes for unit under treatment and control, respectively. The treatment indicator is , and is a vector of pre-treatment covariates.
Individual Treatment Effect (ITE): This is never observed due to the fundamental problem of causal inference.
Conditional Average Treatment Effect (CATE): The CATE is the population-level treatment effect for units with covariate profile . This is the primary estimand of HTE analysis.
Average Treatment Effect (ATE):
Average Treatment Effect on the Treated (ATT):
The CATE generalizes the ATE: is the expectation of the CATE over the marginal distribution of .
To identify the CATE from observational data, we require:
Unconfoundedness (Conditional Independence): No unmeasured confounders conditional on .
Overlap (Positivity): for all Every covariate profile has a positive probability of both treatment and control.
SUTVA: Stable Unit Treatment Value Assumption - no interference and no hidden treatment versions.
Under these assumptions, the CATE is identified:
In randomized experiments, unconfoundedness holds by design. In observational studies, it must be justified by domain knowledge.
This vignette covers a comprehensive toolkit for HTE estimation:
| Method | Key Reference | Package |
|---|---|---|
| S-Learner | Künzel et al. (2019) |
xgboost, lm
|
| T-Learner | Künzel et al. (2019) | Any ML |
| X-Learner | Künzel et al. (2019) | grf |
| R-Learner / DML | Robinson (1988), Nie & Wager (2021) |
glmnet, DoubleML
|
| Causal Forest | Wager & Athey (2018), Athey et al. (2019) | grf |
| BLP / GATES | Chernozhukov et al. (2020) | causalverse |
| DR-Learner / AIPW | Kennedy (2020) | causalverse |
| Qini / AUUC | Radcliffe (2007) | causalverse |
| DML with interactions | Chernozhukov et al. (2018) |
glmnet, DoubleML
|
We begin with a simulation where the true CATE is known, enabling us to evaluate estimator performance.
set.seed(42)
n <- 2000
p <- 10
# Covariates: standard normal
X <- matrix(rnorm(n * p), n, p)
colnames(X) <- paste0("X", 1:p)
# Treatment: randomized with probability 0.5
W <- rbinom(n, 1, 0.5)
# True CATE: linear in X1 (heterogeneity driven by X1)
tau_true <- X[, 1]
# Outcome: tau * W + noise
Y <- tau_true * W + 0.5 * X[, 2] + rnorm(n)
# Compile into data frame
df <- data.frame(X, W = W, Y = Y, tau_true = tau_true)
cat("ATE (true):", mean(tau_true), "\n")
cat("SD of tau (true):", round(sd(tau_true), 3), "\n")
cat("Fraction positive tau:", round(mean(tau_true > 0), 3), "\n")
cat("n =", n, "| p =", p, "| Treatment fraction:", mean(W), "\n")We also use the classic Lalonde (1986) dataset from the MatchIt package, which evaluates a job training program on earnings. This dataset illustrates HTE in a real-world observational setting.
if (requireNamespace("MatchIt", quietly = TRUE)) {
data(lalonde, package = "MatchIt")
cat("Lalonde dataset: n =", nrow(lalonde), "\n")
cat("Outcome: re78 (1978 earnings)\n")
cat("Treatment: treat (job training, 1 = treated)\n")
cat("Treatment fraction:", round(mean(lalonde$treat), 3), "\n")
cat("Mean re78 (treated):", round(mean(lalonde$re78[lalonde$treat == 1]), 1), "\n")
cat("Mean re78 (control):", round(mean(lalonde$re78[lalonde$treat == 0]), 1), "\n")
}Before estimating treatment effects, always check covariate balance.
# Balance table for simulated data
treated_idx <- which(W == 1)
control_idx <- which(W == 0)
balance_df <- data.frame(
Variable = colnames(X),
Mean_Treated = round(colMeans(X[treated_idx, ]), 3),
Mean_Control = round(colMeans(X[control_idx, ]), 3),
Std_Diff = round(
(colMeans(X[treated_idx, ]) - colMeans(X[control_idx, ])) /
sqrt((apply(X[treated_idx, ], 2, var) + apply(X[control_idx, ], 2, var)) / 2),
3
)
)
print(balance_df)
# Visualize standardized differences
ggplot(balance_df, aes(x = Std_Diff, y = reorder(Variable, Std_Diff))) +
geom_vline(xintercept = c(-0.1, 0.1), linetype = "dashed", color = "red", alpha = 0.5) +
geom_vline(xintercept = 0, color = "gray40") +
geom_point(size = 3, color = "#2166AC") +
labs(
x = "Standardized Mean Difference",
y = NULL,
title = "Pre-Treatment Covariate Balance",
subtitle = "Dashed lines at ±0.1 (conventional threshold for good balance)"
) +
theme_minimal(base_size = 12)In the randomized simulation, all standardized differences are close to zero - expected under random assignment.
ggplot(data.frame(tau = tau_true), aes(x = tau)) +
geom_histogram(bins = 50, fill = "#4393C3", alpha = 0.7, color = "white") +
geom_vline(xintercept = mean(tau_true), color = "#D73027", linewidth = 1.2) +
geom_vline(xintercept = 0, linetype = "dashed") +
labs(
x = "True Individual Treatment Effect (tau)",
y = "Count",
title = "Distribution of True CATE",
subtitle = sprintf("ATE = %.3f; SD = %.3f; %.0f%% positive",
mean(tau_true), sd(tau_true), 100 * mean(tau_true > 0))
) +
theme_minimal(base_size = 12)Meta-learners reduce the CATE estimation problem to off-the-shelf regression or classification. They differ in how they use treatment information.
The S-Learner fits a single model with the treatment indicator as just another feature:
Pros: Simple, uses all data. Cons: May underfit treatment heterogeneity if the learner shrinks the treatment coefficient toward zero (e.g., lasso).
# S-Learner with linear model
df_train <- data.frame(X, W = W, Y = Y)
# Fit model with W as a feature (and all interactions for flexibility)
s_fit <- lm(Y ~ . * W, data = df_train[, c(paste0("X", 1:p), "W", "Y")])
# Predict potential outcomes
df1 <- df_train; df1$W <- 1
df0 <- df_train; df0$W <- 0
tau_s <- predict(s_fit, newdata = df1) - predict(s_fit, newdata = df0)
cat("S-Learner (linear with interactions):\n")
cat(" Mean CATE:", round(mean(tau_s), 4), "\n")
cat(" SD CATE: ", round(sd(tau_s), 4), "\n")
cat(" RMSE vs truth:", round(sqrt(mean((tau_s - tau_true)^2)), 4), "\n")
library(xgboost)
# S-Learner with XGBoost
dtrain <- xgb.DMatrix(
data = cbind(X, W = W),
label = Y
)
xgb_params <- list(
objective = "reg:squarederror",
max_depth = 4,
eta = 0.05,
subsample = 0.8,
colsample_bytree = 0.8,
nthread = 1
)
set.seed(42)
xgb_fit <- xgb.train(
params = xgb_params,
data = dtrain,
nrounds = 200,
verbose = 0
)
# Predict with W=1 and W=0
d1_xgb <- xgb.DMatrix(data = cbind(X, W = 1))
d0_xgb <- xgb.DMatrix(data = cbind(X, W = 0))
tau_s_xgb <- predict(xgb_fit, d1_xgb) - predict(xgb_fit, d0_xgb)
cat("S-Learner (XGBoost):\n")
cat(" Mean CATE:", round(mean(tau_s_xgb), 4), "\n")
cat(" SD CATE: ", round(sd(tau_s_xgb), 4), "\n")
cat(" RMSE vs truth:", round(sqrt(mean((tau_s_xgb - tau_true)^2)), 4), "\n")The T-Learner fits separate outcome models for treated and control units:
Pros: No regularization bias on treatment. Cons: Can overfit when one arm is small; ignores information sharing across arms.
# T-Learner with linear models
X_df <- as.data.frame(X)
# Treated model
fit1 <- lm(Y ~ ., data = cbind(X_df, Y = Y)[W == 1, ])
# Control model
fit0 <- lm(Y ~ ., data = cbind(X_df, Y = Y)[W == 0, ])
# Predict on full data
tau_t <- predict(fit1, newdata = X_df) - predict(fit0, newdata = X_df)
cat("T-Learner (linear):\n")
cat(" Mean CATE:", round(mean(tau_t), 4), "\n")
cat(" SD CATE: ", round(sd(tau_t), 4), "\n")
cat(" RMSE vs truth:", round(sqrt(mean((tau_t - tau_true)^2)), 4), "\n")
# T-Learner with XGBoost
set.seed(42)
# Treated model
xgb_fit1 <- xgb.train(
params = xgb_params,
data = xgb.DMatrix(data = X[W == 1, ], label = Y[W == 1]),
nrounds = 200,
verbose = 0
)
# Control model
xgb_fit0 <- xgb.train(
params = xgb_params,
data = xgb.DMatrix(data = X[W == 0, ], label = Y[W == 0]),
nrounds = 200,
verbose = 0
)
dX <- xgb.DMatrix(data = X)
tau_t_xgb <- predict(xgb_fit1, dX) - predict(xgb_fit0, dX)
cat("T-Learner (XGBoost):\n")
cat(" Mean CATE:", round(mean(tau_t_xgb), 4), "\n")
cat(" SD CATE: ", round(sd(tau_t_xgb), 4), "\n")
cat(" RMSE vs truth:", round(sqrt(mean((tau_t_xgb - tau_true)^2)), 4), "\n")The X-Learner is a two-stage procedure designed to improve on the T-Learner when treatment arms are imbalanced:
Stage 1: Fit T-Learner models and .
Stage 2: Compute imputed treatment effects:
Fit regression models for and , then combine: where is typically the propensity score .
# X-Learner implementation using linear models
# Stage 1: T-Learner (reuse fit1, fit0 from above)
mu1_hat <- predict(fit1, newdata = X_df)
mu0_hat <- predict(fit0, newdata = X_df)
# Stage 2: Imputed effects
D_tilde_1 <- Y[W == 1] - mu0_hat[W == 1]
D_tilde_0 <- mu1_hat[W == 0] - Y[W == 0]
# Fit stage-2 models
tau_fit1 <- lm(D_tilde_1 ~ ., data = X_df[W == 1, ])
tau_fit0 <- lm(D_tilde_0 ~ ., data = X_df[W == 0, ])
tau1_hat <- predict(tau_fit1, newdata = X_df)
tau0_hat <- predict(tau_fit0, newdata = X_df)
# Propensity score as weighting function g(x)
ps_fit <- glm(W ~ ., data = cbind(X_df, W = W), family = binomial())
e_hat <- predict(ps_fit, type = "response")
# Combine
tau_x <- e_hat * tau1_hat + (1 - e_hat) * tau0_hat
cat("X-Learner (linear):\n")
cat(" Mean CATE:", round(mean(tau_x), 4), "\n")
cat(" SD CATE: ", round(sd(tau_x), 4), "\n")
cat(" RMSE vs truth:", round(sqrt(mean((tau_x - tau_true)^2)), 4), "\n")The R-Learner uses Robinson’s (1988) partialling-out decomposition. Define residuals:
where and are cross-fit nuisance functions.
The R-Learner minimizes the R-loss:
This is equivalent to regressing on with weights , or running a weighted regression of on .
library(glmnet)
# Cross-fitting with K=5 folds
K <- 5
set.seed(42)
folds <- sample(rep(1:K, length.out = n))
m_hat <- numeric(n) # E[Y|X] cross-fitted
e_hat_r <- numeric(n) # E[W|X] cross-fitted
for (k in 1:K) {
train_idx <- which(folds != k)
test_idx <- which(folds == k)
# Outcome model: lasso on X
cv_m <- cv.glmnet(X[train_idx, ], Y[train_idx], alpha = 1, nfolds = 5)
m_hat[test_idx] <- predict(cv_m, X[test_idx, ], s = "lambda.min")
# Propensity model: lasso logistic on X
cv_e <- cv.glmnet(X[train_idx, ], W[train_idx], family = "binomial",
alpha = 1, nfolds = 5)
e_hat_r[test_idx] <- predict(cv_e, X[test_idx, ], s = "lambda.min",
type = "response")
}
# Clip propensity scores
e_hat_r <- pmax(pmin(e_hat_r, 0.99), 0.01)
# R-residuals
Y_tilde <- Y - m_hat
W_tilde <- W - e_hat_r
# R-learner: weighted lasso
# Regress Y_tilde ~ tau(X) * W_tilde
# Equivalent to regressing Y_tilde / W_tilde on X with weights W_tilde^2
# We fit via modified design matrix
X_r <- X * W_tilde # scaled X
wts <- W_tilde^2
cv_r <- cv.glmnet(X_r, Y_tilde, weights = wts, alpha = 1, nfolds = 5)
tau_r <- predict(cv_r, X_r, s = "lambda.min") / W_tilde
# For units where W_tilde ~ 0, use global mean
tau_r[abs(W_tilde) < 0.01] <- mean(tau_true)
cat("R-Learner (lasso cross-fit):\n")
cat(" Mean CATE:", round(mean(tau_r), 4), "\n")
cat(" SD CATE: ", round(sd(tau_r), 4), "\n")
cat(" RMSE vs truth:", round(sqrt(mean((tau_r - tau_true)^2)), 4), "\n")
# Compile RMSE comparison
rmse_df <- data.frame(
Method = c("S-Learner (OLS)", "S-Learner (XGB)",
"T-Learner (OLS)", "T-Learner (XGB)",
"X-Learner (OLS)", "R-Learner (lasso)"),
RMSE = c(
sqrt(mean((tau_s - tau_true)^2)),
sqrt(mean((tau_s_xgb - tau_true)^2)),
sqrt(mean((tau_t - tau_true)^2)),
sqrt(mean((tau_t_xgb - tau_true)^2)),
sqrt(mean((tau_x - tau_true)^2)),
sqrt(mean((tau_r - tau_true)^2))
),
Bias = c(
mean(tau_s - tau_true),
mean(tau_s_xgb - tau_true),
mean(tau_t - tau_true),
mean(tau_t_xgb - tau_true),
mean(tau_x - tau_true),
mean(tau_r - tau_true)
)
)
rmse_df$RMSE <- round(rmse_df$RMSE, 4)
rmse_df$Bias <- round(rmse_df$Bias, 4)
print(rmse_df)
# Plot comparison
ggplot(rmse_df, aes(x = reorder(Method, RMSE), y = RMSE, fill = Method)) +
geom_col(alpha = 0.8, show.legend = FALSE) +
geom_text(aes(label = round(RMSE, 3)), hjust = -0.1, size = 3.5) +
coord_flip() +
scale_fill_manual(values = c(
"#4393C3", "#2166AC", "#74ADD1", "#ABD9E9", "#D73027", "#F46D43"
)) +
labs(
x = NULL,
y = "RMSE (vs. True CATE)",
title = "Meta-Learner Comparison: Oracle RMSE",
subtitle = "Lower RMSE = better CATE recovery"
) +
theme_minimal(base_size = 12)Causal forests (Wager & Athey 2018; Athey et al. 2019) extend random forests to estimate CATE by:
The key estimating equation within each leaf is:
where are forest weights (frequencies of appearing in the same leaf as ).
library(grf)
set.seed(42)
cf <- causal_forest(
X = X,
Y = Y,
W = W,
num.trees = 2000,
tune.parameters = "all"
)
cat("Causal forest fitted with", cf$`_num_trees`, "trees.\n")
cat("Tuned min.node.size:", cf$tunable.params$min.node.size, "\n")
cat("Tuned sample.fraction:", cf$tunable.params$sample.fraction, "\n")
ate_grf <- average_treatment_effect(cf, target.sample = "all")
att_grf <- average_treatment_effect(cf, target.sample = "treated")
cat("GRF ATE estimate:", round(ate_grf["estimate"], 4),
"| SE:", round(ate_grf["std.err"], 4), "\n")
cat("GRF ATT estimate:", round(att_grf["estimate"], 4),
"| SE:", round(att_grf["std.err"], 4), "\n")
cat("True ATE:", round(mean(tau_true), 4), "\n")
# Out-of-bag predictions (honest, avoids overfitting)
preds <- predict(cf, estimate.variance = TRUE)
tau_cf <- preds$predictions
tau_var <- preds$variance.estimates
tau_se <- sqrt(tau_var)
cat("CATE distribution (causal forest):\n")
print(round(quantile(tau_cf, c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1)), 4))
cat("Oracle RMSE (vs true tau):", round(sqrt(mean((tau_cf - tau_true)^2)), 4), "\n")The calibration test (Chernozhukov et al. 2020) checks whether the CATE predictions are informative:
cal <- test_calibration(cf)
print(cal)A significant coefficient on differential.forest.prediction (the BLP coefficient ) indicates meaningful heterogeneity captured by the forest. A coefficient close to 1 on mean.forest.prediction indicates good calibration.
varimp <- variable_importance(cf)
varimp_df <- data.frame(
feature = colnames(X),
importance = as.numeric(varimp)
)
varimp_df <- varimp_df[order(varimp_df$importance, decreasing = TRUE), ]
ggplot(varimp_df, aes(x = reorder(feature, importance), y = importance)) +
geom_col(fill = "#4393C3", alpha = 0.8) +
coord_flip() +
labs(
x = NULL,
y = "Variable Importance",
title = "Causal Forest: Variable Importance",
subtitle = "X1 drives heterogeneity by construction"
) +
theme_minimal(base_size = 12)
cate_df <- data.frame(
tau_cf = tau_cf,
tau_true = tau_true,
tau_se = tau_se
)
ggplot(cate_df, aes(x = tau_cf)) +
geom_histogram(bins = 50, fill = "#4393C3", alpha = 0.7, color = "white") +
geom_vline(xintercept = ate_grf["estimate"], color = "#D73027", linewidth = 1.2) +
geom_vline(xintercept = 0, linetype = "dashed") +
labs(
x = "Estimated CATE (Causal Forest)",
y = "Count",
title = "Distribution of Estimated CATE",
subtitle = sprintf("ATE = %.3f (red); %.1f%% positive, %.1f%% negative",
ate_grf["estimate"],
100 * mean(tau_cf > 0),
100 * mean(tau_cf < 0))
) +
theme_minimal(base_size = 12)
ggplot(cate_df, aes(x = tau_true, y = tau_cf)) +
geom_point(alpha = 0.15, size = 0.8, color = "#2166AC") +
geom_smooth(method = "lm", color = "#D73027", se = FALSE) +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
labs(
x = "True CATE (oracle)",
y = "Estimated CATE (Causal Forest)",
title = "CATE Estimates vs. True Values",
subtitle = sprintf("Correlation: %.3f | RMSE: %.3f",
cor(tau_cf, tau_true),
sqrt(mean((tau_cf - tau_true)^2)))
) +
theme_minimal(base_size = 12)causal_forest_summary()
cf_summary <- causal_forest_summary(
cf = cf,
X = X,
feature_names = colnames(X),
n_groups = 5,
top_features = 10,
plot = TRUE
)
# ATE table
print(cf_summary$ate)
# CATE distribution summary
print(round(cf_summary$cate_summary, 4))
cf_summary$plot_cate_dist
cf_summary$plot_importanceThe Best Linear Predictor (BLP) framework (Chernozhukov et al. 2020) provides a model-free test for treatment effect heterogeneity and a summary of its structure.
Starting from the partially linear model:
blp_res <- blp_analysis(
cate_hat = tau_cf,
Y = Y,
W = W,
Y_hat = cf$Y.hat,
W_hat = cf$W.hat,
run_gates = TRUE,
n_groups = 5
)
cat("BLP Coefficients:\n")
print(blp_res$blp)A significant confirms that the causal forest captures genuine heterogeneity. The estimate should be close to 1 for a well-calibrated CATE estimator.
GATES summarizes heterogeneity by ranking units into quantile groups by estimated CATE and computing the average treatment effect within each group:
cat("GATES Results:\n")
gates_df <- blp_res$gates
# Only round numeric columns (group is character)
num_cols <- sapply(gates_df, is.numeric)
gates_df[num_cols] <- round(gates_df[num_cols], 4)
print(gates_df)
blp_res$gates_plotThe GATES plot should show a clear monotone pattern from low-CATE to high-CATE groups if the forest is well-calibrated.
# GATES using any CATE estimator: use tau_t (T-Learner OLS)
blp_ols <- blp_analysis(
cate_hat = tau_t,
Y = Y,
W = W,
Y_hat = fitted(lm(Y ~ ., data = cbind(X_df, Y = Y))),
W_hat = rep(mean(W), n),
run_gates = TRUE,
n_groups = 5
)
cat("BLP for T-Learner OLS:\n")
print(blp_ols$blp)The Qini curve (Radcliffe 2007) evaluates the quality of CATE estimates for targeting: which fraction of the population should we treat to maximize aggregate benefit?
Rank units by estimated CATE (highest first). The Qini curve plots:
where is the -th quantile of .
AUUC (Area Under the Uplift Curve) compares targeting strategies: - Random targeting: uniform baseline - CATE-based targeting: our estimator - Oracle targeting: sort by true (requires knowing true effects)
qini_res <- qini_curve(
cate_hat = tau_cf,
Y = Y,
W = W,
compare_random = TRUE,
n_bins = 100,
boot_reps = 500,
seed = 42
)
cat("AUUC:", round(qini_res$auuc, 4), "\n")
cat("AUUC 95% CI: [", round(qini_res$auuc_ci[1], 4),
",", round(qini_res$auuc_ci[2], 4), "]\n")
qini_res$plot +
labs(title = "Qini Curve: Causal Forest CATE Targeting")
# Compute Qini curves for each estimator
compute_qini_simple <- function(tau_hat, Y, W, n_bins = 100) {
n <- length(Y)
ord <- order(tau_hat, decreasing = TRUE)
fracs <- seq(0, 1, length.out = n_bins + 1)
uplifts <- numeric(n_bins + 1)
for (k in seq_len(n_bins + 1)) {
if (k == 1) { uplifts[k] <- 0; next }
n_top <- max(1, round(fracs[k] * n))
top <- ord[seq_len(n_top)]
m1 <- if (sum(W[top] == 1) > 0) mean(Y[top][W[top] == 1]) else 0
m0 <- if (sum(W[top] == 0) > 0) mean(Y[top][W[top] == 0]) else 0
uplifts[k] <- (m1 - m0) * fracs[k]
}
data.frame(fraction = fracs, uplift = uplifts)
}
q_cf <- compute_qini_simple(tau_cf, Y, W)
q_t <- compute_qini_simple(tau_t, Y, W)
q_s <- compute_qini_simple(tau_s, Y, W)
q_ora <- compute_qini_simple(tau_true, Y, W) # Oracle
# Combine
q_all <- rbind(
data.frame(q_cf, Method = "Causal Forest"),
data.frame(q_t, Method = "T-Learner (OLS)"),
data.frame(q_s, Method = "S-Learner (OLS)"),
data.frame(q_ora, Method = "Oracle (true tau)")
)
# Random baseline
overall_tau <- mean(Y[W == 1]) - mean(Y[W == 0])
q_rand <- data.frame(
fraction = seq(0, 1, length.out = 101),
uplift = seq(0, 1, length.out = 101) * overall_tau,
Method = "Random Targeting"
)
q_all <- rbind(q_all, q_rand)
ggplot(q_all, aes(x = fraction, y = uplift, color = Method, linetype = Method)) +
geom_line(linewidth = 0.9) +
scale_color_manual(values = c(
"Causal Forest" = "#2166AC",
"T-Learner (OLS)" = "#4DAC26",
"S-Learner (OLS)" = "#E66101",
"Oracle (true tau)" = "#762A83",
"Random Targeting" = "gray50"
)) +
scale_linetype_manual(values = c(
"Causal Forest" = "solid",
"T-Learner (OLS)" = "solid",
"S-Learner (OLS)" = "solid",
"Oracle (true tau)" = "dotted",
"Random Targeting" = "dashed"
)) +
labs(
x = "Fraction Treated (Ranked by Estimated CATE)",
y = "Cumulative Uplift",
title = "Qini Curve Comparison Across Estimators",
color = NULL, linetype = NULL
) +
theme_minimal(base_size = 12) +
theme(legend.position = "bottom")If treatment has a per-unit cost , the optimal targeting rule is:
# Policy evaluation at different cost thresholds
cost_thresholds <- c(0, 0.25, 0.5, 0.75, 1.0)
policy_values <- sapply(cost_thresholds, function(c_val) {
treated_by_policy <- tau_cf >= c_val
# Policy value = avg tau for those targeted
if (sum(treated_by_policy) == 0) return(0)
mean(tau_true[treated_by_policy]) # Oracle value
})
policy_df <- data.frame(
Cost = cost_thresholds,
Pct_Treat = sapply(cost_thresholds, function(c) mean(tau_cf >= c)),
Avg_CATE = round(policy_values, 4)
)
print(policy_df)
# Estimate policy value using doubly-robust scores
evaluate_policy <- function(policy, Y, W, Y_hat, W_hat) {
# DR policy value: E[policy * (Y(1) - cost) + (1-policy) * Y(0)]
# Simplified: ATE in subgroup where policy = 1
idx <- which(policy == 1)
if (length(idx) == 0) return(NA)
# AIPW estimate in subgroup
W_res <- W[idx] - W_hat[idx]
Y_res <- Y[idx] - Y_hat[idx]
psi <- Y_hat[idx] + W_res * (Y[idx] - Y_hat[idx]) /
ifelse(W[idx] == 1, W_hat[idx], 1 - W_hat[idx])
# Policy value relative to no treatment (all control)
treated_score <- mean(psi[W[idx] == 1]) - mean(Y[idx][W[idx] == 0])
treated_score
}
cat("Policy: treat top 50% by CATE\n")
top50 <- as.integer(tau_cf >= median(tau_cf))
cat("Fraction treated:", mean(top50), "\n")
cat("Oracle policy value (avg true tau for targeted):",
round(mean(tau_true[top50 == 1]), 4), "\n")
cat("Vs. ATE (treat all):", round(mean(tau_true), 4), "\n")
# Create pre-specified subgroups based on X1 and X2
df_plot <- data.frame(X,
cate = tau_cf,
W = W,
Y = Y
)
df_plot$X1_group <- ifelse(X[, 1] > 0, "High X1", "Low X1")
df_plot$X2_group <- ifelse(X[, 2] > 0, "High X2", "Low X2")
df_plot$Subgroup <- paste(df_plot$X1_group, df_plot$X2_group, sep = " / ")
p_forest <- heterogeneity_plot(
data = df_plot,
cate_var = "cate",
subgroup_var = "Subgroup",
overall_effect = ate_grf["estimate"],
plot_type = "forest",
title = "HTE by Pre-Specified Subgroups"
)
p_forest
p_scatter <- heterogeneity_plot(
data = df_plot,
cate_var = "cate",
moderator_var = "X1",
plot_type = "scatter",
title = "CATE vs. Key Moderator X1"
)
p_scatterThe LOESS curve confirms the linear relationship between and the CATE, as specified by the data-generating process.
df_plot$Treatment <- factor(W, levels = c(0, 1), labels = c("Control", "Treated"))
p_violin <- heterogeneity_plot(
data = df_plot,
cate_var = "cate",
subgroup_var = "Treatment",
plot_type = "violin",
title = "CATE Distribution by Treatment Group"
)
p_violinThe partial dependence plot shows the marginal effect of on the CATE, averaging over the distribution of other covariates.
# Partial dependence of CATE on X1: vary X1, hold others at observed values
x1_grid <- seq(min(X[, 1]), max(X[, 1]), length.out = 50)
pdp_vals <- sapply(x1_grid, function(x1_val) {
X_temp <- X
X_temp[, 1] <- x1_val
mean(predict(cf, X_temp)$predictions)
})
pdp_df <- data.frame(X1 = x1_grid, CATE = pdp_vals)
ggplot(pdp_df, aes(x = X1, y = CATE)) +
geom_line(color = "#2166AC", linewidth = 1.2) +
geom_hline(yintercept = ate_grf["estimate"], linetype = "dashed", color = "gray50") +
labs(
x = "X1 (Moderator)",
y = "Partial Average CATE",
title = "Partial Dependence Plot: CATE vs. X1",
subtitle = "Marginalizing over other covariates; dashed = ATE"
) +
theme_minimal(base_size = 12)
# Bin by predicted CATE decile, compute realized treatment effect
calib_df <- data.frame(
tau_cf = tau_cf,
tau_true = tau_true,
W = W,
Y = Y
)
calib_df$decile <- cut(tau_cf,
breaks = quantile(tau_cf, seq(0, 1, 0.1)),
include.lowest = TRUE, labels = paste0("D", 1:10)
)
calib_tab <- calib_df %>%
group_by(decile) %>%
dplyr::summarise(
pred_cate = mean(tau_cf),
true_cate = mean(tau_true),
realized_eff = mean(Y[W == 1]) - mean(Y[W == 0]),
.groups = "drop"
)
ggplot(calib_tab, aes(x = pred_cate, y = realized_eff)) +
geom_point(size = 3, color = "#2166AC") +
geom_errorbar(aes(ymin = realized_eff - 0.2, ymax = realized_eff + 0.2),
width = 0.05, color = "#2166AC", alpha = 0.5) +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
geom_line(aes(y = true_cate), color = "#D73027", linewidth = 0.8,
linetype = "solid") +
labs(
x = "Mean Predicted CATE (by decile)",
y = "Realized Treatment Effect",
title = "CATE Calibration by Decile",
subtitle = "Dashed: perfect calibration; red: true CATE by decile"
) +
theme_minimal(base_size = 12)Robinson (1988) introduced the partially linear model for flexible treatment effect estimation:
Key insight: Partialling out from both and recovers as the coefficient in: where and .
Chernozhukov et al. (2018) showed that cross-fitting (using separate data folds to estimate nuisance functions) enables -consistent inference on even when the nuisance models converge at slower rates.
library(glmnet)
# DML with K=5 cross-fitting
set.seed(42)
K <- 5
n <- nrow(X)
folds <- sample(rep(1:K, length.out = n))
m_hat_dml <- numeric(n)
e_hat_dml <- numeric(n)
for (k in 1:K) {
tr <- which(folds != k)
te <- which(folds == k)
# Outcome model: lasso
cv_m <- cv.glmnet(X[tr, ], Y[tr], alpha = 1)
m_hat_dml[te] <- predict(cv_m, X[te, ], s = "lambda.min")
# Propensity: lasso logistic
cv_e <- cv.glmnet(X[tr, ], W[tr], family = "binomial", alpha = 1)
e_hat_dml[te] <- predict(cv_e, X[te, ], s = "lambda.min", type = "response")
}
# Clip propensity
e_hat_dml <- pmax(pmin(e_hat_dml, 0.99), 0.01)
# Residualize
Y_res_dml <- Y - m_hat_dml
W_res_dml <- W - e_hat_dml
# DML estimator: OLS of Y_res on W_res
dml_mod <- lm(Y_res_dml ~ W_res_dml - 1)
theta_dml <- coef(dml_mod)["W_res_dml"]
se_dml <- sqrt(vcov(dml_mod)["W_res_dml", "W_res_dml"])
cat("DML ATE estimate:", round(theta_dml, 4), "\n")
cat("DML SE: ", round(se_dml, 4), "\n")
cat("95% CI: [", round(theta_dml - 1.96 * se_dml, 4),
",", round(theta_dml + 1.96 * se_dml, 4), "]\n")
cat("True ATE:", round(mean(tau_true), 4), "\n")To allow to vary with a key moderator (e.g., ), interact the moderator with the treatment residual:
# Moderator: X1 (centered)
V <- X[, 1]
V_c <- V - mean(V)
# Interaction design matrix
Z_int <- cbind(W_res_dml, W_res_dml * V_c)
colnames(Z_int) <- c("W_res", "W_res_x_V")
dml_het <- lm(Y_res_dml ~ Z_int - 1)
coef_het <- coef(dml_het)
se_het <- sqrt(diag(vcov(dml_het)))
cat("Heterogeneous DML:\n")
cat(" theta_0 (ATE at V=0):", round(coef_het[1], 4),
"| SE:", round(se_het[1], 4), "\n")
cat(" theta_1 (moderation by X1):", round(coef_het[2], 4),
"| SE:", round(se_het[2], 4), "\n")
cat(" (True slope = 1)\n")The coefficient estimates how the treatment effect changes per unit increase in . Since , the true slope is 1.
library(DoubleML)
library(mlr3)
library(mlr3learners)
# Setup DoubleML data object
dml_data <- DoubleMLData$new(
data = data.frame(X, Y = Y, W = W),
y_col = "Y",
d_cols = "W",
x_cols = paste0("X", 1:p)
)
# Partially linear model with lasso learners
set.seed(42)
learner_m <- lrn("regr.cv_glmnet", alpha = 1)
learner_r <- lrn("regr.cv_glmnet", alpha = 1)
dml_plr <- DoubleMLPLR$new(
obj_dml_data = dml_data,
ml_l = learner_m,
ml_m = learner_r,
n_folds = 5,
score = "partialling out"
)
dml_plr$fit()
print(dml_plr$summary())
# IRM: allows fully flexible treatment effect heterogeneity
learner_g <- lrn("regr.cv_glmnet", alpha = 1)
learner_m2 <- lrn("classif.cv_glmnet", alpha = 1)
dml_irm <- DoubleMLIRM$new(
obj_dml_data = dml_data,
ml_g = learner_g,
ml_m = learner_m2,
n_folds = 5,
score = "ATE"
)
dml_irm$fit()
print(dml_irm$summary())The AIPW (Augmented Inverse Probability Weighting) estimator for the ATE combines outcome regression with IPW to achieve double robustness:
This is doubly robust: consistent if either or is correctly specified.
dr_ate()
if (requireNamespace("MatchIt", quietly = TRUE)) {
data(lalonde, package = "MatchIt")
aipw_res <- dr_ate(
data = lalonde,
outcome = "re78",
treatment = "treat",
covariates = c("age", "educ", "married", "nodegree", "re74", "re75"),
estimand = "ATE",
ps_trim = c(0.01, 0.99)
)
cat("AIPW ATE for Lalonde:\n")
cat(" Estimate:", round(aipw_res$estimate, 2), "\n")
cat(" SE: ", round(aipw_res$se, 2), "\n")
cat(" 95% CI: [", round(aipw_res$ci_lower, 2), ",",
round(aipw_res$ci_upper, 2), "]\n")
cat(" p-value:", round(aipw_res$p_value, 4), "\n")
cat(" n trimmed:", aipw_res$n_trimmed, "\n")
}The DR-Learner (Kennedy 2020) uses AIPW pseudo-outcomes as targets for CATE estimation:
Note under correct specification. The DR-Learner regresses these pseudo-outcomes on to obtain .
# Compute AIPW pseudo-outcomes using cross-fitted nuisances
# Reuse m_hat_dml and e_hat_dml from Section 8
# Separate mu1 and mu0 via cross-fitting
mu1_hat_dr <- numeric(n)
mu0_hat_dr <- numeric(n)
set.seed(42)
for (k in 1:K) {
tr <- which(folds != k)
te <- which(folds == k)
# Outcome models
X_df_k <- as.data.frame(X)
lm1 <- lm(Y ~ ., data = X_df_k[tr[W[tr] == 1], , drop = FALSE],
subset = NULL)
lm1 <- lm(Y[tr][W[tr] == 1] ~ ., data = X_df_k[tr[W[tr] == 1], ])
lm0 <- lm(Y[tr][W[tr] == 0] ~ ., data = X_df_k[tr[W[tr] == 0], ])
mu1_hat_dr[te] <- predict(lm1, newdata = X_df_k[te, ])
mu0_hat_dr[te] <- predict(lm0, newdata = X_df_k[te, ])
}
# AIPW pseudo-outcomes
psi_dr <- mu1_hat_dr - mu0_hat_dr +
W * (Y - mu1_hat_dr) / e_hat_dml -
(1 - W) * (Y - mu0_hat_dr) / (1 - e_hat_dml)
cat("Mean DR pseudo-outcome (ATE proxy):", round(mean(psi_dr), 4), "\n")
cat("True ATE:", round(mean(tau_true), 4), "\n")
# DR-Learner: regress pseudo-outcomes on X
cv_dr <- cv.glmnet(X, psi_dr, alpha = 1)
tau_dr <- predict(cv_dr, X, s = "lambda.min")[, 1]
cat("DR-Learner RMSE:", round(sqrt(mean((tau_dr - tau_true)^2)), 4), "\n")
# Naive IPW estimate
ipw_ate <- mean(W * Y / e_hat_dml) - mean((1 - W) * Y / (1 - e_hat_dml))
# AIPW estimate
aipw_ate <- mean(psi_dr)
# OLS estimate
ols_ate <- coef(lm(Y ~ W + X))[["W"]]
cat("Comparison of ATE estimators:\n")
cat(sprintf(" OLS: %.4f\n", ols_ate))
cat(sprintf(" IPW: %.4f\n", ipw_ate))
cat(sprintf(" AIPW: %.4f\n", aipw_ate))
cat(sprintf(" True: %.4f\n", mean(tau_true)))
# Variance comparison
se_ipw <- sd(W * Y / e_hat_dml - (1 - W) * Y / (1 - e_hat_dml)) / sqrt(n)
se_aipw <- sd(psi_dr) / sqrt(n)
cat(sprintf(" SE IPW: %.4f\n", se_ipw))
cat(sprintf(" SE AIPW: %.4f (efficiency gain = %.1f%%)\n",
se_aipw, 100 * (se_ipw - se_aipw) / se_ipw))The simplest parametric approach to HTE is to include treatment-covariate interactions:
The vector captures how the treatment effect varies with . This is interpretable but restrictive (linearity and correct functional form).
# Fit interaction model
X_df_int <- as.data.frame(X)
colnames(X_df_int) <- paste0("X", 1:p)
# Main effects + treatment interactions
int_formula <- as.formula(
paste("Y ~ W + (",
paste(paste0("X", 1:p), collapse = " + "),
") + W:(",
paste(paste0("X", 1:p), collapse = " + "),
")")
)
int_mod <- lm(int_formula, data = cbind(X_df_int, W = W, Y = Y))
# Extract interaction coefficients
int_coefs <- coef(int_mod)
int_ses <- sqrt(diag(vcov(int_mod)))
int_names <- names(int_coefs)
# Focus on interaction terms
int_idx <- grepl("W:X", int_names)
int_summary <- data.frame(
Term = int_names[int_idx],
Estimate = round(int_coefs[int_idx], 4),
SE = round(int_ses[int_idx], 4),
t_stat = round(int_coefs[int_idx] / int_ses[int_idx], 3),
p_value = round(2 * pt(-abs(int_coefs[int_idx] / int_ses[int_idx]),
df = n - length(int_coefs)), 4)
)
print(int_summary)The interaction term W:X1 should be close to 1 (the true slope of ).
An omnibus F-test compares the interaction model to the main-effects-only model:
# Manual marginal effects: E[tau | X1 = x1], averaging over X2, ..., X10
x1_vals <- seq(-2.5, 2.5, length.out = 50)
me_vals <- sapply(x1_vals, function(x1) {
X_temp <- X_df_int
X_temp$X1 <- x1
d1 <- cbind(X_temp, W = 1)
d0 <- cbind(X_temp, W = 0)
mean(predict(int_mod, d1) - predict(int_mod, d0))
})
me_df <- data.frame(X1 = x1_vals, AME = me_vals)
ggplot(me_df, aes(x = X1, y = AME)) +
geom_line(color = "#2166AC", linewidth = 1.2) +
geom_hline(yintercept = mean(tau_true), linetype = "dashed") +
geom_ribbon(aes(ymin = AME - 0.15, ymax = AME + 0.15),
alpha = 0.15, fill = "#2166AC") +
labs(
x = "X1 (Moderator)",
y = "Average Marginal Effect of Treatment",
title = "Marginal Treatment Effect by X1",
subtitle = "Dashed = ATE; shaded = approximate ±95% CI"
) +
theme_minimal(base_size = 12)marginaleffects
library(marginaleffects)
# Average marginal effect of W at different X1 levels
ame_W <- avg_slopes(int_mod, variables = "W")
cat("Average Marginal Effect of W:\n")
print(ame_W)
# Marginal effect of W by X1 quantile
me_by_x1 <- slopes(
int_mod,
variables = "W",
newdata = datagrid(X1 = quantile(X[, 1], c(0.1, 0.25, 0.5, 0.75, 0.9)))
)
cat("\nMarginal effect of W at X1 quantiles:\n")
print(me_by_x1[, c("X1", "estimate", "std.error", "conf.low", "conf.high")])Pre-specified subgroups (defined before seeing outcome data): - Control Type I error - More credible for publication - Require a pre-analysis plan (PAP) - May miss data-driven opportunities
Data-driven subgroups (from ML or recursive partitioning): - Can discover unexpected heterogeneity - Risk of overfitting and false discoveries - Require honest estimation (separate sample or honesty) and correction for multiple testing - Should be validated on held-out data
Best practice: Use pre-specified primary analyses; treat data-driven results as exploratory.
# Simulate testing treatment effects in 10 pre-specified subgroups
set.seed(42)
n_subgroups <- 10
subgroup_pvals <- numeric(n_subgroups)
for (s in 1:n_subgroups) {
# Define subgroup: quantile of X_s
Xs <- X[, s]
in_group <- Xs > median(Xs)
Y_g <- Y[in_group]
W_g <- W[in_group]
mod_g <- lm(Y_g ~ W_g)
subgroup_pvals[s] <- coef(summary(mod_g))["W_g", "Pr(>|t|)"]
}
# Multiple testing corrections
mt_df <- data.frame(
Subgroup = paste0("X", 1:n_subgroups, " > Median"),
p_raw = round(subgroup_pvals, 4),
p_bonf = round(p.adjust(subgroup_pvals, method = "bonferroni"), 4),
p_bh = round(p.adjust(subgroup_pvals, method = "BH"), 4),
sig_raw = subgroup_pvals < 0.05,
sig_bonf = p.adjust(subgroup_pvals, "bonferroni") < 0.05,
sig_bh = p.adjust(subgroup_pvals, "BH") < 0.05
)
print(mt_df)
cat("\nSignificant at 5%:\n")
cat(" Raw p-value: ", sum(mt_df$sig_raw), "subgroups\n")
cat(" Bonferroni: ", sum(mt_df$sig_bonf), "subgroups\n")
cat(" BH-FDR: ", sum(mt_df$sig_bh), "subgroups\n")The causal tree (Athey & Imbens 2016) uses honest estimation: the partitioning data and estimation data are separate. This controls for the bias introduced by searching over splits.
# Use GRF's causal_forest as an approximation to causal trees
# GRF uses honest sample splitting within each tree
# Identify best splits using variable importance
top_var <- which.max(variable_importance(cf))
cat("Top moderator by GRF importance:", colnames(X)[top_var], "\n")
# Use the top variable to define binary subgroups
split_val <- median(X[, top_var])
group_hi <- X[, top_var] > split_val
group_lo <- !group_hi
# Honest ATE within each group (using GRF predictions)
cat("\nHTE by top moderator subgroup:\n")
cat(" High X1 group: mean CATE =",
round(mean(tau_cf[group_hi]), 4),
"| n =", sum(group_hi), "\n")
cat(" Low X1 group: mean CATE =",
round(mean(tau_cf[group_lo]), 4),
"| n =", sum(group_lo), "\n")
cat(" True HTE: High =", round(mean(tau_true[group_hi]), 4),
"| Low =", round(mean(tau_true[group_lo]), 4), "\n")The virtual twins method (Foster et al. 2011) estimates individual treatment effects, then builds a decision tree on the CATE estimates to identify interpretable subgroups:
# Step 1: Estimate CATE (already have tau_cf from causal forest)
# Step 2: Classify as "responders" vs "non-responders"
threshold <- 0.0 # treat CATE > 0 as responders
responder <- as.integer(tau_cf > threshold)
cat("Virtual Twins:\n")
cat(" Responders (CATE > 0):", sum(responder), "/", n,
sprintf("(%.1f%%)\n", 100 * mean(responder)))
cat(" Mean CATE for responders:", round(mean(tau_cf[responder == 1]), 4), "\n")
cat(" Mean CATE for non-responders:", round(mean(tau_cf[responder == 0]), 4), "\n")
# Step 3: Fit a classification tree to identify subgroup predictors
if (requireNamespace("rpart", quietly = TRUE)) {
library(rpart)
vt_tree <- rpart(
responder ~ .,
data = cbind(as.data.frame(X), responder = responder),
method = "class",
control = rpart.control(cp = 0.02, maxdepth = 3)
)
cat("\nVirtual twins decision tree splits:\n")
print(vt_tree$frame[, c("var", "n", "yval")])
}heterogeneity_plot()
# Define richer subgroups
df_sg <- data.frame(X, cate = tau_cf, W = W, Y = Y)
df_sg$Sex <- ifelse(X[, 3] > 0, "Male", "Female")
df_sg$Age_Grp <- cut(X[, 4], breaks = c(-Inf, -1, 0, 1, Inf),
labels = c("Q1", "Q2", "Q3", "Q4"))
df_sg$Industry <- sample(c("Tech", "Finance", "Health", "Retail"),
n, replace = TRUE, prob = c(0.3, 0.2, 0.25, 0.25))
# Forest plot by industry
p_sg <- heterogeneity_plot(
data = df_sg,
cate_var = "cate",
subgroup_var = "Industry",
overall_effect = mean(tau_cf),
plot_type = "forest",
title = "HTE by Industry Subgroup"
)
p_sgWhen true is known (simulation), we can compute oracle metrics:
eval_metrics <- function(tau_hat, tau_true, name) {
data.frame(
Estimator = name,
RMSE = round(sqrt(mean((tau_hat - tau_true)^2)), 4),
MAE = round(mean(abs(tau_hat - tau_true)), 4),
Bias = round(mean(tau_hat - tau_true), 4),
Corr = round(cor(tau_hat, tau_true), 4),
Pct_Sign = round(mean(sign(tau_hat) == sign(tau_true)), 4)
)
}
eval_df <- rbind(
eval_metrics(tau_s, tau_true, "S-Learner (OLS)"),
eval_metrics(tau_t, tau_true, "T-Learner (OLS)"),
eval_metrics(tau_x, tau_true, "X-Learner (OLS)"),
eval_metrics(tau_cf, tau_true, "Causal Forest"),
eval_metrics(tau_dr, tau_true, "DR-Learner")
)
if (exists("tau_r")) {
eval_df <- rbind(eval_df, eval_metrics(tau_r, tau_true, "R-Learner"))
}
print(eval_df)When true is unknown, the R-loss provides a proxy:
Lower R-loss indicates better CATE recovery without requiring knowledge of true effects.
# Compute R-loss for each estimator
r_loss <- function(tau_hat, Y_res, W_res) {
mean((Y_res - tau_hat * W_res)^2)
}
rloss_df <- data.frame(
Estimator = c("S-Learner (OLS)", "T-Learner (OLS)",
"X-Learner (OLS)", "Causal Forest", "DR-Learner"),
R_Loss = round(c(
r_loss(tau_s, Y_res_dml, W_res_dml),
r_loss(tau_t, Y_res_dml, W_res_dml),
r_loss(tau_x, Y_res_dml, W_res_dml),
r_loss(tau_cf, Y_res_dml, W_res_dml),
r_loss(tau_dr, Y_res_dml, W_res_dml)
), 5)
)
print(rloss_df[order(rloss_df$R_Loss), ])
# 5-fold cross-validated AUUC for causal forest
K_auuc <- 5
set.seed(42)
folds_auuc <- sample(rep(1:K_auuc, length.out = n))
auuc_cv <- numeric(K_auuc)
for (k in 1:K_auuc) {
tr <- which(folds_auuc != k)
te <- which(folds_auuc == k)
cf_k <- causal_forest(X[tr, ], Y[tr], W[tr], num.trees = 500, seed = k)
tau_k <- predict(cf_k, X[te, ])$predictions
qini_k <- qini_curve(tau_k, Y[te], W[te],
boot_reps = 100, seed = k)
auuc_cv[k] <- qini_k$auuc
}
cat("Cross-validated AUUC (causal forest):\n")
cat(" Mean:", round(mean(auuc_cv), 4), "\n")
cat(" SD: ", round(sd(auuc_cv), 4), "\n")
cat(" 95% CI: [", round(mean(auuc_cv) - 2 * sd(auuc_cv) / sqrt(K_auuc), 4),
",", round(mean(auuc_cv) + 2 * sd(auuc_cv) / sqrt(K_auuc), 4), "]\n")
# Bootstrap CI for AUUC of causal forest
B <- 300
set.seed(42)
boot_auuc_vals <- numeric(B)
for (b in seq_len(B)) {
idx <- sample(n, replace = TRUE)
q_b <- qini_curve(tau_cf[idx], Y[idx], W[idx],
boot_reps = 50, seed = b)
boot_auuc_vals[b] <- q_b$auuc
}
cat("Bootstrap AUUC distribution:\n")
cat(" Mean:", round(mean(boot_auuc_vals), 4), "\n")
cat(" SD: ", round(sd(boot_auuc_vals), 4), "\n")
cat(" 95% CI (percentile): [",
round(quantile(boot_auuc_vals, 0.025), 4), ",",
round(quantile(boot_auuc_vals, 0.975), 4), "]\n")
# GRF provides honest confidence intervals via IJ variance
ci_df <- data.frame(
i = 1:n,
tau_hat = tau_cf,
tau_se = tau_se,
tau_lower = tau_cf - 1.96 * tau_se,
tau_upper = tau_cf + 1.96 * tau_se,
tau_true = tau_true
)
# Coverage check
coverage <- mean(ci_df$tau_true >= ci_df$tau_lower &
ci_df$tau_true <= ci_df$tau_upper)
cat("Empirical coverage of 95% CIs:", round(coverage, 3), "\n")
cat("(Nominal: 0.95)\n")
# Plot CIs for a subset of units, sorted by predicted CATE
sub <- ci_df[order(ci_df$tau_hat)[seq(1, n, by = n %/% 100)], ]
ggplot(sub, aes(x = seq_along(tau_hat), y = tau_hat)) +
geom_ribbon(aes(ymin = tau_lower, ymax = tau_upper), alpha = 0.2, fill = "#4393C3") +
geom_line(color = "#2166AC", linewidth = 0.5) +
geom_point(aes(y = tau_true), color = "#D73027", size = 0.8, alpha = 0.6) +
labs(
x = "Unit (sorted by estimated CATE)",
y = "CATE",
title = "Causal Forest: Pointwise 95% Confidence Intervals",
subtitle = sprintf("Coverage: %.1f%% | Red: true CATE | Blue band: 95%% CI",
100 * coverage)
) +
theme_minimal(base_size = 12)Bootstrapping the full pipeline (nuisance estimation + CATE estimation) provides valid uncertainty quantification without relying on asymptotic theory:
set.seed(42)
B_pipe <- 100
ate_boots <- numeric(B_pipe)
for (b in seq_len(B_pipe)) {
idx <- sample(n, replace = TRUE)
cf_b <- causal_forest(X[idx, ], Y[idx], W[idx],
num.trees = 500, seed = b)
ate_b <- average_treatment_effect(cf_b)
ate_boots[b] <- ate_b["estimate"]
}
cat("Bootstrap ATE distribution (full pipeline):\n")
cat(" Mean:", round(mean(ate_boots), 4), "\n")
cat(" SD: ", round(sd(ate_boots), 4), "\n")
cat(" 95% CI (percentile): [",
round(quantile(ate_boots, 0.025), 4), ",",
round(quantile(ate_boots, 0.975), 4), "]\n")
cat(" True ATE:", round(mean(tau_true), 4), "\n")When subgroups are selected based on estimated CATE and then tested, naive p-values are invalid - this is the winner’s curse or data snooping problem.
Solutions: 1. Sample splitting: Use one half for subgroup discovery, the other for inference. 2. Honest estimation: GRF uses honesty by construction. 3. Simultaneous confidence bands: Bonferroni or bootstrap-based corrections. 4. Pre-specification: Define subgroups before seeing outcome data.
# Illustration: naive vs. split-sample inference
set.seed(42)
n_half <- n %/% 2
idx_disc <- 1:n_half # discovery sample
idx_valid <- (n_half+1):n # validation sample
# "Discover" best subgroup on discovery sample
X_disc <- X[idx_disc, ]
tau_disc <- tau_cf[idx_disc]
best_var <- which.max(apply(X_disc, 2, function(xv) {
abs(cor(xv, tau_disc))
}))
split_d <- median(X_disc[, best_var])
cat("Best subgroup variable (discovery):", colnames(X)[best_var], "\n")
cat("Split value:", round(split_d, 3), "\n")
# Naive p-value (same sample as discovery)
in_disc_hi <- X_disc[, best_var] > split_d
ate_naive <- mean(tau_disc[in_disc_hi]) - mean(tau_disc[!in_disc_hi])
# Honest p-value (validation sample)
X_val <- X[idx_valid, ]
tau_val <- tau_cf[idx_valid]
in_val_hi <- X_val[, best_var] > split_d
ate_honest <- mean(tau_val[in_val_hi]) - mean(tau_val[!in_val_hi])
cat("\nNaive CATE diff (discovery sample):", round(ate_naive, 4), "\n")
cat("Honest CATE diff (validation sample):", round(ate_honest, 4), "\n")
cat("True CATE diff (full sample):",
round(mean(tau_true[X[, best_var] > split_d]) -
mean(tau_true[X[, best_var] <= split_d]), 4), "\n")A complete HTE analysis for publication should include:
# Publication-quality subgroup table
subgroup_vars <- list(
"X1 (High)" = X[, 1] > 0,
"X1 (Low)" = X[, 1] <= 0,
"X2 (High)" = X[, 2] > 0,
"X2 (Low)" = X[, 2] <= 0,
"Both High" = X[, 1] > 0 & X[, 2] > 0,
"Both Low" = X[, 1] <= 0 & X[, 2] <= 0,
"X1>0, X2<=0" = X[, 1] > 0 & X[, 2] <= 0,
"X1<=0, X2>0" = X[, 1] <= 0 & X[, 2] > 0,
"All" = rep(TRUE, n)
)
sg_table <- lapply(names(subgroup_vars), function(sg_name) {
idx <- which(subgroup_vars[[sg_name]])
n_g <- length(idx)
# ATE via GRF predictions (honest)
ate_g <- mean(tau_cf[idx])
se_g <- sd(tau_cf[idx]) / sqrt(n_g)
# True ATE
true_g <- mean(tau_true[idx])
data.frame(
Subgroup = sg_name,
N = n_g,
CATE_Est = round(ate_g, 3),
SE = round(se_g, 3),
CI_lo = round(ate_g - 1.96 * se_g, 3),
CI_hi = round(ate_g + 1.96 * se_g, 3),
p_val = round(2 * pnorm(-abs(ate_g / se_g)), 4),
True_CATE = round(true_g, 3),
stringsAsFactors = FALSE
)
})
sg_df <- do.call(rbind, sg_table)
sg_df$Sig <- ifelse(sg_df$p_val < 0.001, "***",
ifelse(sg_df$p_val < 0.01, "**",
ifelse(sg_df$p_val < 0.05, "*", "")))
print(sg_df)
sg_plot_df <- sg_df[sg_df$Subgroup != "All", ]
overall_ate <- sg_df$CATE_Est[sg_df$Subgroup == "All"]
sg_plot_df$Subgroup <- factor(sg_plot_df$Subgroup,
levels = sg_plot_df$Subgroup[order(sg_plot_df$CATE_Est)])
ggplot(sg_plot_df, aes(x = CATE_Est, y = Subgroup)) +
geom_vline(xintercept = overall_ate, linetype = "dashed",
color = "#D73027", linewidth = 0.8) +
geom_vline(xintercept = 0, color = "gray70") +
geom_errorbar(aes(xmin = CI_lo, xmax = CI_hi),
width = 0.3, color = "#2166AC", linewidth = 0.7, orientation = "y") +
geom_point(aes(size = N), color = "#2166AC", shape = 18) +
geom_text(aes(label = sprintf("%.3f%s", CATE_Est, Sig)),
hjust = -0.3, size = 3) +
scale_size_continuous(range = c(2, 6), guide = "none") +
labs(
x = "Estimated CATE (95% CI)",
y = NULL,
title = "Subgroup Treatment Effects",
subtitle = sprintf("Red dashed = Overall ATE (%.3f) | *** p<.001, ** p<.01, * p<.05",
overall_ate),
caption = "Estimates from GRF causal forest (out-of-bag predictions)"
) +
theme_minimal(base_size = 12) +
theme(
panel.grid.minor = element_blank(),
axis.text.y = element_text(face = "bold")
)
# GRF variable importance
vi_df <- data.frame(
Feature = colnames(X),
Importance = as.numeric(variable_importance(cf))
)
vi_df <- vi_df[order(vi_df$Importance, decreasing = TRUE), ]
vi_df$Feature <- factor(vi_df$Feature, levels = rev(vi_df$Feature))
vi_df$IsTrue <- vi_df$Feature == "X1"
ggplot(vi_df, aes(x = Feature, y = Importance, fill = IsTrue)) +
geom_col(alpha = 0.85, show.legend = FALSE) +
scale_fill_manual(values = c("FALSE" = "#4393C3", "TRUE" = "#D73027")) +
coord_flip() +
labs(
x = NULL,
y = "Variable Importance (GRF)",
title = "Feature Importance for Treatment Effect Heterogeneity",
subtitle = "Red = true driver of heterogeneity (X1)"
) +
theme_minimal(base_size = 12)
p1 <- ggplot(data.frame(tau = tau_cf), aes(x = tau)) +
geom_histogram(bins = 50, fill = "#4393C3", alpha = 0.7, color = "white") +
geom_vline(xintercept = mean(tau_cf), color = "#D73027",
linewidth = 1.2, linetype = "solid") +
geom_vline(xintercept = 0, linetype = "dashed") +
labs(x = "Estimated CATE", y = "Count",
title = "CATE Distribution") +
theme_minimal(base_size = 11)
p2 <- ggplot(data.frame(tau_cf = tau_cf, X1 = X[, 1]),
aes(x = X1, y = tau_cf)) +
geom_point(alpha = 0.2, size = 0.8, color = "#2166AC") +
geom_smooth(method = "lm", color = "#D73027", se = TRUE) +
labs(x = "X1 (top moderator)", y = "Estimated CATE",
title = "CATE vs. Top Moderator") +
theme_minimal(base_size = 11)
if (requireNamespace("patchwork", quietly = TRUE)) {
library(patchwork)
p1 + p2
} else {
print(p1)
print(p2)
}A deterministic targeting rule assigns treatment based on covariates:
where is a cost threshold (e.g., cost of treatment). The policy value is:
# Compute policy value for different cost thresholds
costs <- seq(-1.5, 1.5, by = 0.25)
pv_tab <- lapply(costs, function(c_val) {
# Policy: treat if CATE > c
pi_i <- as.integer(tau_cf >= c_val)
pct_treat <- mean(pi_i)
# Policy value via DR scores
psi_policy <- pi_i * psi_dr + (1 - pi_i) * 0 # DR augmented
# Actual policy value = E[Y(pi)]
# = E[mu0 + pi * tau] ≈ E[mu0] + pi * CATE
pv_oracle <- mean(mu0_hat_dr + pi_i * tau_true)
pv_est <- mean(mu0_hat_dr + pi_i * tau_cf)
data.frame(
Cost = c_val,
Pct_Treat = round(pct_treat, 3),
PV_Oracle = round(pv_oracle, 4),
PV_Estimated = round(pv_est, 4)
)
})
pv_df <- do.call(rbind, pv_tab)
print(pv_df)
# Find optimal cost threshold
optimal_c <- pv_df$Cost[which.max(pv_df$PV_Oracle)]
ggplot(pv_df, aes(x = Cost)) +
geom_line(aes(y = PV_Oracle, linetype = "Oracle"), color = "#D73027",
linewidth = 1.1) +
geom_line(aes(y = PV_Estimated, linetype = "Estimated"), color = "#2166AC",
linewidth = 1.1) +
geom_vline(xintercept = optimal_c, linetype = "dashed", color = "gray40") +
scale_linetype_manual(values = c("Oracle" = "solid", "Estimated" = "dashed")) +
labs(
x = "Cost Threshold c (treat if CATE > c)",
y = "Policy Value",
title = "Policy Value vs. Cost Threshold",
subtitle = sprintf("Optimal threshold (oracle) = %.2f", optimal_c),
linetype = NULL
) +
theme_minimal(base_size = 12) +
theme(legend.position = "bottom")
# Welfare gain of CATE-based targeting vs. uniform policies
# No treatment baseline
pv_no_treat <- mean(tau_true * 0 + tau_true * 0) # E[Y(0)] proxy
pv_all_treat <- mean(tau_true) # E[Y(1) - Y(0)] = ATE
# Optimal targeting
pi_opt <- as.integer(tau_cf >= 0)
pv_target <- mean(pi_opt * tau_true)
cat("Welfare analysis (relative to baseline):\n")
cat(" Treat nobody: ATE gain = 0\n")
cat(" Treat everyone: ATE gain =", round(pv_all_treat, 4), "\n")
cat(" Optimal targeting (CATE > 0): gain =", round(pv_target, 4), "\n")
cat(" Gain from targeting vs. treating all:",
round(pv_target - pv_all_treat, 4), "\n")
cat(" Fraction treated under optimal rule:", round(mean(pi_opt), 3), "\n")A stochastic policy assigns treatment probabilistically: . This generalizes the binary rule and allows smooth interpolation:
# Stochastic policy: probability = sigmoid of CATE
softmax_policy <- function(tau, temperature = 1) {
1 / (1 + exp(-tau / temperature))
}
# Different temperatures (higher = more uniform)
temps <- c(0.1, 0.5, 1.0, 2.0, Inf)
sp_df <- lapply(temps, function(T) {
pi_soft <- if (is.infinite(T)) rep(0.5, n) else softmax_policy(tau_cf, T)
pv_soft <- mean(pi_soft * tau_true)
data.frame(
Temperature = T,
Mean_Pi = round(mean(pi_soft), 3),
Policy_Value = round(pv_soft, 4)
)
})
print(do.call(rbind, sp_df))
# How do CATE estimates change with model specification?
set.seed(42)
specs <- list(
"Default (2000 trees)" = list(num.trees = 2000, min.node.size = NULL),
"Small forest (500)" = list(num.trees = 500, min.node.size = NULL),
"Large nodes" = list(num.trees = 1000, min.node.size = 50),
"Small nodes" = list(num.trees = 1000, min.node.size = 5)
)
sens_df <- lapply(names(specs), function(spec_name) {
args <- specs[[spec_name]]
cf_s <- do.call(causal_forest,
c(list(X = X, Y = Y, W = W), args[!sapply(args, is.null)]))
tau_s_hat <- predict(cf_s)$predictions
ate_s <- average_treatment_effect(cf_s)
data.frame(
Specification = spec_name,
ATE = round(ate_s["estimate"], 4),
ATE_SE = round(ate_s["std.err"], 4),
CATE_SD = round(sd(tau_s_hat), 4),
RMSE_oracle = round(sqrt(mean((tau_s_hat - tau_true)^2)), 4),
Corr_oracle = round(cor(tau_s_hat, tau_true), 4)
)
})
print(do.call(rbind, sens_df))
# What happens when propensity scores are extreme?
set.seed(42)
n_ov <- 1000
X_ov <- matrix(rnorm(n_ov * 5), n_ov, 5)
# Treatment strongly predicted by X1 (poor overlap)
ps_ov <- plogis(2 * X_ov[, 1])
W_ov <- rbinom(n_ov, 1, ps_ov)
tau_ov <- X_ov[, 1]
Y_ov <- tau_ov * W_ov + rnorm(n_ov)
cat("Overlap summary:\n")
cat(" Mean PS:", round(mean(ps_ov), 3), "\n")
cat(" PS < 0.1:", sum(ps_ov < 0.1), "| PS > 0.9:", sum(ps_ov > 0.9), "\n")
cat(" Effective n after 0.1-0.9 trimming:",
sum(ps_ov >= 0.1 & ps_ov <= 0.9), "\n")
# AIPW with trimming vs without
trim_idx <- ps_ov >= 0.05 & ps_ov <= 0.95
cat(" Fraction retained:", round(mean(trim_idx), 3), "\n")| Scenario | Recommended Method | Reason |
|---|---|---|
| Large N, high p, no strong priors | Causal Forest (GRF) | Nonparametric, honest, CI available |
| Balanced arms, moderate N | T-Learner (XGBoost) | Simple, flexible |
| Imbalanced arms | X-Learner | Designed for imbalance |
| Linear heterogeneity hypothesized | DML + interaction | Interpretable |
| Semiparametric efficiency desired | DR-Learner | Optimal convergence rate |
| Publication with pre-specified model | Parametric interaction | Transparent, interpretable |
Before reporting HTE results:
When writing up HTE results:
causalverse Vignettes
When the treatment is continuous (dose, intensity) rather than binary, the estimand generalizes to the Dose-Response Function (DRF) or the Average Derivative Effect:
The grf package provides a causal_forest variant for continuous treatments by setting the treatment as a numeric variable.
# Continuous treatment simulation
set.seed(42)
n_cont <- 1500
p_cont <- 8
X_cont <- matrix(rnorm(n_cont * p_cont), n_cont, p_cont)
# Continuous dose: uniform on [0, 2]
W_cont <- runif(n_cont, 0, 2)
# Nonlinear dose-response: quadratic with heterogeneity
tau_cont <- X_cont[, 1] * W_cont - 0.5 * W_cont^2
Y_cont <- tau_cont + 0.3 * X_cont[, 2] + rnorm(n_cont, sd = 0.5)
# Causal forest with continuous treatment
cf_cont <- causal_forest(
X = X_cont, Y = Y_cont, W = W_cont,
num.trees = 1000, seed = 42
)
ate_cont <- average_treatment_effect(cf_cont)
cat("Continuous treatment ATE (avg marginal effect):\n")
cat(" Estimate:", round(ate_cont["estimate"], 4),
"| SE:", round(ate_cont["std.err"], 4), "\n")When treatment is endogenous, IV methods identify the Local Average Treatment Effect (LATE) - the CATE for compliers. With multiple instruments and covariates, an IV causal forest can estimate heterogeneous LATEs:
# IV simulation: instrument Z induces compliance
set.seed(42)
n_iv <- 1500
X_iv <- matrix(rnorm(n_iv * 6), n_iv, 6)
Z_iv <- rbinom(n_iv, 1, 0.5) # instrument (randomized)
# Compliance: P(W=1|Z=1,X) > P(W=1|Z=0,X)
p_comply <- plogis(0.5 * X_iv[, 1] + 0.5)
W_iv <- rbinom(n_iv, 1, Z_iv * p_comply + (1 - Z_iv) * p_comply * 0.2)
# Heterogeneous LATE: driven by X1
tau_iv <- pmax(X_iv[, 1], 0)
Y_iv <- tau_iv * W_iv + 0.4 * X_iv[, 2] + rnorm(n_iv)
# IV forest
iv_cf <- instrumental_forest(
X = X_iv, Y = Y_iv, W = W_iv, Z = Z_iv,
num.trees = 1000, seed = 42
)
late_est <- average_treatment_effect(iv_cf)
cat("IV Forest LATE estimate:\n")
cat(" Estimate:", round(late_est["estimate"], 4),
"| SE:", round(late_est["std.err"], 4), "\n")
# Heterogeneous LATE predictions
tau_iv_hat <- predict(iv_cf)$predictions
cat(" LATE heterogeneity (SD):", round(sd(tau_iv_hat), 4), "\n")
cat(" Compliance rate:", round(mean(W_iv[Z_iv == 1]) - mean(W_iv[Z_iv == 0]), 4), "\n")In panel data settings with staggered adoption, the CATE generalizes to group-time ATTs. The did package (Callaway & Sant’Anna 2021) estimates , which can be aggregated to obtain heterogeneous effects by subgroup:
# DiD with heterogeneous treatment timing
set.seed(42)
n_units <- 500
n_times <- 6
# Staggered adoption: some treated in t=3, others in t=5
unit_fe <- rnorm(n_units)
time_fe <- seq(0, 1, length.out = n_times)
treated <- rbinom(n_units, 1, 0.6)
timing <- ifelse(treated == 1, sample(c(3, 5), n_units, replace = TRUE), Inf)
# Heterogeneous treatment effect: larger for units with unit_fe > 0
tau_het <- ifelse(unit_fe > 0, 2.0, 0.5) * treated
panel_df <- data.frame(
unit = rep(1:n_units, each = n_times),
time = rep(1:n_times, times = n_units),
unit_fe = rep(unit_fe, each = n_times),
timing = rep(timing, each = n_times),
tau = rep(tau_het, each = n_times)
)
panel_df$treated_period <- panel_df$time >= panel_df$timing
panel_df$Y <- panel_df$unit_fe +
rep(time_fe, times = n_units) +
panel_df$treated_period * panel_df$tau +
rnorm(n_units * n_times, sd = 0.5)
# Simple DiD: high vs low unit_fe subgroups
panel_df$group_label <- ifelse(panel_df$unit_fe > 0, "High FE", "Low FE")
# Subgroup-specific DiD ATT
did_results <- lapply(c("High FE", "Low FE"), function(grp) {
sub <- panel_df[panel_df$group_label == grp, ]
post <- sub$time >= 3 & sub$treated_period
control <- sub$time >= 3 & !sub$treated_period
diff <- mean(sub$Y[post]) - mean(sub$Y[!post]) -
(mean(sub$Y[sub$time >= 3 & sub$timing > 6]) -
mean(sub$Y[sub$time < 3 & sub$timing > 6]))
data.frame(Group = grp,
Naive_Diff = round(diff, 4),
True_ATT = round(mean(sub$tau[sub$treated_period]), 4))
})
cat("DiD Subgroup ATTs:\n")
print(do.call(rbind, did_results))causalverse
The following shows the complete workflow using causalverse functions:
# Step 1: Data setup (already done: X, Y, W, tau_true)
# Step 2: Fit causal forest
set.seed(42)
cf_final <- causal_forest(X, Y, W, num.trees = 2000, tune.parameters = "all")
# Step 3: Comprehensive summary
pipeline_summary <- causal_forest_summary(
cf = cf_final,
X = X,
feature_names = colnames(X),
n_groups = 4,
top_features = 10
)
# Step 4: Print ATE
cat("=== PIPELINE RESULTS ===\n")
cat("\n--- ATE ---\n")
print(pipeline_summary$ate)
cat("\n--- CATE Summary ---\n")
print(round(pipeline_summary$cate_summary, 4))
cat("\n--- BLP Test ---\n")
print(pipeline_summary$blp$blp)
# Step 5: Visualize
pipeline_summary$plot_cate_dist
pipeline_summary$plot_gates
# Step 6: Policy evaluation
tau_final <- predict(cf_final)$predictions
qini_final <- qini_curve(
cate_hat = tau_final,
Y = Y,
W = W,
boot_reps = 200
)
cat("\n--- Qini / AUUC ---\n")
cat("AUUC:", round(qini_final$auuc, 4), "\n")
cat("95% CI:", round(qini_final$auuc_ci[1], 4), "-",
round(qini_final$auuc_ci[2], 4), "\n")A common workflow combines GRF’s cross-fitted nuisances with the AIPW estimator for the final ATE, and uses the causal forest CATE for targeting:
# Extract GRF's cross-fitted nuisances
e_grf <- cf_final$W.hat # propensity score
m_grf <- cf_final$Y.hat # E[Y|X]
tau_grf <- predict(cf_final)$predictions
# AIPW ATE using GRF nuisances
Y_cf1 <- m_grf + tau_grf
Y_cf0 <- m_grf
psi_grf <- Y_cf1 - Y_cf0 +
W * (Y - Y_cf1) / e_grf -
(1 - W) * (Y - Y_cf0) / (1 - e_grf)
aipw_grf_ate <- mean(psi_grf)
aipw_grf_se <- sd(psi_grf) / sqrt(n)
cat("AIPW ATE (GRF nuisances):\n")
cat(" Estimate:", round(aipw_grf_ate, 4), "\n")
cat(" SE: ", round(aipw_grf_se, 4), "\n")
cat(" 95% CI: [",
round(aipw_grf_ate - 1.96 * aipw_grf_se, 4), ",",
round(aipw_grf_ate + 1.96 * aipw_grf_se, 4), "]\n")
cat(" True ATE:", round(mean(tau_true), 4), "\n")causalverse::dr_ate() on Real Data
if (requireNamespace("MatchIt", quietly = TRUE)) {
data(lalonde, package = "MatchIt")
# ATE
dr_res_ate <- dr_ate(
data = lalonde,
outcome = "re78",
treatment = "treat",
covariates = c("age", "educ", "married", "nodegree", "re74", "re75"),
estimand = "ATE",
boot_se = FALSE
)
# ATT
dr_res_att <- dr_ate(
data = lalonde,
outcome = "re78",
treatment = "treat",
covariates = c("age", "educ", "married", "nodegree", "re74", "re75"),
estimand = "ATT"
)
cat("Lalonde AIPW Results:\n")
cat(sprintf(" ATE: %.0f (SE = %.0f, p = %.4f)\n",
dr_res_ate$estimate, dr_res_ate$se, dr_res_ate$p_value))
cat(sprintf(" ATT: %.0f (SE = %.0f, p = %.4f)\n",
dr_res_att$estimate, dr_res_att$se, dr_res_att$p_value))
cat(sprintf(" Propensity score range: [%.3f, %.3f]\n",
dr_res_ate$ps_summary["Min."],
dr_res_ate$ps_summary["Max."]))
}
data(lalonde, package = "MatchIt")
# Prepare matrix
cov_names <- c("age", "educ", "married", "nodegree", "re74", "re75")
X_la <- model.matrix(~ age + educ + married + nodegree + re74 + re75 - 1,
data = lalonde)
Y_la <- lalonde$re78
W_la <- lalonde$treat
set.seed(42)
cf_la <- causal_forest(X_la, Y_la, W_la,
num.trees = 2000, tune.parameters = "all")
tau_la <- predict(cf_la)$predictions
ate_la <- average_treatment_effect(cf_la)
cat("Lalonde CATE Analysis:\n")
cat(" ATE:", round(ate_la["estimate"], 1),
"| SE:", round(ate_la["std.err"], 1), "\n")
cat(" CATE range: [", round(min(tau_la), 0), ",",
round(max(tau_la), 0), "]\n")
cat(" Fraction with CATE > 0:", round(mean(tau_la > 0), 3), "\n")
# Subgroup analysis for Lalonde
la_df <- data.frame(X_la, cate = tau_la, W = W_la, Y = Y_la)
la_df$educ_grp <- ifelse(lalonde$educ >= 12, "HS Graduate", "No HS Diploma")
la_df$married_grp <- ifelse(lalonde$married == 1, "Married", "Single")
la_df$re74_grp <- ifelse(lalonde$re74 > 0, "Employed in 1974", "Not Employed")
la_df$nodegree_grp <- ifelse(lalonde$nodegree == 1, "No Degree", "Has Degree")
la_df$age_grp <- cut(lalonde$age, breaks = c(16, 24, 30, 55),
labels = c("Young (17-24)", "Middle (25-30)", "Older (31+)"))
la_df$Subgroup <- paste(la_df$educ_grp)
# Use heterogeneity_plot for forest plot
p_la_edu <- heterogeneity_plot(
data = la_df,
cate_var = "cate",
subgroup_var = "educ_grp",
overall_effect = ate_la["estimate"],
plot_type = "forest",
title = "Lalonde: HTE by Education"
)
p_la_re74 <- heterogeneity_plot(
data = la_df,
cate_var = "cate",
subgroup_var = "re74_grp",
overall_effect = ate_la["estimate"],
plot_type = "forest",
title = "Lalonde: HTE by 1974 Employment"
)
if (requireNamespace("patchwork", quietly = TRUE)) {
library(patchwork)
p_la_edu + p_la_re74
} else {
print(p_la_edu)
print(p_la_re74)
}
# Qini for Lalonde
qini_la <- qini_curve(tau_la, Y_la, W_la, boot_reps = 200)
cat("Lalonde AUUC:", round(qini_la$auuc, 4), "\n")
qini_la$plot + labs(title = "Lalonde: Qini Curve for Job Training Targeting")If treatment of one unit affects another (spillovers), SUTVA is violated and CATE estimates may be biased. Signs include:
Remedies: Cluster-randomized designs, partial interference models, network causal inference.
Never condition on variables that are affected by treatment (colliders or mediators). Including post-treatment variables in can introduce bias:
# Illustration: including a mediator biases CATE
set.seed(42)
n_col <- 1000
X_col <- rnorm(n_col)
W_col <- rbinom(n_col, 1, 0.5)
tau_col <- X_col # true CATE = X
M_col <- 0.5 * W_col + 0.2 * X_col + rnorm(n_col, sd = 0.3) # mediator
Y_col <- tau_col * W_col + M_col + rnorm(n_col, sd = 0.5)
# Correct model: no mediator
lm_correct <- lm(Y_col ~ W_col + X_col)
cat("Correct (no mediator) W coef:", round(coef(lm_correct)["W_col"], 4), "\n")
# Incorrect model: include mediator (post-treatment)
lm_biased <- lm(Y_col ~ W_col + X_col + M_col)
cat("Biased (mediator included) W coef:", round(coef(lm_biased)["W_col"], 4), "\n")
cat("True ATE:", round(mean(tau_col), 4), "\n")
cat("Warning: including post-treatment variable M biases ATE toward 0!\n")ML models can extrapolate poorly in regions with low support. Always check:
# Overlap plot: propensity score by treatment arm
overlap_df <- data.frame(
ps = cf_final$W.hat,
treat = factor(W, levels = c(0, 1), labels = c("Control", "Treated"))
)
ggplot(overlap_df, aes(x = ps, fill = treat)) +
geom_histogram(bins = 40, alpha = 0.6, position = "identity") +
scale_fill_manual(values = c("Control" = "#4393C3", "Treated" = "#D73027")) +
geom_vline(xintercept = c(0.1, 0.9), linetype = "dashed") +
labs(
x = "Estimated Propensity Score",
y = "Count",
fill = NULL,
title = "Propensity Score Overlap",
subtitle = "Dashed lines at 0.1 and 0.9 (overlap region)"
) +
theme_minimal(base_size = 12) +
theme(legend.position = "bottom")The choice of nuisance model (for and ) affects CATE estimation:
Best practice: Use cross-fitting with a flexible but regularized learner (lasso, random forest, gradient boosting). Cross-validate the nuisance models separately.
# Cross-validate nuisance model performance
set.seed(42)
cv_m_final <- cv.glmnet(X, Y, alpha = 1, nfolds = 10)
cv_e_final <- cv.glmnet(X, W, family = "binomial", alpha = 1, nfolds = 10)
cat("Nuisance model cross-validation:\n")
cat(" Outcome model (lasso) R^2:",
round(1 - min(cv_m_final$cvm) / var(Y), 4), "\n")
cat(" Propensity model AUC: ~0.5 (random assignment)\n")
cat(" Optimal lambda (outcome):", round(cv_m_final$lambda.min, 6), "\n")
cat(" Optimal lambda (propensity):", round(cv_e_final$lambda.min, 6), "\n")The same data can show heterogeneity on the absolute scale (CATE = ) but not on the relative scale (Relative Risk = ), or vice versa. Always specify which scale you are studying:
For binary outcomes, both scales are relevant. For continuous outcomes, the additive scale is most common.
Athey, S., & Imbens, G. W. (2016). Recursive partitioning for heterogeneous causal effects. Proceedings of the National Academy of Sciences, 113(27), 7353–7360.
Athey, S., Tibshirani, J., & Wager, S. (2019). Generalized random forests. Annals of Statistics, 47(2), 1148–1178.
Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W., & Robins, J. (2018). Double/debiased machine learning for treatment and structural parameters. Econometrics Journal, 21(1), C1–C68.
Chernozhukov, V., Demirer, M., Duflo, E., & Fernandez-Val, I. (2020). Generic machine learning inference on heterogeneous treatment effects in randomized experiments. NBER Working Paper 24678.
Foster, J. C., Taylor, J. M., & Ruberg, S. J. (2011). Subgroup identification from randomized clinical trial data. Statistics in Medicine, 30(24), 2867–2880.
Kennedy, E. H. (2020). Optimal doubly robust estimation of heterogeneous causal effects. arXiv preprint arXiv:2004.14497.
Künzel, S. R., Sekhon, J. S., Bickel, P. J., & Yu, B. (2019). Metalearners for estimating heterogeneous treatment effects using machine learning. Proceedings of the National Academy of Sciences, 116(10), 4156–4165.
Lalonde, R. J. (1986). Evaluating the econometric evaluations of training programs with experimental data. American Economic Review, 76(4), 604–620.
Nie, X., & Wager, S. (2021). Quasi-oracle estimation of heterogeneous treatment effects. Biometrika, 108(2), 299–319.
Radcliffe, N. J. (2007). Using control groups to target on predicted lift. Direct Marketing Analytics Journal, 1(3), 14–21.
Robinson, P. M. (1988). Root-N-consistent semiparametric regression. Econometrica, 56(4), 931–954.
Wager, S., & Athey, S. (2018). Estimation and inference of heterogeneous treatment effects using random forests. Journal of the American Statistical Association, 113(523), 1228–1242.