This tutorial demonstrates g-methods comparing traditional parametric approaches with machine learning methods. We’ll show when flexible models like random forests provide advantages over linear models in causal inference, particularly when relationships are highly non-linear.
We examine the effect of HIV treatment (\(A\)) on CD4 count (\(Y\)) using two datasets to demonstrate g-formula performance:
Our variables are:
In this setup, the true average treatment effect (ATE) is 50. First, let’s reproduce the basic example with simple relationships:
# Simple tabular data from original tutorial
z <- c(0, 0, 1, 1)
a <- c(0, 1, 0, 1)
y <- c(100, 150, 80, 130)
n <- c(300, 200, 150, 350)
data_simple <- data.frame(z, a, y, n)
kable(data_simple,
caption = "Table 1: Simple Cross-sectional Data",
col.names = c("Z (Viral Load)", "A (Treatment)", "Y (CD4 Count)", "N (Sample Size)")) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"))
Z (Viral Load) | A (Treatment) | Y (CD4 Count) | N (Sample Size) |
---|---|---|---|
0 | 0 | 100 | 300 |
0 | 1 | 150 | 200 |
1 | 0 | 80 | 150 |
1 | 1 | 130 | 350 |
# Calculate marginal distribution of Z
total_n <- sum(n)
z_marginal <- data_simple %>%
group_by(z) %>%
summarise(total = sum(n), prob = sum(n) / total_n, .groups = 'drop')
# Non-parametric G-formula
treated_outcomes <- data_simple %>%
filter(a == 1) %>%
left_join(z_marginal %>% select(z, prob), by = "z") %>%
mutate(weighted_y = y * prob)
untreated_outcomes <- data_simple %>%
filter(a == 0) %>%
left_join(z_marginal %>% select(z, prob), by = "z") %>%
mutate(weighted_y = y * prob)
ate_simple_nonparam <- sum(treated_outcomes$weighted_y) - sum(untreated_outcomes$weighted_y)
# Parametric G-formula (expand data and fit models)
data_expanded <- data_simple %>% slice(rep(row_number(), n))
# Fit models
y_model_simple <- lm(y ~ a + z, data = data_simple, weights = n)
# Simulate and predict
set.seed(123)
sim_n <- 100000
z_sim <- sample(z_marginal$z, size = sim_n, replace = TRUE, prob = z_marginal$prob)
sim_data <- data.frame(z = z_sim)
y_treated_simple <- predict(y_model_simple, newdata = sim_data %>% mutate(a = 1))
y_untreated_simple <- predict(y_model_simple, newdata = sim_data %>% mutate(a = 0))
ate_simple_param <- mean(y_treated_simple) - mean(y_untreated_simple)
# Random Forest
rf_model_simple <- randomForest(y ~ a + z, data = data_expanded, ntree = 500)
y_rf_treated_simple <- predict(rf_model_simple, newdata = sim_data %>% mutate(a = 1))
y_rf_untreated_simple <- predict(rf_model_simple, newdata = sim_data %>% mutate(a = 0))
ate_simple_rf <- mean(y_rf_treated_simple) - mean(y_rf_untreated_simple)
cat("Simple Data Results:\n")
## Simple Data Results:
## Non-parametric ATE: 50
## Linear model ATE: 50
## Random Forest ATE: 36.38
Now let’s create data with strong non-linear relationships to demonstrate when Random Forest excels:
set.seed(42)
n_obs <- 10000 # Enough sample for ML methods
# Generate multiple confounders for richer interactions
z1 <- runif(n_obs, 0, 100) # Viral load
z2 <- runif(n_obs, 20, 70) # Age
z3 <- rbinom(n_obs, 1, 0.3) # Comorbidity (0/1)
# Complex non-linear treatment assignment with smooth relationships
treatment_logit <- -1.5 +
# Smooth S-curve for viral load (sicker patients less likely to get treatment)
3 * plogis((z1 - 50) / 20) - 1.5 +
# Smooth inverse-U for age (middle-aged more likely to get treatment)
2 * exp(-((z2 - 45) / 15)^2) - 0.5 +
# Comorbidity effect that varies smoothly with other variables
z3 * (0.5 + 0.02 * z1 - 0.01 * z2) +
# Smooth interaction surfaces
0.015 * z1 * (z2 - 45) / 25 + # Viral load × Age interaction
0.3 * sin(z1 * pi / 50) * cos(z2 * pi / 40) + # Trigonometric interaction
z3 * 0.2 * cos((z1 + z2) * pi / 80) # Three-way interaction
treatment_prob <- plogis(treatment_logit)
a <- rbinom(n_obs, 1, treatment_prob)
# Complex outcome model with baseline health declining with viral load and age
baseline_outcome <- 200 - 1.5 * z1 + 0.01 * z1^2 + # Viral load effect
50 * exp(-((z2 - 40) / 12)^2) + # Age effect (peak health ~40)
z3 * (-30 - 0.3 * z1 + 0.5 * z2) # Comorbidity penalty
# Highly heterogeneous treatment effects - the key for demonstrating RF advantage
treatment_effect <-
# Base effect varies smoothly with viral load
30 + 0.8 * z1 + 20 * tanh((z1 - 40) / 20) + # Smooth transition around z1=40
# Age modifies effectiveness (peak around age 45)
25 * exp(-((z2 - 45) / 18)^2) +
# Comorbidity creates complex interactions
z3 * (20 - 0.4 * z1 + 0.3 * z2) +
# Smooth interactions that linear models struggle with
0.008 * z1 * z2 + # Linear interaction
0.15 * z1 * z3 + # z1×z3 interaction
# Non-linear patterns
15 * sin(z1 * pi / 60) * (1 + 0.5 * z3) + # Sine modulation
10 * cos(z2 * pi / 50) * exp(-z1 / 100) + # Decaying cosine
# Regional "sweet spots" where treatment works exceptionally well
20 * exp(-((z1 - 30)^2 + (z2 - 40)^2) / 400) + # Gaussian sweet spot
15 * exp(-((z1 - 70)^2 + (z2 - 50)^2) / 500) * z3 # Another for comorbid patients
# Final outcome with moderate noise
y <- baseline_outcome + a * treatment_effect + rnorm(n_obs, 0, 25)
# Create dataset
data_complex <- data.frame(
z1 = z1, z2 = z2, z3 = z3,
a = a, y = y,
true_te = treatment_effect
)
# Summary statistics
summary_stats <- data_complex %>%
summarise(
n = n(),
z1_mean = mean(z1), z1_sd = sd(z1),
z2_mean = mean(z2), z2_sd = sd(z2),
z3_mean = mean(z3),
prop_treated = mean(a),
y_mean = mean(y), y_sd = sd(y),
te_mean = mean(true_te), te_sd = sd(true_te)
)
kable(summary_stats,
caption = "Table 2: Complex Data Summary Statistics",
col.names = c("N", "Z1 Mean", "Z1 SD", "Z2 Mean", "Z2 SD", "Z3 Mean",
"Prop. Treated", "Y Mean", "Y SD", "TE Mean", "TE SD"),
digits = 2) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"))
N | Z1 Mean | Z1 SD | Z2 Mean | Z2 SD | Z3 Mean | Prop. Treated | Y Mean | Y SD | TE Mean | TE SD |
---|---|---|---|---|---|---|---|---|---|---|
10000 | 49.89 | 29.07 | 44.85 | 14.47 | 0.3 | 0.38 | 225.94 | 71.52 | 116.45 | 39.91 |
# Treatment probability by viral load and age
p1 <- ggplot(data_complex, aes(x = z1, y = z2, color = factor(a))) +
geom_point(alpha = 0.4, size = 0.8) +
labs(title = "Treatment Assignment Pattern",
subtitle = "Complex non-linear assignment based on viral load and age",
x = "Viral Load (Z1)", y = "Age (Z2)", color = "Treatment") +
scale_color_manual(values = c("red", "blue"), labels = c("Untreated", "Treated")) +
theme_minimal()
# Treatment effect surface (for patients without comorbidity)
grid_size <- 40
z1_grid <- seq(0, 100, length.out = grid_size)
z2_grid <- seq(20, 70, length.out = grid_size)
te_surface <- expand.grid(z1 = z1_grid, z2 = z2_grid) %>%
mutate(z3 = 0) %>% # Show for z3 = 0
mutate(
te = 30 + 0.8 * z1 + 20 * tanh((z1 - 40) / 20) +
25 * exp(-((z2 - 45) / 18)^2) +
z3 * (20 - 0.4 * z1 + 0.3 * z2) +
0.008 * z1 * z2 +
0.15 * z1 * z3 +
15 * sin(z1 * pi / 60) * (1 + 0.5 * z3) +
10 * cos(z2 * pi / 50) * exp(-z1 / 100) +
20 * exp(-((z1 - 30)^2 + (z2 - 40)^2) / 400) +
15 * exp(-((z1 - 70)^2 + (z2 - 50)^2) / 500) * z3
)
p2 <- ggplot(te_surface, aes(x = z1, y = z2, fill = te)) +
geom_tile() +
scale_fill_gradient2(low = "blue", mid = "white", high = "red",
midpoint = median(te_surface$te), name = "Treatment\nEffect") +
labs(title = "True Treatment Effect Surface (No Comorbidity)",
subtitle = "Complex non-linear heterogeneity with interaction hotspots",
x = "Viral Load (Z1)", y = "Age (Z2)") +
theme_minimal()
grid.arrange(p1, p2, ncol = 2)
# Linear models (using all three confounders)
outcome_model_linear <- lm(y ~ a + z1 + z2 + z3 + a:z1, data = data_complex)
# Polynomial model with modest flexibility (doesn't mirror DGP structure)
outcome_model_poly <- lm(y ~ a * (z1 + I(z1^3) + z2 + z3) +
z1:z2 + z2:z3, data = data_complex)
# Random Forest models with proper hyperparameters
set.seed(123)
outcome_model_rf <- randomForest(y ~ a + z1 + z2 + z3, data = data_complex,
ntree = 2000, # More trees for stability
nodesize = 20, # Larger nodes (less overfitting)
mtry = 2, # Fewer variables per split
maxnodes = 200, # Limit complexity
importance = TRUE)
# Model performance summaries
cat("Model Performance Summary:\n")
## Model Performance Summary:
## Linear outcome R²: 0.789
## Polynomial outcome R²: 0.802
## RF outcome % Var Explained: 86.8 %
# Simulation setup
set.seed(456)
sim_n <- 10000
# Sample from marginal distributions of confounders
z1_sim <- sample(data_complex$z1, size = sim_n, replace = TRUE)
z2_sim <- sample(data_complex$z2, size = sim_n, replace = TRUE)
z3_sim <- sample(data_complex$z3, size = sim_n, replace = TRUE)
sim_data <- data.frame(z1 = z1_sim, z2 = z2_sim, z3 = z3_sim)
# Calculate true ATE on simulation sample
true_te_sim <- 30 + 0.8 * z1_sim + 20 * tanh((z1_sim - 40) / 20) +
25 * exp(-((z2_sim - 45) / 18)^2) +
z3_sim * (20 - 0.4 * z1_sim + 0.3 * z2_sim) +
0.008 * z1_sim * z2_sim +
0.15 * z1_sim * z3_sim +
15 * sin(z1_sim * pi / 60) * (1 + 0.5 * z3_sim) +
10 * cos(z2_sim * pi / 50) * exp(-z1_sim / 100) +
20 * exp(-((z1_sim - 30)^2 + (z2_sim - 40)^2) / 400) +
15 * exp(-((z1_sim - 70)^2 + (z2_sim - 50)^2) / 500) * z3_sim
true_ate <- mean(true_te_sim)
# Linear model predictions
y_linear_treated <- predict(outcome_model_linear, newdata = sim_data %>% mutate(a = 1))
y_linear_untreated <- predict(outcome_model_linear, newdata = sim_data %>% mutate(a = 0))
ate_linear <- mean(y_linear_treated) - mean(y_linear_untreated)
# Polynomial model predictions
y_poly_treated <- predict(outcome_model_poly, newdata = sim_data %>% mutate(a = 1))
y_poly_untreated <- predict(outcome_model_poly, newdata = sim_data %>% mutate(a = 0))
ate_poly <- mean(y_poly_treated) - mean(y_poly_untreated)
# Random Forest predictions
y_rf_treated <- predict(outcome_model_rf, newdata = sim_data %>% mutate(a = 1))
y_rf_untreated <- predict(outcome_model_rf, newdata = sim_data %>% mutate(a = 0))
ate_rf <- mean(y_rf_treated) - mean(y_rf_untreated)
# Crude estimate
crude_treated <- mean(data_complex$y[data_complex$a == 1])
crude_untreated <- mean(data_complex$y[data_complex$a == 0])
crude_ate <- crude_treated - crude_untreated
# Create comprehensive results table
results_comparison <- data.frame(
Method = c("True ATE", "Crude (Unadjusted)", "Linear G-Formula",
"Polynomial G-Formula", "Random Forest G-Formula"),
ATE_Estimate = c(round(true_ate, 2), round(crude_ate, 2), round(ate_linear, 2),
round(ate_poly, 2), round(ate_rf, 2)),
Bias = c(0, round(crude_ate - true_ate, 2), round(ate_linear - true_ate, 2),
round(ate_poly - true_ate, 2), round(ate_rf - true_ate, 2)),
Abs_Bias = c(0, round(abs(crude_ate - true_ate), 2), round(abs(ate_linear - true_ate), 2),
round(abs(ate_poly - true_ate), 2), round(abs(ate_rf - true_ate), 2)),
Notes = c("Oracle truth", "Ignores confounding", "Simple linear model",
"Cubic + interactions", "Flexible ML approach")
)
kable(results_comparison,
caption = "Table 3: G-Formula Results Comparison",
col.names = c("Method", "ATE Estimate", "Bias", "Absolute Bias", "Notes")) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"))
Method | ATE Estimate | Bias | Absolute Bias | Notes |
---|---|---|---|---|
True ATE | 115.94 | 0.00 | 0.00 | Oracle truth |
Crude (Unadjusted) | 125.83 | 9.89 | 9.89 | Ignores confounding |
Linear G-Formula | 127.52 | 11.58 | 11.58 | Simple linear model |
Polynomial G-Formula | 125.84 | 9.89 | 9.89 | Cubic + interactions |
Random Forest G-Formula | 117.27 | 1.33 | 1.33 | Flexible ML approach |
# Performance insights
linear_bias <- results_comparison$Abs_Bias[3] # Linear only
poly_bias <- results_comparison$Abs_Bias[4] # Polynomial only
rf_bias <- results_comparison$Abs_Bias[5] # Random Forest
best_parametric_bias <- min(linear_bias, poly_bias) # Best parametric approach
cat("\n*** KEY FINDINGS ***\n")
##
## *** KEY FINDINGS ***
## True ATE: 115.94
## Linear model bias: 11.58
## Polynomial model bias: 9.89
## Random Forest bias: 1.33
if(rf_bias < linear_bias) {
rf_vs_linear <- round(linear_bias / rf_bias, 1)
cat("✓ Random Forest performs", rf_vs_linear, "times better than linear model!\n")
} else {
cat("Linear model performs similarly to or better than Random Forest.\n")
}
## ✓ Random Forest performs 8.7 times better than linear model!
if(rf_bias < best_parametric_bias) {
rf_vs_best_param <- round(best_parametric_bias / rf_bias, 1)
cat("✓ Random Forest performs", rf_vs_best_param, "times better than best parametric method!\n")
} else {
cat("Parametric methods perform similarly to or better than Random Forest.\n")
}
## ✓ Random Forest performs 7.4 times better than best parametric method!
# Cross-validation comparison
set.seed(789)
cv_folds <- createFolds(data_complex$y, k = 5)
# Function to calculate RMSE
calculate_rmse <- function(actual, predicted) {
sqrt(mean((actual - predicted)^2, na.rm = TRUE))
}
# Initialize RMSE vectors
rmse_linear <- rmse_poly <- rmse_rf <- numeric(5)
# Cross-validation loop
for(i in 1:5) {
train_idx <- unlist(cv_folds[-i])
test_idx <- cv_folds[[i]]
train_data <- data_complex[train_idx, ]
test_data <- data_complex[test_idx, ]
# Fit models
linear_cv <- lm(y ~ a + z1 + z2 + z3 + a:z1, data = train_data)
poly_cv <- lm(y ~ a * (z1 + I(z1^3) + z2 + z3) + z1:z2 + z2:z3, data = train_data)
rf_cv <- randomForest(y ~ a + z1 + z2 + z3, data = train_data,
ntree = 1000, nodesize = 20, mtry = 2)
# Predictions
pred_linear <- predict(linear_cv, test_data)
pred_poly <- predict(poly_cv, test_data)
pred_rf <- predict(rf_cv, test_data)
# RMSE calculation
rmse_linear[i] <- calculate_rmse(test_data$y, pred_linear)
rmse_poly[i] <- calculate_rmse(test_data$y, pred_poly)
rmse_rf[i] <- calculate_rmse(test_data$y, pred_rf)
}
# Performance summary
performance_summary <- data.frame(
Model = c("Linear", "Polynomial", "Random Forest"),
Mean_RMSE = c(mean(rmse_linear), mean(rmse_poly), mean(rmse_rf)),
SD_RMSE = c(sd(rmse_linear), sd(rmse_poly), sd(rmse_rf)),
Improvement_vs_Linear = c("—",
paste0(round((mean(rmse_linear) - mean(rmse_poly))/mean(rmse_linear)*100, 1), "%"),
paste0(round((mean(rmse_linear) - mean(rmse_rf))/mean(rmse_linear)*100, 1), "%"))
)
kable(performance_summary,
caption = "Table 4: Cross-Validation Performance",
col.names = c("Model", "Mean RMSE", "SD RMSE", "Improvement"),
digits = 3) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"))
Model | Mean RMSE | SD RMSE | Improvement |
---|---|---|---|
Linear | 32.897 | 0.774 | — |
Polynomial | 31.846 | 0.759 | 3.2% |
Random Forest | 26.020 | 0.360 | 20.9% |
# Compare model predictions across viral load (fixing age=40, comorbidity=0)
z1_plot_range <- seq(0, 100, length.out = 100)
plot_data_fixed <- data.frame(z1 = z1_plot_range, z2 = 40, z3 = 0)
# Calculate true treatment effect for this slice
true_te_plot <- 30 + 0.8 * z1_plot_range + 20 * tanh((z1_plot_range - 40) / 20) +
25 * exp(-((40 - 45) / 18)^2) +
0 * (20 - 0.4 * z1_plot_range + 0.3 * 40) +
0.008 * z1_plot_range * 40 +
0.15 * z1_plot_range * 0 +
15 * sin(z1_plot_range * pi / 60) * (1 + 0.5 * 0) +
10 * cos(40 * pi / 50) * exp(-z1_plot_range / 100) +
20 * exp(-((z1_plot_range - 30)^2 + (40 - 40)^2) / 400) +
15 * exp(-((z1_plot_range - 70)^2 + (40 - 50)^2) / 500) * 0
# Model predictions
pred_linear_treated <- predict(outcome_model_linear, newdata = plot_data_fixed %>% mutate(a = 1))
pred_linear_untreated <- predict(outcome_model_linear, newdata = plot_data_fixed %>% mutate(a = 0))
te_linear <- pred_linear_treated - pred_linear_untreated
pred_poly_treated <- predict(outcome_model_poly, newdata = plot_data_fixed %>% mutate(a = 1))
pred_poly_untreated <- predict(outcome_model_poly, newdata = plot_data_fixed %>% mutate(a = 0))
te_poly <- pred_poly_treated - pred_poly_untreated
pred_rf_treated <- predict(outcome_model_rf, newdata = plot_data_fixed %>% mutate(a = 1))
pred_rf_untreated <- predict(outcome_model_rf, newdata = plot_data_fixed %>% mutate(a = 0))
te_rf <- pred_rf_treated - pred_rf_untreated
# Create plotting dataframe
te_df <- data.frame(
z1 = rep(z1_plot_range, 4),
treatment_effect = c(true_te_plot, te_linear, te_poly, te_rf),
model = rep(c("True Effect", "Linear", "Polynomial", "Random Forest"), each = 100)
)
# Plot treatment effects
p_te <- ggplot(te_df, aes(x = z1, y = treatment_effect, color = model, linetype = model)) +
geom_line(size = 1.2) +
labs(title = "Treatment Effect by Viral Load (Age=40, No Comorbidity)",
subtitle = "Model comparison showing Random Forest's ability to capture non-linear patterns",
x = "Viral Load (Z1)", y = "Treatment Effect",
color = "Model", linetype = "Model") +
theme_minimal() +
scale_color_manual(values = c("black", "red", "blue", "darkgreen")) +
scale_linetype_manual(values = c("solid", "dashed", "dotted", "solid")) +
theme(legend.position = "bottom")
print(p_te)
Simple data: Linear models outperform Random Forest when relationships are approximately linear. In this setting, Random Forest introduces unnecessary variance and yields higher bias in ATE estimation.
Complex data: Random Forest captures highly non-linear and heterogeneous treatment effects more accurately than linear or polynomial models. The improvement in ATE estimation compared to linear models and best parametric methods demonstrates the value of non-parametric approaches for complex relationships.
Prediction accuracy: Random Forest improves out-of-sample prediction accuracy by 20.9% over the linear model, based on cross-validated RMSE.
Random Forest excels when: