diff --git a/README.md b/README.md index 72a7810..d1d30b9 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@ [![Tests](https://github.com/gcskoenig/fippy/actions/workflows/python-package.yml/badge.svg)](https://github.com/gcskoenig/fippy/actions/workflows/python-package.yml) [![PyPI](https://img.shields.io/pypi/v/fippy)](https://pypi.org/project/fippy/) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) -[![Python](https://img.shields.io/pypi/pyversions/fippy)](https://pypi.org/project/fippy/) +[![Python](https://img.shields.io/badge/python-3.11%20--%203.13-blue)](https://pypi.org/project/fippy/) -A Python package for model-agnostic feature importance with statistical inference. Fippy implements a unified framework where feature importance methods are composed from three orthogonal axes: +A Python package for model-agnostic feature importance with statistical inference. Fippy implements a unified framework where feature importance methods are composed from three axes: | Axis | Options | Description | |---|---|---| diff --git a/src/fippy/explainer.py b/src/fippy/explainer.py index 74cde76..6e0334f 100644 --- a/src/fippy/explainer.py +++ b/src/fippy/explainer.py @@ -67,6 +67,12 @@ def loo( y = np.asarray(y) groups = self._resolve_features(features) + if distribution is not None: + sampler = self._get_sampler(distribution) + if hasattr(sampler, 'check_compatibility'): + requires_multi = any(len(g.columns) > 1 for g in groups) + sampler.check_compatibility(requires_multivariate=requires_multi) + n_obs = len(X) scores = np.empty((1, n_repeats, n_obs, len(groups))) baseline_loss = self.loss(y, self.predict(X)) @@ -121,6 +127,11 @@ def shapley( y = np.asarray(y) groups = self._resolve_features(features) + if distribution is not None: + sampler = self._get_sampler(distribution) + if hasattr(sampler, 'check_compatibility'): + sampler.check_compatibility(requires_multivariate=True) + n_obs = len(X) scores = np.empty((1, n_repeats, n_obs, len(groups))) baseline_loss = self.loss(y, self.predict(X)) diff --git a/src/fippy/samplers/__init__.py b/src/fippy/samplers/__init__.py index 5d0cea9..2cde152 100644 --- a/src/fippy/samplers/__init__.py +++ b/src/fippy/samplers/__init__.py @@ -1,2 +1,5 @@ +from fippy.samplers.base import Sampler +from fippy.samplers.dispatch import TypeDispatchSampler +from fippy.samplers.forest import RFClassificationSampler, RFResidualSampler from fippy.samplers.gaussian import GaussianSampler from fippy.samplers.permutation import PermutationSampler diff --git a/src/fippy/samplers/base.py b/src/fippy/samplers/base.py new file mode 100644 index 0000000..d804659 --- /dev/null +++ b/src/fippy/samplers/base.py @@ -0,0 +1,97 @@ +"""Sampler ABC: base class for all conditional distribution samplers.""" +from abc import ABC, abstractmethod + +import numpy as np +import pandas as pd + + +class Sampler(ABC): + """Base class for all conditional distribution samplers. + + Subclasses declare three capability flags as class attributes: + multivariate: Can sample len(J) > 1 natively. + supports_categorical_target: Can produce samples for categorical features in J. + supports_categorical_context: Can condition on categorical features in S. + + Categorical columns are detected from DataFrame dtypes (CategoricalDtype, object). + """ + + multivariate: bool = False + supports_categorical_target: bool = False + supports_categorical_context: bool = False + + def __init__(self, X_train: pd.DataFrame): + self.X_train = X_train + self.feature_names = list(X_train.columns) + self._categorical_cols = self._detect_categorical(X_train) + + @staticmethod + def _detect_categorical(X: pd.DataFrame) -> set[str]: + """Detect categorical columns from dtype.""" + return { + col + for col in X.columns + if isinstance(X[col].dtype, pd.CategoricalDtype) + or pd.api.types.is_object_dtype(X[col]) + } + + def check_compatibility(self, requires_multivariate: bool = False): + """Pre-flight validation: check sampler compatibility with the dataset. + + Called by the Explainer at the start of loo()/shapley(), before the + feature loop. The Explainer determines what properties are required + and passes them in. + + Args: + requires_multivariate: Whether the computation requires sampling + len(J) > 1 (e.g., Shapley always, LOO with multi-column groups). + + Raises: + ValueError: Comprehensive error listing all incompatibilities. + """ + errors = [] + + cat_cols = self._categorical_cols + if cat_cols: + if not self.supports_categorical_target: + errors.append( + f"Categorical columns {cat_cols} will appear as targets, " + f"but {type(self).__name__} does not support categorical " + f"targets. Use a sampler that supports categorical targets " + f"(e.g., PermutationSampler, ARFSampler) or wrap with " + f"TypeDispatchSampler." + ) + if not self.supports_categorical_context: + errors.append( + f"Categorical columns {cat_cols} will appear in conditioning " + f"sets, but {type(self).__name__} does not support categorical " + f"context features." + ) + + if requires_multivariate and not self.multivariate: + errors.append( + f"{type(self).__name__} is univariate (multivariate=False) " + f"and cannot sample multiple features jointly. " + f"Wrap it in SequentialSampler." + ) + + if errors: + raise ValueError( + f"Sampler {type(self).__name__} is incompatible with the " + f"requested computation:\n" + + "\n".join(f" - {e}" for e in errors) + ) + + def fit(self, J, S): + """Fit P(X_J | X_S).""" + self._fit(J, S) + + def sample(self, X, J, S, n_samples=1): + """Sample from P(X_J | X_S). Shape: (n_obs, n_samples, len(J)).""" + return self._sample(X, J, S, n_samples) + + @abstractmethod + def _fit(self, J, S): ... + + @abstractmethod + def _sample(self, X, J, S, n_samples) -> np.ndarray: ... diff --git a/src/fippy/samplers/dispatch.py b/src/fippy/samplers/dispatch.py new file mode 100644 index 0000000..978da13 --- /dev/null +++ b/src/fippy/samplers/dispatch.py @@ -0,0 +1,48 @@ +"""TypeDispatchSampler: routes to continuous or categorical sub-sampler.""" +import pandas as pd + +from fippy.samplers.base import Sampler + + +class TypeDispatchSampler(Sampler): + """Routes to continuous or categorical sub-sampler based on target dtype. + + Univariate only (len(J) = 1). For multivariate use, wrap in SequentialSampler. + + Example: + sampler = TypeDispatchSampler( + X_train, + continuous_sampler=RFResidualSampler(X_train), + categorical_sampler=RFClassificationSampler(X_train), + ) + """ + + multivariate = False + supports_categorical_target = True + + def __init__( + self, + X_train: pd.DataFrame, + continuous_sampler: Sampler, + categorical_sampler: Sampler, + ): + super().__init__(X_train) + self._continuous = continuous_sampler + self._categorical = categorical_sampler + + @property + def supports_categorical_context(self): + return (self._continuous.supports_categorical_context + and self._categorical.supports_categorical_context) + + def _select_sampler(self, J): + assert len(J) == 1 + if J[0] in self._categorical_cols: + return self._categorical + return self._continuous + + def _fit(self, J, S): + self._select_sampler(J)._fit(J, S) + + def _sample(self, X, J, S, n_samples): + return self._select_sampler(J)._sample(X, J, S, n_samples) diff --git a/src/fippy/samplers/forest.py b/src/fippy/samplers/forest.py new file mode 100644 index 0000000..48464ac --- /dev/null +++ b/src/fippy/samplers/forest.py @@ -0,0 +1,187 @@ +"""Forest-based univariate conditional samplers.""" +import numpy as np +import pandas as pd +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.model_selection import RandomizedSearchCV + +from fippy.samplers.base import Sampler + +_RF_PARAM_GRID = { + "n_estimators": [100, 200, 500], + "max_depth": [5, 10, 20, None], + "min_samples_leaf": [1, 2, 5, 10], + "max_features": ["sqrt", "log2", 0.5, 1.0], +} + + +class _ForestSamplerBase(Sampler): + """Shared logic for RF-based univariate samplers.""" + + def __init__(self, X_train: pd.DataFrame, *, tune: bool = True, + n_iter: int = 20, cv: int = 5, random_state=None): + super().__init__(X_train) + self.tune = tune + self.n_iter = n_iter + self.cv = cv + self.random_state = random_state + self._cache: dict[tuple, tuple] = {} + # Build encoding maps for categorical context features. + self._cat_maps: dict[str, dict] = { + col: {v: i for i, v in enumerate(X_train[col].unique())} + for col in self._categorical_cols + } + + def _context_array(self, X, S): + """Convert context columns to numeric numpy array for sklearn.""" + if not S: + return np.zeros((len(X), 0)) + cols = sorted(S) + arrays = [] + for col in cols: + if col in self._categorical_cols: + mapping = self._cat_maps[col] + codes = np.array([mapping.get(v, -1) for v in X[col]], dtype=float) + arrays.append(codes) + else: + arrays.append(X[col].values.astype(float)) + return np.column_stack(arrays) + + def _tune_and_fit(self, estimator, X_context, y, scoring=None): + """Fit estimator, optionally with hyperparameter tuning.""" + if self.tune and len(y) >= self.cv: + search = RandomizedSearchCV( + estimator, + _RF_PARAM_GRID, + n_iter=self.n_iter, + cv=self.cv, + scoring=scoring, + random_state=self.random_state, + n_jobs=-1, + ) + search.fit(X_context, y) + return search.best_estimator_ + estimator.fit(X_context, y) + return estimator + + @staticmethod + def _key(J, S): + return (tuple(sorted(J)), tuple(sorted(S))) + + +class RFResidualSampler(_ForestSamplerBase): + """Univariate conditional sampler for continuous targets. + + Fits a Random Forest regressor to estimate E[X_j | X_S] and stores training + residuals. At sample time, predicts the conditional expectation and adds a + randomly resampled training residual. + + Assumes homoscedastic residuals: the residual distribution does not depend + on X_S. + """ + + multivariate = False + supports_categorical_target = False + supports_categorical_context = True + + def _fit(self, J, S): + key = self._key(J, S) + if key in self._cache: + return + + assert len(J) == 1 + y_train = self.X_train[J[0]].values.astype(float) + + if not S: + mean = y_train.mean() + self._cache[key] = ("marginal", mean, y_train - mean) + return + + X_ctx = self._context_array(self.X_train, S) + model = self._tune_and_fit( + RandomForestRegressor(random_state=self.random_state), + X_ctx, y_train, + ) + residuals = y_train - model.predict(X_ctx) + self._cache[key] = ("fitted", model, residuals) + + def _sample(self, X, J, S, n_samples): + key = self._key(J, S) + if key not in self._cache: + self._fit(J, S) + + entry = self._cache[key] + n_obs = len(X) + result = np.empty((n_obs, n_samples, 1)) + + if entry[0] == "marginal": + mean, residuals = entry[1], entry[2] + for k in range(n_samples): + idx = np.random.randint(0, len(residuals), size=n_obs) + result[:, k, 0] = mean + residuals[idx] + else: + model, residuals = entry[1], entry[2] + preds = model.predict(self._context_array(X, S)) + for k in range(n_samples): + idx = np.random.randint(0, len(residuals), size=n_obs) + result[:, k, 0] = preds + residuals[idx] + + return result + + +class RFClassificationSampler(_ForestSamplerBase): + """Univariate conditional sampler for categorical targets. + + Fits a Random Forest classifier to estimate P(X_j | X_S). At sample time, + predicts class probabilities and samples from them. + """ + + multivariate = False + supports_categorical_target = True + supports_categorical_context = True + + def _fit(self, J, S): + key = self._key(J, S) + if key in self._cache: + return + + assert len(J) == 1 + y_train = self.X_train[J[0]].values + + if not S: + classes, counts = np.unique(y_train, return_counts=True) + self._cache[key] = ("marginal", classes, counts / counts.sum()) + return + + X_ctx = self._context_array(self.X_train, S) + model = self._tune_and_fit( + RandomForestClassifier(random_state=self.random_state), + X_ctx, y_train, + scoring="neg_log_loss", + ) + self._cache[key] = ("fitted", model) + + def _sample(self, X, J, S, n_samples): + key = self._key(J, S) + if key not in self._cache: + self._fit(J, S) + + entry = self._cache[key] + n_obs = len(X) + result = np.empty((n_obs, n_samples, 1), dtype=object) + + if entry[0] == "marginal": + classes, probs = entry[1], entry[2] + for k in range(n_samples): + idx = np.random.choice(len(classes), size=n_obs, p=probs) + result[:, k, 0] = classes[idx] + else: + model = entry[1] + probs = model.predict_proba(self._context_array(X, S)) + classes = model.classes_ + cumprobs = np.cumsum(probs, axis=1) + for k in range(n_samples): + u = np.random.rand(n_obs, 1) + idx = (u < cumprobs).argmax(axis=1) + result[:, k, 0] = classes[idx] + + return result diff --git a/src/fippy/samplers/gaussian.py b/src/fippy/samplers/gaussian.py index 4f23e56..cd2a863 100644 --- a/src/fippy/samplers/gaussian.py +++ b/src/fippy/samplers/gaussian.py @@ -1,10 +1,12 @@ """GaussianSampler: conditional sampling via Gaussian conditioning formulas.""" import numpy as np import pandas as pd + from fippy.backend.estimators import GaussianConditionalEstimator +from fippy.samplers.base import Sampler -class GaussianSampler: +class GaussianSampler(Sampler): """Second-order Gaussian conditional sampler. Computes P(X_J | X_S) using standard multivariate normal conditioning: @@ -15,13 +17,14 @@ class GaussianSampler: """ multivariate = True + supports_categorical_target = False + supports_categorical_context = False def __init__(self, X_train: pd.DataFrame): - self.X_train = X_train - self.feature_names = list(X_train.columns) + super().__init__(X_train) self._cache: dict[tuple, object] = {} - def fit(self, J, S): + def _fit(self, J, S): """Fit P(X_J | X_S).""" key = self._key(J, S) if key in self._cache: @@ -41,7 +44,7 @@ def fit(self, J, S): ) self._cache[key] = (estimator, J_only) - def sample(self, X, J, S, n_samples=1): + def _sample(self, X, J, S, n_samples=1): """Sample from P(X_J | X_S). Returns: np.ndarray of shape (n_obs, n_samples, len(J)). @@ -49,7 +52,7 @@ def sample(self, X, J, S, n_samples=1): J, S = list(J), list(S) key = self._key(J, S) if key not in self._cache: - self.fit(J, S) + self._fit(J, S) entry = self._cache[key] n_obs = len(X) diff --git a/src/fippy/samplers/permutation.py b/src/fippy/samplers/permutation.py index 65cc65f..680b747 100644 --- a/src/fippy/samplers/permutation.py +++ b/src/fippy/samplers/permutation.py @@ -2,8 +2,10 @@ import numpy as np import pandas as pd +from fippy.samplers.base import Sampler -class PermutationSampler: + +class PermutationSampler(Sampler): """Samples from the marginal distribution P(X_J) by randomly drawing rows from training data. @@ -11,16 +13,14 @@ class PermutationSampler: """ multivariate = True + supports_categorical_target = True + supports_categorical_context = True - def __init__(self, X_train: pd.DataFrame): - self.X_train = X_train - self.feature_names = list(X_train.columns) - - def fit(self, J, S): + def _fit(self, J, S): """No fitting needed for permutation sampling.""" pass - def sample(self, X, J, S, n_samples=1): + def _sample(self, X, J, S, n_samples=1): """Sample by drawing from training data (ignores S). Args: diff --git a/tests/test_suite3_samplers_and_basic.py b/tests/test_suite3_samplers_and_basic.py index 06ca08c..2551eb2 100644 --- a/tests/test_suite3_samplers_and_basic.py +++ b/tests/test_suite3_samplers_and_basic.py @@ -6,6 +6,10 @@ 3.3 Two-sample distribution smoke tests 3.4 ExplanationResult: array logic 3.5 Validation error tests + 3.6 Categorical compatibility tests + 3.7 RFResidualSampler tests + 3.8 RFClassificationSampler tests + 3.9 TypeDispatchSampler + LOO integration """ import numpy as np import pandas as pd @@ -15,7 +19,13 @@ from fippy import Explainer, ExplanationResult from fippy.losses import squared_error -from fippy.samplers import GaussianSampler, PermutationSampler +from fippy.samplers import ( + GaussianSampler, + PermutationSampler, + RFClassificationSampler, + RFResidualSampler, + TypeDispatchSampler, +) # =========================================================================== @@ -474,3 +484,440 @@ def test_invalid_restriction_raises(self, basic_explainer): explainer, X, y = basic_explainer with pytest.raises(ValueError, match="Invalid restriction"): explainer.loo(X, y, "invalid_restriction", distribution="marginal") + + +# =========================================================================== +# 3.6 Categorical compatibility tests +# =========================================================================== + +class TestCategoricalCompatibility: + """Test that samplers fail fast on categorical data when unsupported.""" + + @pytest.fixture + def categorical_data(self): + rng = np.random.RandomState(42) + n = 200 + X = pd.DataFrame({ + "x1": rng.randn(n), + "x2": rng.randn(n), + "x3": pd.Categorical(rng.choice(["a", "b", "c"], n)), + }) + y = X["x1"] + rng.randn(n) * 0.1 + return X, y + + def test_gaussian_sampler_rejects_categorical_target(self, categorical_data): + X, y = categorical_data + sampler = GaussianSampler(X) + predict = lambda x: x["x1"].values + explainer = Explainer(predict, X, loss=squared_error, sampler=sampler) + with pytest.raises(ValueError, match="categorical targets"): + explainer.cfi(X, y) + + def test_gaussian_sampler_rejects_categorical_context(self, categorical_data): + X, y = categorical_data + sampler = GaussianSampler(X) + predict = lambda x: x["x1"].values + explainer = Explainer(predict, X, loss=squared_error, sampler=sampler) + with pytest.raises(ValueError, match="categorical context"): + explainer.cfi(X, y) + + def test_gaussian_sampler_rejects_object_dtype(self): + rng = np.random.RandomState(42) + n = 200 + X = pd.DataFrame({ + "x1": rng.randn(n), + "x2": rng.choice(["a", "b"], n), # object dtype + }) + y = X["x1"] + rng.randn(n) * 0.1 + sampler = GaussianSampler(X) + predict = lambda x: x["x1"].values + explainer = Explainer(predict, X, loss=squared_error, sampler=sampler) + with pytest.raises(ValueError, match="categorical targets"): + explainer.cfi(X, y) + + def test_permutation_sampler_accepts_categorical(self, categorical_data): + X, _ = categorical_data + sampler = PermutationSampler(X) + # PermutationSampler supports categorical in both roles — check passes + sampler.check_compatibility(requires_multivariate=True) + + def test_check_compatibility_reports_both_errors(self, categorical_data): + """When both target and context are unsupported, both errors are reported.""" + X, _ = categorical_data + sampler = GaussianSampler(X) + with pytest.raises(ValueError, match="(?s)categorical targets.*categorical context"): + sampler.check_compatibility() + + def test_no_error_on_numeric_only_data(self): + rng = np.random.RandomState(42) + X = pd.DataFrame({"x1": rng.randn(100), "x2": rng.randn(100)}) + sampler = GaussianSampler(X) + # Should not raise + sampler.check_compatibility() + + def test_integer_columns_not_detected_as_categorical(self): + rng = np.random.RandomState(42) + X = pd.DataFrame({"x1": rng.randn(100), "x2": rng.randint(0, 3, 100)}) + sampler = GaussianSampler(X) + assert sampler._categorical_cols == set() + + def test_astype_category_detected(self): + rng = np.random.RandomState(42) + X = pd.DataFrame({"x1": rng.randn(100), "x2": rng.randint(0, 3, 100)}) + X["x2"] = X["x2"].astype("category") + sampler = GaussianSampler(X) + assert sampler._categorical_cols == {"x2"} + + def test_gaussian_rejects_categorical_in_shapley(self, categorical_data): + X, y = categorical_data + sampler = GaussianSampler(X) + predict = lambda x: x["x1"].values + explainer = Explainer(predict, X, loss=squared_error, sampler=sampler) + with pytest.raises(ValueError, match="categorical targets"): + explainer.sage(X, y, distribution="conditional", n_samples=2, + n_permutations=2) + + def test_marginal_without_sampler_skips_check(self): + """PFI with no sampler uses internal _PermutationSampler — no compatibility check.""" + rng = np.random.RandomState(42) + n = 50 + X = pd.DataFrame({"x1": rng.randn(n), "x2": rng.randn(n)}) + y = X["x1"].values + predict = lambda x: x["x1"].values + explainer = Explainer(predict, X, loss=squared_error) + # No sampler provided → internal fallback, no check_compatibility call + result = explainer.pfi(X, y) + assert result.scores.shape[2] == n + + +# =========================================================================== +# 3.7 RFResidualSampler: continuous conditional sampling +# =========================================================================== + +class TestRFResidualSampler: + + @pytest.fixture(scope="class") + def linear_data(self): + """x2 = 2*x1 + noise, so E[x2|x1] ≈ 2*x1.""" + rng = np.random.RandomState(42) + n = 500 + x1 = rng.randn(n) + noise = rng.randn(n) * 0.3 + x2 = 2.0 * x1 + noise + return pd.DataFrame({"x1": x1, "x2": x2}) + + def test_output_shape(self, linear_data): + X = linear_data + sampler = RFResidualSampler(X, tune=False, random_state=42) + result = sampler.sample(X.head(10), J=["x2"], S=["x1"], n_samples=5) + assert result.shape == (10, 5, 1) + + def test_conditional_mean_recovered(self, linear_data): + """Conditional mean E[x2|x1] ≈ 2*x1 should be recovered.""" + X = linear_data + sampler = RFResidualSampler(X, tune=False, random_state=42) + # Condition on x1=1.0 → expect mean ≈ 2.0 + X_test = pd.DataFrame({"x1": [1.0], "x2": [0.0]}) + samples = sampler.sample(X_test, J=["x2"], S=["x1"], n_samples=5000) + sampled_mean = samples[0, :, 0].mean() + np.testing.assert_allclose(sampled_mean, 2.0, atol=0.3) + + def test_residual_variance_preserved(self, linear_data): + """Variance of samples should reflect residual variance ≈ 0.3^2.""" + X = linear_data + sampler = RFResidualSampler(X, tune=False, random_state=42) + X_test = pd.DataFrame({"x1": [0.0], "x2": [0.0]}) + samples = sampler.sample(X_test, J=["x2"], S=["x1"], n_samples=5000) + sampled_std = samples[0, :, 0].std() + # Residual std ≈ 0.3, but RF training residuals are smaller (overfitting) + # Just check it's in a reasonable range + assert 0.05 < sampled_std < 0.6 + + def test_empty_S_gives_marginal(self, linear_data): + """With S=[], samples should match the marginal of x2.""" + X = linear_data + sampler = RFResidualSampler(X, tune=False, random_state=42) + X_test = pd.DataFrame({"x1": [0.0], "x2": [0.0]}) + samples = sampler.sample(X_test, J=["x2"], S=[], n_samples=5000) + sampled_mean = samples[0, :, 0].mean() + np.testing.assert_allclose(sampled_mean, X["x2"].mean(), atol=0.2) + + def test_categorical_context(self): + """RF can condition on categorical features.""" + rng = np.random.RandomState(42) + n = 300 + cat = rng.choice(["a", "b"], n) + # x1 depends on category: a → mean 0, b → mean 5 + x1 = np.where(cat == "a", rng.randn(n), rng.randn(n) + 5.0) + X = pd.DataFrame({"x1": x1, "cat": pd.Categorical(cat)}) + + sampler = RFResidualSampler(X, tune=False, random_state=42) + X_a = pd.DataFrame({"x1": [0.0], "cat": pd.Categorical(["a"], categories=["a", "b"])}) + X_b = pd.DataFrame({"x1": [0.0], "cat": pd.Categorical(["b"], categories=["a", "b"])}) + + samples_a = sampler.sample(X_a, J=["x1"], S=["cat"], n_samples=3000) + samples_b = sampler.sample(X_b, J=["x1"], S=["cat"], n_samples=3000) + + mean_a = samples_a[0, :, 0].mean() + mean_b = samples_b[0, :, 0].mean() + # Means should be clearly separated + assert mean_b - mean_a > 2.0 + + def test_with_tuning(self, linear_data): + """Tuned sampler should produce reasonable samples.""" + X = linear_data + sampler = RFResidualSampler(X, tune=True, n_iter=5, cv=3, random_state=42) + X_test = pd.DataFrame({"x1": [1.0], "x2": [0.0]}) + samples = sampler.sample(X_test, J=["x2"], S=["x1"], n_samples=3000) + sampled_mean = samples[0, :, 0].mean() + np.testing.assert_allclose(sampled_mean, 2.0, atol=0.4) + + def test_caching(self, linear_data): + """Second call with same (J, S) should use cache.""" + X = linear_data + sampler = RFResidualSampler(X, tune=False, random_state=42) + sampler.fit(["x2"], ["x1"]) + assert len(sampler._cache) == 1 + sampler.fit(["x2"], ["x1"]) # should hit cache + assert len(sampler._cache) == 1 + sampler.fit(["x1"], ["x2"]) # different (J, S) + assert len(sampler._cache) == 2 + + def test_capability_flags(self): + rng = np.random.RandomState(42) + X = pd.DataFrame({"x1": rng.randn(50), "x2": rng.randn(50)}) + sampler = RFResidualSampler(X, tune=False) + assert sampler.multivariate is False + assert sampler.supports_categorical_target is False + assert sampler.supports_categorical_context is True + + def test_rejects_categorical_target(self): + rng = np.random.RandomState(42) + n = 100 + X = pd.DataFrame({ + "x1": rng.randn(n), + "cat": pd.Categorical(rng.choice(["a", "b"], n)), + }) + sampler = RFResidualSampler(X, tune=False) + with pytest.raises(ValueError, match="categorical targets"): + sampler.check_compatibility() + + +# =========================================================================== +# 3.8 RFClassificationSampler: categorical conditional sampling +# =========================================================================== + +class TestRFClassificationSampler: + + @pytest.fixture(scope="class") + def categorical_data(self): + """cat depends on x1: x1 > 0 → mostly 'a', x1 < 0 → mostly 'b'.""" + rng = np.random.RandomState(42) + n = 500 + x1 = rng.randn(n) + probs_a = 1 / (1 + np.exp(-3 * x1)) # sigmoid + cat = np.where(rng.rand(n) < probs_a, "a", "b") + return pd.DataFrame({"x1": x1, "cat": pd.Categorical(cat)}) + + def test_output_shape(self, categorical_data): + X = categorical_data + sampler = RFClassificationSampler(X, tune=False, random_state=42) + result = sampler.sample(X.head(10), J=["cat"], S=["x1"], n_samples=5) + assert result.shape == (10, 5, 1) + + def test_samples_are_valid_categories(self, categorical_data): + X = categorical_data + sampler = RFClassificationSampler(X, tune=False, random_state=42) + result = sampler.sample(X.head(20), J=["cat"], S=["x1"], n_samples=10) + unique_vals = set(result.ravel()) + assert unique_vals <= {"a", "b"} + + def test_conditional_probabilities(self, categorical_data): + """For x1=2 (large positive), P(cat='a'|x1) should be high.""" + X = categorical_data + sampler = RFClassificationSampler(X, tune=False, random_state=42) + X_pos = pd.DataFrame({"x1": [2.0], "cat": pd.Categorical(["a"], categories=["a", "b"])}) + X_neg = pd.DataFrame({"x1": [-2.0], "cat": pd.Categorical(["a"], categories=["a", "b"])}) + + samples_pos = sampler.sample(X_pos, J=["cat"], S=["x1"], n_samples=3000) + samples_neg = sampler.sample(X_neg, J=["cat"], S=["x1"], n_samples=3000) + + frac_a_pos = (samples_pos[0, :, 0] == "a").mean() + frac_a_neg = (samples_neg[0, :, 0] == "a").mean() + + # x1=2 should give much more 'a' than x1=-2 + assert frac_a_pos > 0.7 + assert frac_a_neg < 0.3 + + def test_empty_S_gives_marginal(self, categorical_data): + """With S=[], samples should match marginal class frequencies.""" + X = categorical_data + sampler = RFClassificationSampler(X, tune=False, random_state=42) + X_test = pd.DataFrame({"x1": [0.0], "cat": pd.Categorical(["a"], categories=["a", "b"])}) + samples = sampler.sample(X_test, J=["cat"], S=[], n_samples=5000) + frac_a = (samples[0, :, 0] == "a").mean() + expected_frac_a = (X["cat"] == "a").mean() + np.testing.assert_allclose(frac_a, expected_frac_a, atol=0.05) + + def test_with_tuning(self, categorical_data): + """Tuned sampler should produce valid samples.""" + X = categorical_data + sampler = RFClassificationSampler(X, tune=True, n_iter=5, cv=3, random_state=42) + result = sampler.sample(X.head(5), J=["cat"], S=["x1"], n_samples=10) + assert result.shape == (5, 10, 1) + assert set(result.ravel()) <= {"a", "b"} + + def test_multiclass(self): + """Works with more than 2 classes.""" + rng = np.random.RandomState(42) + n = 300 + x1 = rng.randn(n) + # 3 classes with different means + cat = np.where(x1 > 0.5, "high", np.where(x1 < -0.5, "low", "mid")) + X = pd.DataFrame({"x1": x1, "cat": pd.Categorical(cat)}) + + sampler = RFClassificationSampler(X, tune=False, random_state=42) + X_high = pd.DataFrame({ + "x1": [2.0], + "cat": pd.Categorical(["high"], categories=["high", "low", "mid"]), + }) + samples = sampler.sample(X_high, J=["cat"], S=["x1"], n_samples=3000) + frac_high = (samples[0, :, 0] == "high").mean() + assert frac_high > 0.5 + + def test_categorical_context(self): + """Classifier can condition on categorical context features.""" + rng = np.random.RandomState(42) + n = 300 + ctx = rng.choice(["x", "y"], n) + # target depends on context: ctx='x' → mostly 'a', ctx='y' → mostly 'b' + target = np.where( + ctx == "x", + np.where(rng.rand(n) < 0.9, "a", "b"), + np.where(rng.rand(n) < 0.1, "a", "b"), + ) + X = pd.DataFrame({ + "ctx": pd.Categorical(ctx), + "target": pd.Categorical(target), + }) + + sampler = RFClassificationSampler(X, tune=False, random_state=42) + X_x = pd.DataFrame({ + "ctx": pd.Categorical(["x"], categories=["x", "y"]), + "target": pd.Categorical(["a"], categories=["a", "b"]), + }) + samples = sampler.sample(X_x, J=["target"], S=["ctx"], n_samples=3000) + frac_a = (samples[0, :, 0] == "a").mean() + assert frac_a > 0.7 + + def test_capability_flags(self): + rng = np.random.RandomState(42) + X = pd.DataFrame({ + "x1": rng.randn(50), + "cat": pd.Categorical(rng.choice(["a", "b"], 50)), + }) + sampler = RFClassificationSampler(X, tune=False) + assert sampler.multivariate is False + assert sampler.supports_categorical_target is True + assert sampler.supports_categorical_context is True + + def test_passes_compatibility_with_categoricals(self): + rng = np.random.RandomState(42) + X = pd.DataFrame({ + "x1": rng.randn(50), + "cat": pd.Categorical(rng.choice(["a", "b"], 50)), + }) + sampler = RFClassificationSampler(X, tune=False) + # Should not raise — supports both categorical roles + sampler.check_compatibility() + + +# =========================================================================== +# 3.9 TypeDispatchSampler + LOO integration +# =========================================================================== + +class TestTypeDispatchSamplerLOO: + """End-to-end test: TypeDispatchSampler wrapping RF samplers for LOO/CFI + on a mixed continuous + categorical dataset.""" + + @pytest.fixture(scope="class") + def mixed_data(self): + rng = np.random.RandomState(42) + n = 300 + x1 = rng.randn(n) + cat = pd.Categorical(rng.choice(["a", "b"], n)) + # y = x1 + effect_of_cat + noise + y = x1 + np.where(cat == "a", 2.0, -2.0) + rng.randn(n) * 0.1 + X = pd.DataFrame({"x1": x1, "cat": cat}) + return X, y + + @pytest.fixture(scope="class") + def dispatch_sampler(self, mixed_data): + X, _ = mixed_data + return TypeDispatchSampler( + X, + continuous_sampler=RFResidualSampler(X, tune=False, random_state=42), + categorical_sampler=RFClassificationSampler(X, tune=False, random_state=42), + ) + + def test_cfi_runs_and_returns_correct_shape(self, mixed_data, dispatch_sampler): + X, y = mixed_data + X_test, y_test = X.head(30), y[:30] + + def predict(X_df): + x1_val = X_df["x1"].values.astype(float) + cat_val = np.where(X_df["cat"] == "a", 2.0, -2.0) + return x1_val + cat_val + + explainer = Explainer(predict, X, loss=squared_error, sampler=dispatch_sampler) + result = explainer.cfi(X_test, y_test, n_repeats=3) + + assert result.scores.shape == (1, 3, 30, 2) + assert result.feature_names == ["x1", "cat"] + + def test_both_features_have_positive_importance(self, mixed_data, dispatch_sampler): + """Both x1 and cat contribute to y, so both should have positive importance.""" + X, y = mixed_data + + def predict(X_df): + x1_val = X_df["x1"].values.astype(float) + cat_val = np.where(X_df["cat"] == "a", 2.0, -2.0) + return x1_val + cat_val + + explainer = Explainer(predict, X, loss=squared_error, sampler=dispatch_sampler) + result = explainer.cfi(X, y, n_repeats=5) + imp = result.importance() + + assert imp.loc["x1", "importance"] > 0 + assert imp.loc["cat", "importance"] > 0 + + def test_dispatches_to_correct_sub_sampler(self, mixed_data, dispatch_sampler): + """Continuous target → RFResidualSampler, categorical → RFClassificationSampler.""" + X, _ = mixed_data + assert isinstance(dispatch_sampler._select_sampler(["x1"]), RFResidualSampler) + assert isinstance(dispatch_sampler._select_sampler(["cat"]), RFClassificationSampler) + + def test_compatibility_check_passes(self, dispatch_sampler): + """TypeDispatchSampler supports both categorical roles.""" + dispatch_sampler.check_compatibility(requires_multivariate=False) + + def test_compatibility_check_rejects_multivariate(self, dispatch_sampler): + """TypeDispatchSampler is univariate — multivariate request fails.""" + with pytest.raises(ValueError, match="univariate"): + dispatch_sampler.check_compatibility(requires_multivariate=True) + + def test_n_repeats_produce_different_scores(self, mixed_data, dispatch_sampler): + """Different repeats should give different scores (randomness in sampling).""" + X, y = mixed_data + X_test, y_test = X.head(20), y[:20] + + def predict(X_df): + x1_val = X_df["x1"].values.astype(float) + cat_val = np.where(X_df["cat"] == "a", 2.0, -2.0) + return x1_val + cat_val + + explainer = Explainer(predict, X, loss=squared_error, sampler=dispatch_sampler) + result = explainer.cfi(X_test, y_test, n_repeats=2) + + # Repeat 0 and repeat 1 should differ (different random draws) + assert not np.allclose(result.scores[0, 0, :, :], result.scores[0, 1, :, :])