This tutorial demonstrates g-methods comparing traditional parametric approaches with machine learning methods. We’ll show when flexible models like random forests provide advantages over linear models in causal inference, particularly when relationships are highly non-linear.
We examine the effect of HIV treatment (\(A\)) on CD4 count (\(Y\)) using two datasets to demonstrate g-computation performance:
Our variables are:
In this setup, the true average treatment effect (ATE) is 50. First, let’s reproduce the basic example with simple relationships:
# Simple tabular data from original tutorial
<- c(0, 0, 1, 1)
z <- c(0, 1, 0, 1)
a <- c(100, 150, 80, 130)
y <- c(300, 200, 150, 350)
n
<- data.frame(z, a, y, n)
data_simple
kable(data_simple,
caption = "Table 1: Simple Cross-sectional Data",
col.names = c("Z (Viral Load)", "A (Treatment)", "Y (CD4 Count)", "N (Sample Size)")) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"))
Z (Viral Load) | A (Treatment) | Y (CD4 Count) | N (Sample Size) |
---|---|---|---|
0 | 0 | 100 | 300 |
0 | 1 | 150 | 200 |
1 | 0 | 80 | 150 |
1 | 1 | 130 | 350 |
# Calculate marginal distribution of Z
<- sum(n)
total_n <- data_simple %>%
z_marginal group_by(z) %>%
summarise(total = sum(n), prob = sum(n) / total_n, .groups = 'drop')
# Non-parametric G-formula
<- data_simple %>%
treated_outcomes filter(a == 1) %>%
left_join(z_marginal %>% select(z, prob), by = "z") %>%
mutate(weighted_y = y * prob)
<- data_simple %>%
untreated_outcomes filter(a == 0) %>%
left_join(z_marginal %>% select(z, prob), by = "z") %>%
mutate(weighted_y = y * prob)
<- sum(treated_outcomes$weighted_y) - sum(untreated_outcomes$weighted_y)
ate_simple_nonparam
# Parametric G-formula (expand data and fit models)
<- data_simple %>% slice(rep(row_number(), n))
data_expanded
# Fit models
<- lm(y ~ a + z, data = data_simple, weights = n)
y_model_simple
# Simulate and predict
set.seed(123)
<- 100000
sim_n <- sample(z_marginal$z, size = sim_n, replace = TRUE, prob = z_marginal$prob)
z_sim <- data.frame(z = z_sim)
sim_data
<- predict(y_model_simple, newdata = sim_data %>% mutate(a = 1))
y_treated_simple <- predict(y_model_simple, newdata = sim_data %>% mutate(a = 0))
y_untreated_simple <- mean(y_treated_simple) - mean(y_untreated_simple)
ate_simple_param
# Random Forest
<- randomForest(y ~ a + z, data = data_expanded, ntree = 500)
rf_model_simple <- predict(rf_model_simple, newdata = sim_data %>% mutate(a = 1))
y_rf_treated_simple <- predict(rf_model_simple, newdata = sim_data %>% mutate(a = 0))
y_rf_untreated_simple <- mean(y_rf_treated_simple) - mean(y_rf_untreated_simple)
ate_simple_rf
cat("Simple Data Results:\n")
## Simple Data Results:
cat("Non-parametric ATE:", round(ate_simple_nonparam, 2), "\n")
## Non-parametric ATE: 50
cat("Linear model ATE:", round(ate_simple_param, 2), "\n")
## Linear model ATE: 50
cat("Random Forest ATE:", round(ate_simple_rf, 2), "\n")
## Random Forest ATE: 36.38
Our variables are:
Now let’s create data with strong non-linear relationships to demonstrate when Random Forest excels:
set.seed(42)
<- 10000 # Enough sample for ML methods
n_obs
# Generate multiple confounders for richer interactions
<- runif(n_obs, 0, 100) # Viral load
z1 <- runif(n_obs, 20, 70) # Age
z2 <- rbinom(n_obs, 1, 0.5) # Rural (0/1)
z3
# Complex non-linear treatment assignment with smooth relationships
<- -1.5 +
treatment_logit # Smooth S-curve for viral load (sicker patients more likely to get treatment)
3 * plogis((z1 - 50) / 20) - 1.5 +
# Smooth inverse-U for age (middle-aged more likely to get treatment)
2 * exp(-((z2 - 45) / 15)^2) - 0.5 +
# Rural effect that varies smoothly with other variables
* (0.5 + 0.02 * z1 - 0.01 * z2) +
z3
# Smooth interaction surfaces
0.015 * z1 * (z2 - 45) / 25 + # Viral load × Age interaction
0.3 * sin(z1 * pi / 50) * cos(z2 * pi / 40) + # Trigonometric interaction
* 0.2 * cos((z1 + z2) * pi / 80) # Three-way interaction
z3
<- plogis(treatment_logit)
treatment_prob <- rbinom(n_obs, 1, treatment_prob)
a
# Complex outcome model with baseline health declining with viral load and age
<- 200 - 1.5 * z1 + 0.01 * z1^2 + # Viral load effect
baseline_outcome 50 * exp(-((z2 - 40) / 12)^2) + # Age effect (peak health ~40)
* (-30 - 0.3 * z1 + 0.5 * z2) # Rural penalty
z3
# Highly heterogeneous treatment effects - the key for demonstrating RF advantage
<-
treatment_effect # Base effect varies smoothly with viral load
30 + 0.8 * z1 + 20 * tanh((z1 - 40) / 20) + # Smooth transition around z1=40
# Age modifies effectiveness (peak around age 45)
25 * exp(-((z2 - 45) / 18)^2) +
# Rural creates complex interactions
* (20 - 0.4 * z1 + 0.3 * z2) +
z3
# Smooth interactions that linear models struggle with
0.008 * z1 * z2 + # Linear interaction
0.15 * z1 * z3 + # z1×z3 interaction
# Non-linear patterns
15 * sin(z1 * pi / 60) * (1 + 0.5 * z3) + # Sine modulation
10 * cos(z2 * pi / 50) * exp(-z1 / 100) + # Decaying cosine
# Regional "sweet spots" where treatment works exceptionally well
20 * exp(-((z1 - 30)^2 + (z2 - 40)^2) / 400) + # Gaussian sweet spot
15 * exp(-((z1 - 70)^2 + (z2 - 50)^2) / 500) * z3 # Another for Rural patients
# Final outcome with moderate noise
<- baseline_outcome + a * treatment_effect + rnorm(n_obs, 0, 25)
y
# Create dataset
<- data.frame(
data_complex z1 = z1, z2 = z2, z3 = z3,
a = a, y = y,
true_te = treatment_effect
)
# Summary statistics
<- data_complex %>%
summary_stats summarise(
n = n(),
z1_mean = mean(z1), z1_sd = sd(z1),
z2_mean = mean(z2), z2_sd = sd(z2),
z3_mean = mean(z3),
prop_treated = mean(a),
y_mean = mean(y), y_sd = sd(y),
te_mean = mean(true_te), te_sd = sd(true_te)
)
kable(summary_stats,
caption = "Table 2: Complex Data Summary Statistics",
col.names = c("N", "Z1 Mean", "Z1 SD", "Z2 Mean", "Z2 SD", "Z3 Mean",
"Prop. Treated", "Y Mean", "Y SD", "TE Mean", "TE SD"),
digits = 2) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"))
N | Z1 Mean | Z1 SD | Z2 Mean | Z2 SD | Z3 Mean | Prop. Treated | Y Mean | Y SD | TE Mean | TE SD |
---|---|---|---|---|---|---|---|---|---|---|
10000 | 49.89 | 29.07 | 44.85 | 14.47 | 0.49 | 0.41 | 227.73 | 72.93 | 121.54 | 39.12 |
# Treatment probability by viral load and age
<- ggplot(data_complex, aes(x = z1, y = z2, color = factor(a))) +
p1 geom_point(alpha = 0.4, size = 0.8) +
labs(title = "Treatment Assignment Pattern",
subtitle = "Complex non-linear assignment based on viral load and age",
x = "Viral Load (Z1)", y = "Age (Z2)", color = "Treatment") +
scale_color_manual(values = c("red", "blue"), labels = c("Untreated", "Treated")) +
theme_minimal()
# Treatment effect surface (for patients without Rural)
<- 40
grid_size <- seq(0, 100, length.out = grid_size)
z1_grid <- seq(20, 70, length.out = grid_size)
z2_grid
<- expand.grid(z1 = z1_grid, z2 = z2_grid) %>%
te_surface mutate(z3 = 0) %>% # Show for z3 = 0
mutate(
te = 30 + 0.8 * z1 + 20 * tanh((z1 - 40) / 20) +
25 * exp(-((z2 - 45) / 18)^2) +
* (20 - 0.4 * z1 + 0.3 * z2) +
z3 0.008 * z1 * z2 +
0.15 * z1 * z3 +
15 * sin(z1 * pi / 60) * (1 + 0.5 * z3) +
10 * cos(z2 * pi / 50) * exp(-z1 / 100) +
20 * exp(-((z1 - 30)^2 + (z2 - 40)^2) / 400) +
15 * exp(-((z1 - 70)^2 + (z2 - 50)^2) / 500) * z3
)
<- ggplot(te_surface, aes(x = z1, y = z2, fill = te)) +
p2 geom_tile() +
scale_fill_gradient2(low = "blue", mid = "white", high = "red",
midpoint = median(te_surface$te), name = "Treatment\nEffect") +
labs(title = "True Treatment Effect Surface (No Rural)",
subtitle = "Complex non-linear heterogeneity with interaction hotspots",
x = "Viral Load (Z1)", y = "Age (Z2)") +
theme_minimal()
grid.arrange(p1, p2, ncol = 2)
# Linear models (using all three confounders)
<- lm(y ~ a + z1 + z2 + z3, data = data_complex)
outcome_model_linear
# Polynomial model with modest flexibility (doesn't mirror DGP structure)
<- lm(y ~ a * (z1 + I(z1^3) + z2 + z3) +
outcome_model_poly :z2 + z2:z3, data = data_complex)
z1
# Random Forest models with proper hyperparameters
set.seed(123)
<- randomForest(y ~ a + z1 + z2 + z3, data = data_complex,
outcome_model_rf ntree = 2000, # More trees for stability
nodesize = 20, # Larger nodes (less overfitting)
mtry = 2, # Fewer variables per split
maxnodes = 200, # Limit complexity
importance = TRUE)
# Model performance summaries
cat("Model Performance Summary:\n")
## Model Performance Summary:
cat("Linear outcome R²:", round(summary(outcome_model_linear)$r.squared, 3), "\n")
## Linear outcome R²: 0.756
cat("Polynomial outcome R²:", round(summary(outcome_model_poly)$r.squared, 3), "\n")
## Polynomial outcome R²: 0.807
cat("RF outcome % Var Explained:", round(tail(outcome_model_rf$rsq, 1), 3), "\n")
## RF outcome % Var Explained: 0.873
# Simulation setup
set.seed(456)
<- 10000
sim_n
# Sample from marginal distributions of confounders
<- sample(data_complex$z1, size = sim_n, replace = TRUE)
z1_sim <- sample(data_complex$z2, size = sim_n, replace = TRUE)
z2_sim <- sample(data_complex$z3, size = sim_n, replace = TRUE)
z3_sim
<- data.frame(z1 = z1_sim, z2 = z2_sim, z3 = z3_sim)
sim_data
# Calculate true ATE on simulation sample
<- 30 + 0.8 * z1_sim + 20 * tanh((z1_sim - 40) / 20) +
true_te_sim 25 * exp(-((z2_sim - 45) / 18)^2) +
* (20 - 0.4 * z1_sim + 0.3 * z2_sim) +
z3_sim 0.008 * z1_sim * z2_sim +
0.15 * z1_sim * z3_sim +
15 * sin(z1_sim * pi / 60) * (1 + 0.5 * z3_sim) +
10 * cos(z2_sim * pi / 50) * exp(-z1_sim / 100) +
20 * exp(-((z1_sim - 30)^2 + (z2_sim - 40)^2) / 400) +
15 * exp(-((z1_sim - 70)^2 + (z2_sim - 50)^2) / 500) * z3_sim
<- mean(true_te_sim)
true_ate
# Linear model predictions
<- predict(outcome_model_linear, newdata = sim_data %>% mutate(a = 1))
y_linear_treated <- predict(outcome_model_linear, newdata = sim_data %>% mutate(a = 0))
y_linear_untreated <- mean(y_linear_treated) - mean(y_linear_untreated)
ate_linear
# Polynomial model predictions
<- predict(outcome_model_poly, newdata = sim_data %>% mutate(a = 1))
y_poly_treated <- predict(outcome_model_poly, newdata = sim_data %>% mutate(a = 0))
y_poly_untreated <- mean(y_poly_treated) - mean(y_poly_untreated)
ate_poly
# Random Forest predictions
<- predict(outcome_model_rf, newdata = sim_data %>% mutate(a = 1))
y_rf_treated <- predict(outcome_model_rf, newdata = sim_data %>% mutate(a = 0))
y_rf_untreated <- mean(y_rf_treated) - mean(y_rf_untreated)
ate_rf
# Crude estimate
<- mean(data_complex$y[data_complex$a == 1])
crude_treated <- mean(data_complex$y[data_complex$a == 0])
crude_untreated <- crude_treated - crude_untreated crude_ate
# Create comprehensive results table
<- data.frame(
results_comparison Method = c("True ATE", "Crude (Unadjusted)", "Linear G-Formula",
"Polynomial G-Formula", "Random Forest G-Formula"),
ATE_Estimate = c(round(true_ate, 2), round(crude_ate, 2), round(ate_linear, 2),
round(ate_poly, 2), round(ate_rf, 2)),
Bias = c(0, round(crude_ate - true_ate, 2), round(ate_linear - true_ate, 2),
round(ate_poly - true_ate, 2), round(ate_rf - true_ate, 2)),
Abs_Bias = c(0, round(abs(crude_ate - true_ate), 2), round(abs(ate_linear - true_ate), 2),
round(abs(ate_poly - true_ate), 2), round(abs(ate_rf - true_ate), 2)),
Notes = c("Oracle truth", "Ignores confounding", "Simple linear model",
"Cubic + interactions", "Flexible ML approach")
)
kable(results_comparison,
caption = "Table 3: G-Formula Results Comparison",
col.names = c("Method", "ATE Estimate", "Bias", "Absolute Bias", "Notes")) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"))
Method | ATE Estimate | Bias | Absolute Bias | Notes |
---|---|---|---|---|
True ATE | 120.93 | 0.00 | 0.00 | Oracle truth |
Crude (Unadjusted) | 127.05 | 6.12 | 6.12 | Ignores confounding |
Linear G-Formula | 138.60 | 17.67 | 17.67 | Simple linear model |
Polynomial G-Formula | 130.19 | 9.26 | 9.26 | Cubic + interactions |
Random Forest G-Formula | 121.27 | 0.35 | 0.35 | Flexible ML approach |
# Performance insights
<- results_comparison$Abs_Bias[3] # Linear only
linear_bias <- results_comparison$Abs_Bias[4] # Polynomial only
poly_bias <- results_comparison$Abs_Bias[5] # Random Forest
rf_bias <- min(linear_bias, poly_bias) # Best parametric approach
best_parametric_bias
cat("\n*** KEY FINDINGS ***\n")
##
## *** KEY FINDINGS ***
cat("True ATE:", round(true_ate, 2), "\n")
## True ATE: 120.93
cat("Linear model bias:", round(linear_bias, 2), "\n")
## Linear model bias: 17.67
cat("Polynomial model bias:", round(poly_bias, 2), "\n")
## Polynomial model bias: 9.26
cat("Random Forest bias:", round(rf_bias, 2), "\n")
## Random Forest bias: 0.35
if(rf_bias < linear_bias) {
<- round(linear_bias / rf_bias, 1)
rf_vs_linear cat("✓ Random Forest performs", rf_vs_linear, "times better than linear model!\n")
else {
} cat("Linear model performs similarly to or better than Random Forest.\n")
}
## ✓ Random Forest performs 50.5 times better than linear model!
if(rf_bias < best_parametric_bias) {
<- round(best_parametric_bias / rf_bias, 1)
rf_vs_best_param cat("✓ Random Forest performs", rf_vs_best_param, "times better than best parametric method!\n")
else {
} cat("Parametric methods perform similarly to or better than Random Forest.\n")
}
## ✓ Random Forest performs 26.5 times better than best parametric method!
# Cross-validation comparison
set.seed(789)
<- createFolds(data_complex$y, k = 5)
cv_folds
# Function to calculate RMSE
<- function(actual, predicted) {
calculate_rmse sqrt(mean((actual - predicted)^2, na.rm = TRUE))
}
# Initialize RMSE vectors
<- rmse_poly <- rmse_rf <- numeric(5)
rmse_linear
# Cross-validation loop
for(i in 1:5) {
<- unlist(cv_folds[-i])
train_idx <- cv_folds[[i]]
test_idx
<- data_complex[train_idx, ]
train_data <- data_complex[test_idx, ]
test_data
# Fit models
<- lm(y ~ a + z1 + z2 + z3, data = train_data)
linear_cv <- lm(y ~ a * (z1 + I(z1^3) + z2 + z3) + z1:z2 + z2:z3, data = train_data)
poly_cv <- randomForest(y ~ a + z1 + z2 + z3, data = train_data,
rf_cv ntree = 1000, nodesize = 20, mtry = 2)
# Predictions
<- predict(linear_cv, test_data)
pred_linear <- predict(poly_cv, test_data)
pred_poly <- predict(rf_cv, test_data)
pred_rf
# RMSE calculation
<- calculate_rmse(test_data$y, pred_linear)
rmse_linear[i] <- calculate_rmse(test_data$y, pred_poly)
rmse_poly[i] <- calculate_rmse(test_data$y, pred_rf)
rmse_rf[i]
}
# Performance summary
<- data.frame(
performance_summary Model = c("Linear", "Polynomial", "Random Forest"),
Mean_RMSE = c(mean(rmse_linear), mean(rmse_poly), mean(rmse_rf)),
SD_RMSE = c(sd(rmse_linear), sd(rmse_poly), sd(rmse_rf)),
Improvement_vs_Linear = c("—",
paste0(round((mean(rmse_linear) - mean(rmse_poly))/mean(rmse_linear)*100, 1), "%"),
paste0(round((mean(rmse_linear) - mean(rmse_rf))/mean(rmse_linear)*100, 1), "%"))
)
kable(performance_summary,
caption = "Table 4: Cross-Validation Performance",
col.names = c("Model", "Mean RMSE", "SD RMSE", "Improvement"),
digits = 3) %>%
kable_styling(bootstrap_options = c("striped", "hover", "condensed"))
Model | Mean RMSE | SD RMSE | Improvement |
---|---|---|---|
Linear | 36.065 | 0.690 | — |
Polynomial | 32.106 | 0.309 | 11% |
Random Forest | 26.076 | 0.266 | 27.7% |
# Compare model predictions across viral load (fixing age=40, rural=0)
<- seq(0, 100, length.out = 100)
z1_plot_range <- data.frame(z1 = z1_plot_range, z2 = 40, z3 = 0)
plot_data_fixed
# Calculate true treatment effect for this slice
<- 30 + 0.8 * z1_plot_range + 20 * tanh((z1_plot_range - 40) / 20) +
true_te_plot 25 * exp(-((40 - 45) / 18)^2) +
0 * (20 - 0.4 * z1_plot_range + 0.3 * 40) +
0.008 * z1_plot_range * 40 +
0.15 * z1_plot_range * 0 +
15 * sin(z1_plot_range * pi / 60) * (1 + 0.5 * 0) +
10 * cos(40 * pi / 50) * exp(-z1_plot_range / 100) +
20 * exp(-((z1_plot_range - 30)^2 + (40 - 40)^2) / 400) +
15 * exp(-((z1_plot_range - 70)^2 + (40 - 50)^2) / 500) * 0
# Model predictions
<- predict(outcome_model_linear, newdata = plot_data_fixed %>% mutate(a = 1))
pred_linear_treated <- predict(outcome_model_linear, newdata = plot_data_fixed %>% mutate(a = 0))
pred_linear_untreated <- pred_linear_treated - pred_linear_untreated
te_linear
<- predict(outcome_model_poly, newdata = plot_data_fixed %>% mutate(a = 1))
pred_poly_treated <- predict(outcome_model_poly, newdata = plot_data_fixed %>% mutate(a = 0))
pred_poly_untreated <- pred_poly_treated - pred_poly_untreated
te_poly
<- predict(outcome_model_rf, newdata = plot_data_fixed %>% mutate(a = 1))
pred_rf_treated <- predict(outcome_model_rf, newdata = plot_data_fixed %>% mutate(a = 0))
pred_rf_untreated <- pred_rf_treated - pred_rf_untreated
te_rf
# Create plotting dataframe
<- data.frame(
te_df z1 = rep(z1_plot_range, 4),
treatment_effect = c(true_te_plot, te_linear, te_poly, te_rf),
model = rep(c("True Effect", "Linear", "Polynomial", "Random Forest"), each = 100)
)
# Plot treatment effects
<- ggplot(te_df, aes(x = z1, y = treatment_effect, color = model, linetype = model)) +
p_te geom_line(size = 1.2) +
labs(title = "Treatment Effect by Viral Load (Age=40, No Rural)",
subtitle = "Model comparison showing Random Forest's ability to capture non-linear patterns",
x = "Viral Load (Z1)", y = "Treatment Effect",
color = "Model", linetype = "Model") +
theme_minimal() +
scale_color_manual(values = c("black", "red", "blue", "darkgreen")) +
scale_linetype_manual(values = c("solid", "dashed", "dotted", "solid")) +
theme(legend.position = "bottom")
print(p_te)
Simple data: Linear models outperform Random Forest when relationships are approximately linear. In this setting, Random Forest introduces unnecessary variance and yields higher bias in ATE estimation.
Complex data: Random Forest captures highly non-linear and heterogeneous treatment effects more accurately than linear or polynomial models. The improvement in ATE estimation compared to linear models and best parametric methods demonstrates the value of non-parametric approaches for complex relationships.
Generally, ML models perform well when: