Skip to content

Commit 5fd4776

Browse files
Revise README for clarity and additional details
Updated README to improve clarity and structure, added installation instructions, example usage, and detailed steps for nuisance estimation and R-Loss CATE learning.
1 parent 77c232b commit 5fd4776

File tree

1 file changed

+194
-2
lines changed

1 file changed

+194
-2
lines changed

README.md

Lines changed: 194 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,200 @@
1-
# Python based R Learner
1+
# Python-based R Learner
22

3-
## Installation
3+
A Python package named `rlearner` that runs the R learner ([Nie and Wager, 2021](#ref-nie-wager-2021)) for heterogeneous treatment effect estimation and validation with flexible choices.
4+
5+
## Set Up
46

57
Install the package via pip:
68

79
```bash
810
pip install "git+https://github.com/andyjiayuwang/Python-based-R-Learner.git"
11+
```
12+
13+
A full example workflow is available in [`demo.ipynb`](demo.ipynb).
14+
15+
A minimal import example is:
16+
17+
```python
18+
from rlearner import (
19+
CrossFittedNuisanceEstimator,
20+
RLearner,
21+
RLossStacking,
22+
SuperLearnerClassifier,
23+
SuperLearnerRegressor,
24+
)
25+
```
26+
27+
## Step 1: Nuisance Estimation
28+
29+
The first step estimates the nuisance functions needed by the R learner:
30+
31+
- `m(X) = E[Y | X]`, the outcome regression
32+
- `e(X) = E[W | X]`, the propensity score
33+
34+
These nuisance estimates are used to build the residualized quantities
35+
36+
- `Y_tilde = Y - m_hat(X)`
37+
- `W_tilde = W - e_hat(X)`
38+
39+
which are then passed into the second-stage R-loss optimization.
40+
41+
The package provides two ways to handle step 1.
42+
43+
### Built-in nuisance estimation
44+
45+
Use `CrossFittedNuisanceEstimator` when you want the package to fit nuisance models directly. It supports:
46+
47+
- K-fold cross-fitting for both the outcome model and the treatment model
48+
- Any sklearn-style regressor for the outcome model
49+
- Any sklearn-style binary classifier with `predict_proba` for the treatment model
50+
- Optional grid search on the full nuisance model object through `outcome_param_grid` and `treatment_param_grid`
51+
- Full-sample refitting after cross-fitting so the fitted nuisance models can be reused for prediction
52+
53+
Default settings for `CrossFittedNuisanceEstimator` are:
54+
55+
- `n_folds=10`
56+
- `shuffle=True`
57+
- `random_state=42`
58+
- `propensity_clip=1e-6`
59+
- `stratify_treatment=True`
60+
- `refit_full=True`
61+
- `outcome_search_cv=5`
62+
- `treatment_search_cv=5`
63+
- `treatment_scoring="neg_log_loss"`
64+
65+
### Manual nuisance inputs
66+
67+
Use manual nuisance inputs when you already have trusted out-of-fold nuisance predictions from an external workflow. In that case, pass:
68+
69+
- `y_hat`, the out-of-fold estimate of `m(X)`
70+
- `d_hat`, the out-of-fold estimate of `e(X)`
71+
72+
through `ManualNuisanceEstimator` or directly through `RLearner.fit(..., y_hat=..., d_hat=...)`.
73+
74+
### Constrained super learner for step 1
75+
76+
The package also provides constrained super learners for nuisance prediction:
77+
78+
- `SuperLearnerRegressor`
79+
- `SuperLearnerClassifier`
80+
81+
These models support:
82+
83+
- Multiple base learners
84+
- Nonnegative stacking weights
85+
- Optional normalization of weights to sum to 1 through `normalize_weights=True`
86+
- Separate grid search for each base learner via `estimator_param_grids`
87+
- Stable internal sample splitting for hyperparameter tuning
88+
- Weight inspection through `get_weights()`
89+
- Best-parameter inspection through `get_best_params()`
90+
91+
Default settings for the super learners are:
92+
93+
- `search_cv=5`
94+
- `search_shuffle=True`
95+
- `random_state=42`
96+
- `normalize_weights=False`
97+
- `tolerance=1e-10`
98+
- `max_iter=1000`
99+
100+
For treatment prediction, the built-in step 1 implementation currently assumes a binary treatment indicator.
101+
102+
## Step 2: R-Loss CATE Learning
103+
104+
The second step learns the conditional average treatment effect `tau(X)` using the residualized outcome and treatment from step 1.
105+
106+
The package provides two main components for this stage.
107+
108+
### Single second-stage learner
109+
110+
Use `RLossWrapper` to fit a single sklearn-style regressor under the R-loss construction. This is the simplest way to estimate a single CATE model once `Y_tilde` and `W_tilde` are available.
111+
112+
### Multiple second-stage learners plus stacking
113+
114+
Use `RLearner` with `cate_learners={...}` when you want to fit multiple second-stage learners and combine them. The package then:
115+
116+
- Fits one `RLossWrapper` per learner
117+
- Produces one CATE estimate from each learner
118+
- Optionally combines them with `RLossStacking`
119+
120+
`RLossStacking` follows the positive linear-combination idea used in the R-loss stacking step. The fitted object reports:
121+
122+
- `a_hat`, the constant shift term
123+
- `b_hat`, the scale of the coefficient vector
124+
- `alpha_hat`, the nonnegative relative weights of the second-stage learners
125+
126+
Default settings for `RLossStacking` are:
127+
128+
- `lambda_reg=1.0`
129+
- `tolerance=1e-10`
130+
- `max_iter=1000`
131+
132+
In step 2, the stacking weights are constrained to be nonnegative, but they are not required to sum to 1.
133+
134+
## Step 3: Validation and Diagnostics
135+
136+
The third step validates the fitted treatment-effect model using the out-of-fold nuisance estimates and the fitted CATE predictions. The validation routines implemented here follow the discussions in Chernozhukov et al. (2024).
137+
138+
All validation routines are available in two ways:
139+
140+
- As standalone functions in `rlearner`
141+
- As convenience methods on a fitted `RLearner` instance
142+
143+
### BLP test
144+
145+
The BLP test runs the no-intercept regression
146+
147+
- `Y_tilde = alpha * W_tilde + beta * W_tilde * tau_hat(X)`
148+
149+
and reports:
150+
151+
- Point estimates for `alpha` and `beta`
152+
- HC2 standard errors
153+
- Normal-based z statistics
154+
- p-values
155+
- Confidence intervals
156+
157+
Default setting:
158+
159+
- `confidence_level=0.95`
160+
161+
### Calibration test
162+
163+
The calibration test bins observations by predicted treatment effect and compares:
164+
165+
- The average predicted treatment effect within each bin
166+
- The doubly robust bin-level treatment effect estimate
167+
168+
It returns both:
169+
170+
- `CAL_1`, the weighted L1 calibration criterion
171+
- `CAL_2`, the weighted L2 calibration criterion
172+
173+
and also exposes the full bin-level table.
174+
175+
Default setting:
176+
177+
- `n_bins=5`
178+
179+
### Uplift test
180+
181+
The uplift test performs ranking-based validation using a DR uplift curve. Observations are sorted by `tau_hat(X)` from high to low, top-fraction subgroups are formed, and a DR subgroup effect is computed for each fraction.
182+
183+
The output includes:
184+
185+
- The uplift curve table `(fraction, subgroup size, theta_dr)`
186+
- `AUUC`, the area under the uplift curve
187+
188+
Default setting:
189+
190+
- `fractions = 0.1, 0.2, ..., 1.0`
191+
192+
## Notes
193+
194+
- The import name is `rlearner`, even though the GitHub repository is named `Python-based-R-Learner`.
195+
- The package currently declares support for Python `>=3.10`.
196+
197+
## References
198+
199+
- <a id="ref-nie-wager-2021"></a>Nie, X., & Wager, S. (2021). Quasi-oracle estimation of heterogeneous treatment effects. *Biometrika*, 108(2), 299-319.
200+
- <a id="ref-chernozhukov-et-al-2024"></a>Chernozhukov, V., Hansen, C., Kallus, N., Spindler, M., & Syrgkanis, V. (2024). *Applied causal inference powered by ML and AI*. arXiv preprint arXiv:2403.02467.

0 commit comments

Comments
 (0)