|
1 | | -# Python based R Learner |
| 1 | +# Python-based R Learner |
2 | 2 |
|
3 | | -## Installation |
| 3 | +A Python package named `rlearner` that runs the R learner ([Nie and Wager, 2021](#ref-nie-wager-2021)) for heterogeneous treatment effect estimation and validation with flexible choices. |
| 4 | + |
| 5 | +## Set Up |
4 | 6 |
|
5 | 7 | Install the package via pip: |
6 | 8 |
|
7 | 9 | ```bash |
8 | 10 | pip install "git+https://github.com/andyjiayuwang/Python-based-R-Learner.git" |
| 11 | +``` |
| 12 | + |
| 13 | +A full example workflow is available in [`demo.ipynb`](demo.ipynb). |
| 14 | + |
| 15 | +A minimal import example is: |
| 16 | + |
| 17 | +```python |
| 18 | +from rlearner import ( |
| 19 | + CrossFittedNuisanceEstimator, |
| 20 | + RLearner, |
| 21 | + RLossStacking, |
| 22 | + SuperLearnerClassifier, |
| 23 | + SuperLearnerRegressor, |
| 24 | +) |
| 25 | +``` |
| 26 | + |
| 27 | +## Step 1: Nuisance Estimation |
| 28 | + |
| 29 | +The first step estimates the nuisance functions needed by the R learner: |
| 30 | + |
| 31 | +- `m(X) = E[Y | X]`, the outcome regression |
| 32 | +- `e(X) = E[W | X]`, the propensity score |
| 33 | + |
| 34 | +These nuisance estimates are used to build the residualized quantities |
| 35 | + |
| 36 | +- `Y_tilde = Y - m_hat(X)` |
| 37 | +- `W_tilde = W - e_hat(X)` |
| 38 | + |
| 39 | +which are then passed into the second-stage R-loss optimization. |
| 40 | + |
| 41 | +The package provides two ways to handle step 1. |
| 42 | + |
| 43 | +### Built-in nuisance estimation |
| 44 | + |
| 45 | +Use `CrossFittedNuisanceEstimator` when you want the package to fit nuisance models directly. It supports: |
| 46 | + |
| 47 | +- K-fold cross-fitting for both the outcome model and the treatment model |
| 48 | +- Any sklearn-style regressor for the outcome model |
| 49 | +- Any sklearn-style binary classifier with `predict_proba` for the treatment model |
| 50 | +- Optional grid search on the full nuisance model object through `outcome_param_grid` and `treatment_param_grid` |
| 51 | +- Full-sample refitting after cross-fitting so the fitted nuisance models can be reused for prediction |
| 52 | + |
| 53 | +Default settings for `CrossFittedNuisanceEstimator` are: |
| 54 | + |
| 55 | +- `n_folds=10` |
| 56 | +- `shuffle=True` |
| 57 | +- `random_state=42` |
| 58 | +- `propensity_clip=1e-6` |
| 59 | +- `stratify_treatment=True` |
| 60 | +- `refit_full=True` |
| 61 | +- `outcome_search_cv=5` |
| 62 | +- `treatment_search_cv=5` |
| 63 | +- `treatment_scoring="neg_log_loss"` |
| 64 | + |
| 65 | +### Manual nuisance inputs |
| 66 | + |
| 67 | +Use manual nuisance inputs when you already have trusted out-of-fold nuisance predictions from an external workflow. In that case, pass: |
| 68 | + |
| 69 | +- `y_hat`, the out-of-fold estimate of `m(X)` |
| 70 | +- `d_hat`, the out-of-fold estimate of `e(X)` |
| 71 | + |
| 72 | +through `ManualNuisanceEstimator` or directly through `RLearner.fit(..., y_hat=..., d_hat=...)`. |
| 73 | + |
| 74 | +### Constrained super learner for step 1 |
| 75 | + |
| 76 | +The package also provides constrained super learners for nuisance prediction: |
| 77 | + |
| 78 | +- `SuperLearnerRegressor` |
| 79 | +- `SuperLearnerClassifier` |
| 80 | + |
| 81 | +These models support: |
| 82 | + |
| 83 | +- Multiple base learners |
| 84 | +- Nonnegative stacking weights |
| 85 | +- Optional normalization of weights to sum to 1 through `normalize_weights=True` |
| 86 | +- Separate grid search for each base learner via `estimator_param_grids` |
| 87 | +- Stable internal sample splitting for hyperparameter tuning |
| 88 | +- Weight inspection through `get_weights()` |
| 89 | +- Best-parameter inspection through `get_best_params()` |
| 90 | + |
| 91 | +Default settings for the super learners are: |
| 92 | + |
| 93 | +- `search_cv=5` |
| 94 | +- `search_shuffle=True` |
| 95 | +- `random_state=42` |
| 96 | +- `normalize_weights=False` |
| 97 | +- `tolerance=1e-10` |
| 98 | +- `max_iter=1000` |
| 99 | + |
| 100 | +For treatment prediction, the built-in step 1 implementation currently assumes a binary treatment indicator. |
| 101 | + |
| 102 | +## Step 2: R-Loss CATE Learning |
| 103 | + |
| 104 | +The second step learns the conditional average treatment effect `tau(X)` using the residualized outcome and treatment from step 1. |
| 105 | + |
| 106 | +The package provides two main components for this stage. |
| 107 | + |
| 108 | +### Single second-stage learner |
| 109 | + |
| 110 | +Use `RLossWrapper` to fit a single sklearn-style regressor under the R-loss construction. This is the simplest way to estimate a single CATE model once `Y_tilde` and `W_tilde` are available. |
| 111 | + |
| 112 | +### Multiple second-stage learners plus stacking |
| 113 | + |
| 114 | +Use `RLearner` with `cate_learners={...}` when you want to fit multiple second-stage learners and combine them. The package then: |
| 115 | + |
| 116 | +- Fits one `RLossWrapper` per learner |
| 117 | +- Produces one CATE estimate from each learner |
| 118 | +- Optionally combines them with `RLossStacking` |
| 119 | + |
| 120 | +`RLossStacking` follows the positive linear-combination idea used in the R-loss stacking step. The fitted object reports: |
| 121 | + |
| 122 | +- `a_hat`, the constant shift term |
| 123 | +- `b_hat`, the scale of the coefficient vector |
| 124 | +- `alpha_hat`, the nonnegative relative weights of the second-stage learners |
| 125 | + |
| 126 | +Default settings for `RLossStacking` are: |
| 127 | + |
| 128 | +- `lambda_reg=1.0` |
| 129 | +- `tolerance=1e-10` |
| 130 | +- `max_iter=1000` |
| 131 | + |
| 132 | +In step 2, the stacking weights are constrained to be nonnegative, but they are not required to sum to 1. |
| 133 | + |
| 134 | +## Step 3: Validation and Diagnostics |
| 135 | + |
| 136 | +The third step validates the fitted treatment-effect model using the out-of-fold nuisance estimates and the fitted CATE predictions. The validation routines implemented here follow the discussions in Chernozhukov et al. (2024). |
| 137 | + |
| 138 | +All validation routines are available in two ways: |
| 139 | + |
| 140 | +- As standalone functions in `rlearner` |
| 141 | +- As convenience methods on a fitted `RLearner` instance |
| 142 | + |
| 143 | +### BLP test |
| 144 | + |
| 145 | +The BLP test runs the no-intercept regression |
| 146 | + |
| 147 | +- `Y_tilde = alpha * W_tilde + beta * W_tilde * tau_hat(X)` |
| 148 | + |
| 149 | +and reports: |
| 150 | + |
| 151 | +- Point estimates for `alpha` and `beta` |
| 152 | +- HC2 standard errors |
| 153 | +- Normal-based z statistics |
| 154 | +- p-values |
| 155 | +- Confidence intervals |
| 156 | + |
| 157 | +Default setting: |
| 158 | + |
| 159 | +- `confidence_level=0.95` |
| 160 | + |
| 161 | +### Calibration test |
| 162 | + |
| 163 | +The calibration test bins observations by predicted treatment effect and compares: |
| 164 | + |
| 165 | +- The average predicted treatment effect within each bin |
| 166 | +- The doubly robust bin-level treatment effect estimate |
| 167 | + |
| 168 | +It returns both: |
| 169 | + |
| 170 | +- `CAL_1`, the weighted L1 calibration criterion |
| 171 | +- `CAL_2`, the weighted L2 calibration criterion |
| 172 | + |
| 173 | +and also exposes the full bin-level table. |
| 174 | + |
| 175 | +Default setting: |
| 176 | + |
| 177 | +- `n_bins=5` |
| 178 | + |
| 179 | +### Uplift test |
| 180 | + |
| 181 | +The uplift test performs ranking-based validation using a DR uplift curve. Observations are sorted by `tau_hat(X)` from high to low, top-fraction subgroups are formed, and a DR subgroup effect is computed for each fraction. |
| 182 | + |
| 183 | +The output includes: |
| 184 | + |
| 185 | +- The uplift curve table `(fraction, subgroup size, theta_dr)` |
| 186 | +- `AUUC`, the area under the uplift curve |
| 187 | + |
| 188 | +Default setting: |
| 189 | + |
| 190 | +- `fractions = 0.1, 0.2, ..., 1.0` |
| 191 | + |
| 192 | +## Notes |
| 193 | + |
| 194 | +- The import name is `rlearner`, even though the GitHub repository is named `Python-based-R-Learner`. |
| 195 | +- The package currently declares support for Python `>=3.10`. |
| 196 | + |
| 197 | +## References |
| 198 | + |
| 199 | +- <a id="ref-nie-wager-2021"></a>Nie, X., & Wager, S. (2021). Quasi-oracle estimation of heterogeneous treatment effects. *Biometrika*, 108(2), 299-319. |
| 200 | +- <a id="ref-chernozhukov-et-al-2024"></a>Chernozhukov, V., Hansen, C., Kallus, N., Spindler, M., & Syrgkanis, V. (2024). *Applied causal inference powered by ML and AI*. arXiv preprint arXiv:2403.02467. |
0 commit comments