Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cais/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
input_parser_tool,
dataset_analyzer_tool,
query_interpreter_tool,
iv_discovery_tool,
method_selector_tool,
method_validator_tool,
method_executor_tool,
Expand Down
41 changes: 39 additions & 2 deletions cais/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cais.tools.input_parser_tool import input_parser_tool
from cais.tools.dataset_analyzer_tool import dataset_analyzer_tool
from cais.tools.query_interpreter_tool import query_interpreter_tool
from cais.tools.iv_discovery_tool import iv_discovery_tool
from cais.tools.method_selector_tool import method_selector_tool
from cais.tools.controls_selector_tool import controls_selector_tool
from cais.tools.method_validator_tool import method_validator_tool
Expand Down Expand Up @@ -49,6 +50,11 @@

# Set up basic logging
os.makedirs('./logs/', exist_ok=True)
logging.basicConfig(
filename='./logs/agent_debug.log',
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


Comment on lines 51 to 60
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging.basicConfig(...) at import time reconfigures global logging for any library consumer and for the whole test suite. Prefer leaving logging configuration to the application entrypoint/CLI; here, set a module logger and emit logs without calling basicConfig (or gate it behind if __name__ == "__main__").

Suggested change
# Set up basic logging
os.makedirs('./logs/', exist_ok=True)
logging.basicConfig(
filename='./logs/agent_debug.log',
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Set up module logger
logger = logging.getLogger(__name__)
if __name__ == "__main__":
# Configure basic logging only when this module is executed as a script
os.makedirs('./logs/', exist_ok=True)
logging.basicConfig(
filename='./logs/agent_debug.log',
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -153,6 +159,23 @@ def select_method(self, query=None, llm_decision=True):
self.selected_method = self.method_info.selected_method
return self.selected_method

def discover_instruments(self, query=None):
query = self.checkq(query)

iv_discovery_output = iv_discovery_tool.func(
variables=self.variables,
dataset_analysis=self.dataset_analysis,
dataset_description=self.dataset_description,
original_query=query
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IV discovery is invoked via iv_discovery_tool.func(...) without passing the agent’s configured LLM/provider/model. The current IV discovery component creates its own default LLM client, which can lead to inconsistent behavior (and failures) when the agent is configured for a non-default provider/model. Consider passing self.llm through the tool/component so all steps share the same client/config.

Suggested change
original_query=query
original_query=query,
llm=self.llm,

Copilot uses AI. Check for mistakes.
)

if hasattr(iv_discovery_output, "model_dump"):
iv_discovery_output_dict = iv_discovery_output.model_dump()
else:
iv_discovery_output_dict = iv_discovery_output

self.variables = Variables(**iv_discovery_output_dict["variables"])
return self.variables

def validate_method(self, query=None):
'''
Expand All @@ -176,7 +199,7 @@ def select_controls(self, query=None) -> list:

query = self.checkq(query)

controls_selector_output = controls_selector_tool(
controls_selector_output = controls_selector_tool.func(
method_name=self.selected_method,
variables=self.variables,
dataset_analysis=self.dataset_analysis,
Expand All @@ -197,7 +220,16 @@ def clean_dataset(self, query=None):
original_query=query,
causal_method=self.selected_method
)
self.cleaned_dataset_path = cleaning_output.get("cleaned_dataset_path", self.dataset_path)
self.cleaned_dataset_path = cleaning_output.get("cleaned_dataset_path")

# Check if file was actually created/returned
if not self.cleaned_dataset_path or not os.path.exists(self.cleaned_dataset_path):
stderr = cleaning_output.get("stderr", "No stderr available.")
logger.error(f"Dataset cleaning failed to produce a file at {self.cleaned_dataset_path}. Stderr: {stderr}")
# Fallback to original dataset if cleaning failed but we want to attempt execution?
# Or raise error. Let's raise for now to be safe.
raise FileNotFoundError(f"Cleaned dataset NOT found at {self.cleaned_dataset_path}. Cleaning stderr: {stderr}")

return self.cleaned_dataset_path

def execute_method(self, query=None, remove_cleaned=True):
Expand Down Expand Up @@ -253,6 +285,11 @@ def run_analysis(self, query, llm_method_selection: Optional[bool] = True):
query=query,
llm_decision=llm_method_selection
)
if self.selected_method == INSTRUMENTAL_VARIABLE:
logger.info("Instrumental Variable method selected. Running IV Discovery...")
self.discover_instruments(
query=query
)
self.select_controls(
query=query
)
Expand Down
19 changes: 15 additions & 4 deletions cais/components/dataset_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@
Input: (dataset_path, Transformation Spec JSON).
Output: A SINGLE Python script as text that:
- imports only: json, os, pandas as pd, numpy as np
- loads the dataset from dataset_path (infer CSV/Parquet by extension)
- loads the dataset from the path provided in global variable `__DATASET_PATH__`
- applies ONLY what the Spec asks for (row_filters, column_ops, method_constructs, etc.)
- keeps all original columns unless Spec explicitly drops them
- creates any new columns explicitly; suffix where needed; no silent overwrite
- produces a dataframe named clean_df
- writes:
- cleaned_df.csv (same directory as dataset_path)
- clean_df.csv to the path provided in global variable `__CLEANED_PATH__`
- preprocessing_manifest.json (the Spec actually executed)
- derived_columns.json (list of new columns with one-line descriptions)
- prints a concise, human-readable summary report to stdout
Expand Down Expand Up @@ -181,7 +181,15 @@ def _run_script_text(script: str, dataset_path: str, cleaned_path: str) -> Tuple
with contextlib.redirect_stdout(stdout_io), contextlib.redirect_stderr(stderr_io):
# Provide dataset_path as a global the script can read (it should anyway use the passed JSON)
gbls["__DATASET_PATH__"] = dataset_path
script=script.replace('cleaned_df.csv', cleaned_path)
gbls["__CLEANED_PATH__"] = cleaned_path

# We restore the replacements but use json.dumps for safe quoting on Windows
# This handles models that hardcode the path despite instructions
for placeholder in ["cleaned_df.csv", "clean_df.csv", "manifest.json", "derived_columns.json"]:
if placeholder in script:
script = script.replace(f'"{placeholder}"', json.dumps(cleaned_path if "csv" in placeholder else placeholder))
script = script.replace(f"'{placeholder}'", json.dumps(cleaned_path if "csv" in placeholder else placeholder))

exec(script, gbls, lcls)
except Exception as e:
tb = traceback.format_exc()
Expand Down Expand Up @@ -246,7 +254,10 @@ def run_cleaning_stage(dataset_path: str,
"""
llm = get_llm_client()

cleaned_path = os.path.join(os.path.dirname(os.path.abspath(dataset_path)) or ".", f"{dataset_path.split('/')[-1][:-4]}_cleaned_{os.getpid()}.csv")
dataset_path = dataset_path.replace("\\", "/")
base_name = os.path.basename(dataset_path)
file_stem = os.path.splitext(base_name)[0]
cleaned_path = os.path.join(os.path.dirname(os.path.abspath(dataset_path)) or ".", f"{file_stem}_cleaned_{os.getpid()}.csv").replace("\\", "/")

# 1) PLAN
method = causal_method or variables.get("method") or ""
Expand Down
99 changes: 99 additions & 0 deletions cais/components/iv_discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
from typing import Dict, List, Any, Optional

from cais.config import get_llm_client
from cais.iv_llm.src.agents.hypothesizer import Hypothesizer
from cais.iv_llm.src.agents.confounder_miner import ConfounderMiner
from cais.iv_llm.src.critics.exclusion_critic import ExclusionCritic
from cais.iv_llm.src.critics.independence_critic import IndependenceCritic


logger = logging.getLogger(__name__)

class IVDiscovery:
def __init__(self):
llm = get_llm_client()
self.hypothesizer = Hypothesizer(llm, k=5)
self.confounder_miner = ConfounderMiner(llm, j=5)
self.exclusion_critic = ExclusionCritic(llm)
self.independence_critic = IndependenceCritic(llm)
Comment on lines +13 to +19
Copy link

Copilot AI Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IV discovery component always calls get_llm_client() internally, which prevents reuse of the caller’s LLM instance/config and makes testing harder. Consider accepting an optional llm in IVDiscovery.__init__ (and/or the discover_instruments function) so callers can inject a shared client and avoid repeated instantiation.

Copilot uses AI. Check for mistakes.

def discover_instruments(self, treatment: str, outcome: str, context: str = "", confounders: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Discover valid instrumental variables for the given treatment and outcome.

Args:
treatment: Name of the treatment variable
outcome: Name of the outcome variable
context: Additional context about the dataset/query
confounders: List of known confounders (optional)

Returns:
Dict containing proposed IVs, valid IVs, and validation results
"""
logger.info(f"Discovering instruments for treatment: {treatment}, outcome: {outcome}")

# Step 1: Hypothesize IVs
proposed_ivs = self.hypothesizer.propose_ivs(treatment, outcome, context=context)
logger.info(f"Proposed IVs: {proposed_ivs}")

if not proposed_ivs:
return {
'proposed_ivs': [],
'valid_ivs': [],
'validation_results': [],
'confounders': confounders or []
}

# Step 2: Identify confounders if not provided
if confounders is None:
confounders = self.confounder_miner.identify_confounders(treatment, outcome, context=context)
logger.info(f"Identified confounders: {confounders}")

# Step 3: Validate IVs with critics
validation_results = []
valid_ivs = []

# Run exclusion and independence critics
exclusion_results = {}
independence_results = {}

# First pass: exclusion critic for all IVs
for iv in proposed_ivs:
exclusion_results[iv] = self.exclusion_critic.validate_exclusion(iv, treatment, outcome, confounders)

# Second pass: independence critic for all IVs
for iv in proposed_ivs:
independence_results[iv] = self.independence_critic.validate_independence(iv, treatment, outcome, confounders)

# Combine results
for iv in proposed_ivs:
exclusion_valid = exclusion_results[iv]
independence_valid = independence_results[iv]

validation_results.append({
'iv': iv,
'exclusion_valid': exclusion_valid,
'independence_valid': independence_valid,
'overall_valid': exclusion_valid and independence_valid
})

if exclusion_valid and independence_valid:
valid_ivs.append(iv)

logger.info(f"Valid IVs found: {valid_ivs}")

return {
'proposed_ivs': proposed_ivs,
'valid_ivs': valid_ivs,
'validation_results': validation_results,
'confounders': confounders
}

def discover_instruments(treatment: str, outcome: str, context: str = "", confounders: Optional[List[str]] = None, config_path: Optional[str] = None) -> Dict[str, Any]:
"""
Convenience function to discover instruments using default IVDiscovery instance.
"""
# `config_path` is currently unused; keep it for backwards compatibility.
discovery = IVDiscovery()
return discovery.discover_instruments(treatment, outcome, context, confounders)
38 changes: 38 additions & 0 deletions cais/iv_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# IV-LLM package

import logging
import os
from pathlib import Path


def _find_project_root() -> Path:
"""Walk up from this file to find the directory containing pyproject.toml."""
current = Path(__file__).resolve().parent
for parent in [current, *current.parents]:
if (parent / "pyproject.toml").exists():
return parent
return Path.cwd()


def _get_log_path() -> Path:
env_dir = os.getenv("IV_LLM_OUTPUT_DIR")
if env_dir:
return Path(env_dir) / "iv_llm.jsonl"
return _find_project_root() / "logs" / "iv_llm.jsonl"


# Configure a file handler on the "cais.iv_llm" logger so that every child
# logger (agents, critics, etc.) automatically writes to the IV-LLM log file.
_logger = logging.getLogger(__name__) # "cais.iv_llm"
if not _logger.handlers:
_logger.setLevel(logging.DEBUG)
_logger.propagate = True # still propagate to root for console output
try:
_log_path = _get_log_path()
_log_path.parent.mkdir(parents=True, exist_ok=True)
_handler = logging.FileHandler(str(_log_path), encoding="utf-8")
_handler.setLevel(logging.INFO)
_handler.setFormatter(logging.Formatter("%(message)s"))
_logger.addHandler(_handler)
except Exception:
pass
1 change: 1 addition & 0 deletions cais/iv_llm/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# IV-LLM src package
1 change: 1 addition & 0 deletions cais/iv_llm/src/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Agents package
63 changes: 63 additions & 0 deletions cais/iv_llm/src/agents/confounder_miner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import json
import logging
from typing import List

from langchain_core.language_models import BaseChatModel

from cais.utils.llm_helpers import invoke_llm
from ..prompts.prompt_loader import PromptLoader
from ..variable_utils import extract_available_columns, filter_to_available

logger = logging.getLogger(__name__)

class ConfounderMiner:
def __init__(self, llm: BaseChatModel, j: int = 5) -> None:
self.llm = llm
self.j = j
self.prompt_loader = PromptLoader()

def identify_confounders(self, treatment: str, outcome: str, context: str = "") -> List[str]:
prompt = self.prompt_loader.format_confounder_prompt(treatment, outcome, self.j, context=context)
response = invoke_llm(self.llm, prompt)
confounders_raw = self._parse_confounders(response)

available_cols = extract_available_columns(context)
confounders = (
filter_to_available(confounders_raw, available_cols)
if available_cols
else confounders_raw
)

confounders = confounders[: self.j]

logger.info(json.dumps({
'name': 'confounder_miner',
'inputs': {'treatment': treatment, 'outcome': outcome, 'j': self.j},
'outputs': {'confounders': confounders},
'raw_response': response,
}, default=str))

return confounders

def _parse_confounders(self, response: str) -> List[str]:
import re

def _clean(name: str) -> str:
return name.strip().strip('"\'').strip('`').strip('*').strip()

# Try XML format first
match = re.search(r'<Answer>\[(.*?)\]</Answer>', response)
if match:
confounders_str = match.group(1)
confounders = [_clean(c) for c in confounders_str.split(',')]
return confounders[:self.j]

# Fallback: look for bracket format without XML
bracket_match = re.search(r'\[([^\]]+)\]', response)
if bracket_match:
confounders_str = bracket_match.group(1)
confounders = [_clean(c) for c in confounders_str.split(',')]
return confounders[:self.j]

print(f"WARNING: Could not parse confounders from response: {response[:200]}...")
return []
Loading