diff --git a/cais/__init__.py b/cais/__init__.py
index cffdd70..fd447fc 100644
--- a/cais/__init__.py
+++ b/cais/__init__.py
@@ -23,6 +23,7 @@
input_parser_tool,
dataset_analyzer_tool,
query_interpreter_tool,
+ iv_discovery_tool,
method_selector_tool,
method_validator_tool,
method_executor_tool,
diff --git a/cais/agent.py b/cais/agent.py
index dee2aff..149d076 100644
--- a/cais/agent.py
+++ b/cais/agent.py
@@ -8,6 +8,7 @@
from cais.tools.input_parser_tool import input_parser_tool
from cais.tools.dataset_analyzer_tool import dataset_analyzer_tool
from cais.tools.query_interpreter_tool import query_interpreter_tool
+from cais.tools.iv_discovery_tool import iv_discovery_tool
from cais.tools.method_selector_tool import method_selector_tool
from cais.tools.controls_selector_tool import controls_selector_tool
from cais.tools.method_validator_tool import method_validator_tool
@@ -49,6 +50,11 @@
# Set up basic logging
os.makedirs('./logs/', exist_ok=True)
+logging.basicConfig(
+ filename='./logs/agent_debug.log',
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
logger = logging.getLogger(__name__)
@@ -153,6 +159,23 @@ def select_method(self, query=None, llm_decision=True):
self.selected_method = self.method_info.selected_method
return self.selected_method
+ def discover_instruments(self, query=None):
+ query = self.checkq(query)
+
+ iv_discovery_output = iv_discovery_tool.func(
+ variables=self.variables,
+ dataset_analysis=self.dataset_analysis,
+ dataset_description=self.dataset_description,
+ original_query=query
+ )
+
+ if hasattr(iv_discovery_output, "model_dump"):
+ iv_discovery_output_dict = iv_discovery_output.model_dump()
+ else:
+ iv_discovery_output_dict = iv_discovery_output
+
+ self.variables = Variables(**iv_discovery_output_dict["variables"])
+ return self.variables
def validate_method(self, query=None):
'''
@@ -176,7 +199,7 @@ def select_controls(self, query=None) -> list:
query = self.checkq(query)
- controls_selector_output = controls_selector_tool(
+ controls_selector_output = controls_selector_tool.func(
method_name=self.selected_method,
variables=self.variables,
dataset_analysis=self.dataset_analysis,
@@ -197,7 +220,16 @@ def clean_dataset(self, query=None):
original_query=query,
causal_method=self.selected_method
)
- self.cleaned_dataset_path = cleaning_output.get("cleaned_dataset_path", self.dataset_path)
+ self.cleaned_dataset_path = cleaning_output.get("cleaned_dataset_path")
+
+ # Check if file was actually created/returned
+ if not self.cleaned_dataset_path or not os.path.exists(self.cleaned_dataset_path):
+ stderr = cleaning_output.get("stderr", "No stderr available.")
+ logger.error(f"Dataset cleaning failed to produce a file at {self.cleaned_dataset_path}. Stderr: {stderr}")
+ # Fallback to original dataset if cleaning failed but we want to attempt execution?
+ # Or raise error. Let's raise for now to be safe.
+ raise FileNotFoundError(f"Cleaned dataset NOT found at {self.cleaned_dataset_path}. Cleaning stderr: {stderr}")
+
return self.cleaned_dataset_path
def execute_method(self, query=None, remove_cleaned=True):
@@ -253,6 +285,11 @@ def run_analysis(self, query, llm_method_selection: Optional[bool] = True):
query=query,
llm_decision=llm_method_selection
)
+ if self.selected_method == INSTRUMENTAL_VARIABLE:
+ logger.info("Instrumental Variable method selected. Running IV Discovery...")
+ self.discover_instruments(
+ query=query
+ )
self.select_controls(
query=query
)
diff --git a/cais/components/dataset_cleaner.py b/cais/components/dataset_cleaner.py
index e8531ce..66124d9 100644
--- a/cais/components/dataset_cleaner.py
+++ b/cais/components/dataset_cleaner.py
@@ -70,13 +70,13 @@
Input: (dataset_path, Transformation Spec JSON).
Output: A SINGLE Python script as text that:
- imports only: json, os, pandas as pd, numpy as np
-- loads the dataset from dataset_path (infer CSV/Parquet by extension)
+- loads the dataset from the path provided in global variable `__DATASET_PATH__`
- applies ONLY what the Spec asks for (row_filters, column_ops, method_constructs, etc.)
- keeps all original columns unless Spec explicitly drops them
- creates any new columns explicitly; suffix where needed; no silent overwrite
- produces a dataframe named clean_df
- writes:
- - cleaned_df.csv (same directory as dataset_path)
+ - clean_df.csv to the path provided in global variable `__CLEANED_PATH__`
- preprocessing_manifest.json (the Spec actually executed)
- derived_columns.json (list of new columns with one-line descriptions)
- prints a concise, human-readable summary report to stdout
@@ -181,7 +181,15 @@ def _run_script_text(script: str, dataset_path: str, cleaned_path: str) -> Tuple
with contextlib.redirect_stdout(stdout_io), contextlib.redirect_stderr(stderr_io):
# Provide dataset_path as a global the script can read (it should anyway use the passed JSON)
gbls["__DATASET_PATH__"] = dataset_path
- script=script.replace('cleaned_df.csv', cleaned_path)
+ gbls["__CLEANED_PATH__"] = cleaned_path
+
+ # We restore the replacements but use json.dumps for safe quoting on Windows
+ # This handles models that hardcode the path despite instructions
+ for placeholder in ["cleaned_df.csv", "clean_df.csv", "manifest.json", "derived_columns.json"]:
+ if placeholder in script:
+ script = script.replace(f'"{placeholder}"', json.dumps(cleaned_path if "csv" in placeholder else placeholder))
+ script = script.replace(f"'{placeholder}'", json.dumps(cleaned_path if "csv" in placeholder else placeholder))
+
exec(script, gbls, lcls)
except Exception as e:
tb = traceback.format_exc()
@@ -246,7 +254,10 @@ def run_cleaning_stage(dataset_path: str,
"""
llm = get_llm_client()
- cleaned_path = os.path.join(os.path.dirname(os.path.abspath(dataset_path)) or ".", f"{dataset_path.split('/')[-1][:-4]}_cleaned_{os.getpid()}.csv")
+ dataset_path = dataset_path.replace("\\", "/")
+ base_name = os.path.basename(dataset_path)
+ file_stem = os.path.splitext(base_name)[0]
+ cleaned_path = os.path.join(os.path.dirname(os.path.abspath(dataset_path)) or ".", f"{file_stem}_cleaned_{os.getpid()}.csv").replace("\\", "/")
# 1) PLAN
method = causal_method or variables.get("method") or ""
diff --git a/cais/components/iv_discovery.py b/cais/components/iv_discovery.py
new file mode 100644
index 0000000..c1fb426
--- /dev/null
+++ b/cais/components/iv_discovery.py
@@ -0,0 +1,99 @@
+import logging
+from typing import Dict, List, Any, Optional
+
+from cais.config import get_llm_client
+from cais.iv_llm.src.agents.hypothesizer import Hypothesizer
+from cais.iv_llm.src.agents.confounder_miner import ConfounderMiner
+from cais.iv_llm.src.critics.exclusion_critic import ExclusionCritic
+from cais.iv_llm.src.critics.independence_critic import IndependenceCritic
+
+
+logger = logging.getLogger(__name__)
+
+class IVDiscovery:
+ def __init__(self):
+ llm = get_llm_client()
+ self.hypothesizer = Hypothesizer(llm, k=5)
+ self.confounder_miner = ConfounderMiner(llm, j=5)
+ self.exclusion_critic = ExclusionCritic(llm)
+ self.independence_critic = IndependenceCritic(llm)
+
+ def discover_instruments(self, treatment: str, outcome: str, context: str = "", confounders: Optional[List[str]] = None) -> Dict[str, Any]:
+ """
+ Discover valid instrumental variables for the given treatment and outcome.
+
+ Args:
+ treatment: Name of the treatment variable
+ outcome: Name of the outcome variable
+ context: Additional context about the dataset/query
+ confounders: List of known confounders (optional)
+
+ Returns:
+ Dict containing proposed IVs, valid IVs, and validation results
+ """
+ logger.info(f"Discovering instruments for treatment: {treatment}, outcome: {outcome}")
+
+ # Step 1: Hypothesize IVs
+ proposed_ivs = self.hypothesizer.propose_ivs(treatment, outcome, context=context)
+ logger.info(f"Proposed IVs: {proposed_ivs}")
+
+ if not proposed_ivs:
+ return {
+ 'proposed_ivs': [],
+ 'valid_ivs': [],
+ 'validation_results': [],
+ 'confounders': confounders or []
+ }
+
+ # Step 2: Identify confounders if not provided
+ if confounders is None:
+ confounders = self.confounder_miner.identify_confounders(treatment, outcome, context=context)
+ logger.info(f"Identified confounders: {confounders}")
+
+ # Step 3: Validate IVs with critics
+ validation_results = []
+ valid_ivs = []
+
+ # Run exclusion and independence critics
+ exclusion_results = {}
+ independence_results = {}
+
+ # First pass: exclusion critic for all IVs
+ for iv in proposed_ivs:
+ exclusion_results[iv] = self.exclusion_critic.validate_exclusion(iv, treatment, outcome, confounders)
+
+ # Second pass: independence critic for all IVs
+ for iv in proposed_ivs:
+ independence_results[iv] = self.independence_critic.validate_independence(iv, treatment, outcome, confounders)
+
+ # Combine results
+ for iv in proposed_ivs:
+ exclusion_valid = exclusion_results[iv]
+ independence_valid = independence_results[iv]
+
+ validation_results.append({
+ 'iv': iv,
+ 'exclusion_valid': exclusion_valid,
+ 'independence_valid': independence_valid,
+ 'overall_valid': exclusion_valid and independence_valid
+ })
+
+ if exclusion_valid and independence_valid:
+ valid_ivs.append(iv)
+
+ logger.info(f"Valid IVs found: {valid_ivs}")
+
+ return {
+ 'proposed_ivs': proposed_ivs,
+ 'valid_ivs': valid_ivs,
+ 'validation_results': validation_results,
+ 'confounders': confounders
+ }
+
+def discover_instruments(treatment: str, outcome: str, context: str = "", confounders: Optional[List[str]] = None, config_path: Optional[str] = None) -> Dict[str, Any]:
+ """
+ Convenience function to discover instruments using default IVDiscovery instance.
+ """
+ # `config_path` is currently unused; keep it for backwards compatibility.
+ discovery = IVDiscovery()
+ return discovery.discover_instruments(treatment, outcome, context, confounders)
\ No newline at end of file
diff --git a/cais/iv_llm/__init__.py b/cais/iv_llm/__init__.py
new file mode 100644
index 0000000..0106dc4
--- /dev/null
+++ b/cais/iv_llm/__init__.py
@@ -0,0 +1,38 @@
+# IV-LLM package
+
+import logging
+import os
+from pathlib import Path
+
+
+def _find_project_root() -> Path:
+ """Walk up from this file to find the directory containing pyproject.toml."""
+ current = Path(__file__).resolve().parent
+ for parent in [current, *current.parents]:
+ if (parent / "pyproject.toml").exists():
+ return parent
+ return Path.cwd()
+
+
+def _get_log_path() -> Path:
+ env_dir = os.getenv("IV_LLM_OUTPUT_DIR")
+ if env_dir:
+ return Path(env_dir) / "iv_llm.jsonl"
+ return _find_project_root() / "logs" / "iv_llm.jsonl"
+
+
+# Configure a file handler on the "cais.iv_llm" logger so that every child
+# logger (agents, critics, etc.) automatically writes to the IV-LLM log file.
+_logger = logging.getLogger(__name__) # "cais.iv_llm"
+if not _logger.handlers:
+ _logger.setLevel(logging.DEBUG)
+ _logger.propagate = True # still propagate to root for console output
+ try:
+ _log_path = _get_log_path()
+ _log_path.parent.mkdir(parents=True, exist_ok=True)
+ _handler = logging.FileHandler(str(_log_path), encoding="utf-8")
+ _handler.setLevel(logging.INFO)
+ _handler.setFormatter(logging.Formatter("%(message)s"))
+ _logger.addHandler(_handler)
+ except Exception:
+ pass
\ No newline at end of file
diff --git a/cais/iv_llm/src/__init__.py b/cais/iv_llm/src/__init__.py
new file mode 100644
index 0000000..218cfdf
--- /dev/null
+++ b/cais/iv_llm/src/__init__.py
@@ -0,0 +1 @@
+# IV-LLM src package
\ No newline at end of file
diff --git a/cais/iv_llm/src/agents/__init__.py b/cais/iv_llm/src/agents/__init__.py
new file mode 100644
index 0000000..9ba3a4d
--- /dev/null
+++ b/cais/iv_llm/src/agents/__init__.py
@@ -0,0 +1 @@
+# Agents package
\ No newline at end of file
diff --git a/cais/iv_llm/src/agents/confounder_miner.py b/cais/iv_llm/src/agents/confounder_miner.py
new file mode 100644
index 0000000..939874b
--- /dev/null
+++ b/cais/iv_llm/src/agents/confounder_miner.py
@@ -0,0 +1,63 @@
+import json
+import logging
+from typing import List
+
+from langchain_core.language_models import BaseChatModel
+
+from cais.utils.llm_helpers import invoke_llm
+from ..prompts.prompt_loader import PromptLoader
+from ..variable_utils import extract_available_columns, filter_to_available
+
+logger = logging.getLogger(__name__)
+
+class ConfounderMiner:
+ def __init__(self, llm: BaseChatModel, j: int = 5) -> None:
+ self.llm = llm
+ self.j = j
+ self.prompt_loader = PromptLoader()
+
+ def identify_confounders(self, treatment: str, outcome: str, context: str = "") -> List[str]:
+ prompt = self.prompt_loader.format_confounder_prompt(treatment, outcome, self.j, context=context)
+ response = invoke_llm(self.llm, prompt)
+ confounders_raw = self._parse_confounders(response)
+
+ available_cols = extract_available_columns(context)
+ confounders = (
+ filter_to_available(confounders_raw, available_cols)
+ if available_cols
+ else confounders_raw
+ )
+
+ confounders = confounders[: self.j]
+
+ logger.info(json.dumps({
+ 'name': 'confounder_miner',
+ 'inputs': {'treatment': treatment, 'outcome': outcome, 'j': self.j},
+ 'outputs': {'confounders': confounders},
+ 'raw_response': response,
+ }, default=str))
+
+ return confounders
+
+ def _parse_confounders(self, response: str) -> List[str]:
+ import re
+
+ def _clean(name: str) -> str:
+ return name.strip().strip('"\'').strip('`').strip('*').strip()
+
+ # Try XML format first
+ match = re.search(r'\[(.*?)\]', response)
+ if match:
+ confounders_str = match.group(1)
+ confounders = [_clean(c) for c in confounders_str.split(',')]
+ return confounders[:self.j]
+
+ # Fallback: look for bracket format without XML
+ bracket_match = re.search(r'\[([^\]]+)\]', response)
+ if bracket_match:
+ confounders_str = bracket_match.group(1)
+ confounders = [_clean(c) for c in confounders_str.split(',')]
+ return confounders[:self.j]
+
+ print(f"WARNING: Could not parse confounders from response: {response[:200]}...")
+ return []
\ No newline at end of file
diff --git a/cais/iv_llm/src/agents/hypothesizer.py b/cais/iv_llm/src/agents/hypothesizer.py
new file mode 100644
index 0000000..93df678
--- /dev/null
+++ b/cais/iv_llm/src/agents/hypothesizer.py
@@ -0,0 +1,78 @@
+import json
+import logging
+from typing import List
+
+from langchain_core.language_models import BaseChatModel
+
+from cais.utils.llm_helpers import invoke_llm
+from ..prompts.prompt_loader import PromptLoader
+from ..variable_utils import extract_available_columns, filter_to_available, fallback_candidates
+
+logger = logging.getLogger(__name__)
+
+class Hypothesizer:
+ def __init__(self, llm: BaseChatModel, k: int = 5) -> None:
+ self.llm = llm
+ self.k = k
+ self.prompt_loader = PromptLoader()
+
+ def propose_ivs(self, treatment: str, outcome: str, context: str = "") -> List[str]:
+ prompt = self.prompt_loader.format_hypothesizer_prompt(treatment, outcome, self.k, context=context)
+ response = invoke_llm(self.llm, prompt)
+ ivs_raw = self._parse_ivs(response)
+
+ available_cols = extract_available_columns(context)
+ ivs = filter_to_available(ivs_raw, available_cols) if available_cols else ivs_raw
+
+ if available_cols and not ivs:
+ ivs = fallback_candidates(available_cols, exclude=[treatment, outcome])[: self.k]
+
+ ivs = ivs[: self.k]
+
+ logger.info(json.dumps({
+ 'name': 'hypothesizer',
+ 'inputs': {'treatment': treatment, 'outcome': outcome, 'k': self.k},
+ 'outputs': {'proposed_ivs': ivs},
+ 'raw_response': response,
+ }, default=str))
+
+ return ivs
+
+ def _parse_ivs(self, response: str) -> List[str]:
+ import re
+
+ def _clean(name: str) -> str:
+ return name.strip().strip('"\'').strip('`').strip('*').strip()
+
+ # Try XML format first
+ match = re.search(r'\[(.*?)\]', response)
+ if match:
+ ivs_str = match.group(1)
+ ivs = [_clean(iv) for iv in ivs_str.split(',')]
+ return ivs[:self.k]
+
+ # Fallback: look for bracket format
+ bracket_match = re.search(r'\[([^\]]+)\]', response)
+ if bracket_match:
+ ivs_str = bracket_match.group(1)
+ ivs = [_clean(iv) for iv in ivs_str.split(',')]
+ return ivs[:self.k]
+
+ # Fallback: look for numbered list format
+ lines = response.split('\n')
+ ivs = []
+ for line in lines:
+ line = line.strip()
+ # Match patterns like "1. Something:" or "- Something:"
+ if re.match(r'^\d+\.\s+(.+?):', line):
+ iv = _clean(re.match(r'^\d+\.\s+(.+?):', line).group(1))
+ ivs.append(iv)
+ elif re.match(r'^-\s+(.+?):', line):
+ iv = _clean(re.match(r'^-\s+(.+?):', line).group(1))
+ ivs.append(iv)
+
+ if ivs:
+ return ivs[:self.k]
+
+ print(f"WARNING: Could not parse IVs from response: {response[:200]}...")
+ return []
\ No newline at end of file
diff --git a/cais/iv_llm/src/critics/__init__.py b/cais/iv_llm/src/critics/__init__.py
new file mode 100644
index 0000000..558ec56
--- /dev/null
+++ b/cais/iv_llm/src/critics/__init__.py
@@ -0,0 +1 @@
+# critics package
\ No newline at end of file
diff --git a/cais/iv_llm/src/critics/exclusion_critic.py b/cais/iv_llm/src/critics/exclusion_critic.py
new file mode 100644
index 0000000..b2123b1
--- /dev/null
+++ b/cais/iv_llm/src/critics/exclusion_critic.py
@@ -0,0 +1,37 @@
+import json
+import logging
+from typing import List
+
+from langchain_core.language_models import BaseChatModel
+
+from cais.utils.llm_helpers import invoke_llm
+from ..prompts.prompt_loader import PromptLoader
+
+logger = logging.getLogger(__name__)
+
+class ExclusionCritic:
+ def __init__(self, llm: BaseChatModel) -> None:
+ self.llm = llm
+ self.prompt_loader = PromptLoader()
+
+ def validate_exclusion(self, iv: str, treatment: str, outcome: str, confounders: List[str]) -> bool:
+ # Exclusion restriction: does IV affect outcome only through treatment?
+ # Confounders not directly relevant here - just check direct pathways
+ prompt = self.prompt_loader.format_exclusion_prompt(iv, treatment, outcome, confounders)
+ response = invoke_llm(self.llm, prompt)
+ result = self._parse_validity(response)
+
+ # Log detailed output
+ logger.info(json.dumps({
+ 'name': f'exclusion_critic_{iv}',
+ 'inputs': {'iv': iv, 'treatment': treatment, 'outcome': outcome, 'confounders': confounders},
+ 'outputs': {'valid': result},
+ 'raw_response': response,
+ }, default=str))
+
+ return result
+
+ def _parse_validity(self, response: str) -> bool:
+ import re
+ match = re.search(r'(Valid|Invalid)', response)
+ return match.group(1) == 'Valid' if match else False
\ No newline at end of file
diff --git a/cais/iv_llm/src/critics/independence_critic.py b/cais/iv_llm/src/critics/independence_critic.py
new file mode 100644
index 0000000..390712e
--- /dev/null
+++ b/cais/iv_llm/src/critics/independence_critic.py
@@ -0,0 +1,49 @@
+import json
+import logging
+from typing import List
+
+from langchain_core.language_models import BaseChatModel
+
+from cais.utils.llm_helpers import invoke_llm
+from ..prompts.prompt_loader import PromptLoader
+
+logger = logging.getLogger(__name__)
+
+class IndependenceCritic:
+ def __init__(self, llm: BaseChatModel) -> None:
+ self.llm = llm
+ self.prompt_loader = PromptLoader()
+
+ def validate_independence(self, iv: str, treatment: str, outcome: str, confounders: List[str]) -> bool:
+ # Independence: check IV against each confounder separately
+ responses = {}
+
+ for confounder in confounders:
+ prompt = self.prompt_loader.format_independence_prompt(iv, treatment, outcome, confounder)
+ response = invoke_llm(self.llm, prompt)
+ responses[confounder] = response
+
+ if not self._parse_validity(response):
+ # Log detailed output
+ logger.info(json.dumps({
+ 'name': f'independence_critic_{iv}',
+ 'inputs': {'iv': iv, 'treatment': treatment, 'outcome': outcome, 'confounders': confounders},
+ 'outputs': {'valid': False, 'failed_on': confounder},
+ 'raw_response': responses,
+ }, default=str))
+ return False
+
+ # Log successful validation
+ logger.info(json.dumps({
+ 'name': f'independence_critic_{iv}',
+ 'inputs': {'iv': iv, 'treatment': treatment, 'outcome': outcome, 'confounders': confounders},
+ 'outputs': {'valid': True},
+ 'raw_response': responses,
+ }, default=str))
+
+ return True
+
+ def _parse_validity(self, response: str) -> bool:
+ import re
+ match = re.search(r'(Valid|Invalid)', response)
+ return match.group(1) == 'Valid' if match else False
\ No newline at end of file
diff --git a/cais/iv_llm/src/experiments/iv_co_scientist.py b/cais/iv_llm/src/experiments/iv_co_scientist.py
new file mode 100644
index 0000000..49d840c
--- /dev/null
+++ b/cais/iv_llm/src/experiments/iv_co_scientist.py
@@ -0,0 +1,173 @@
+import yaml
+import json
+import pandas as pd
+from ..causal_analysis.correlation import analyze_data
+from ..agents.human_proxy import HumanProxy
+from ..agents.causal_oracle import CausalOracle
+from ..agents.hypothesizer import Hypothesizer
+from ..agents.confounder_miner import ConfounderMiner
+from ..critics.exclusion_critic import ExclusionCritic
+from ..critics.independence_critic import IndependenceCritic
+from ..agents.grounder import Grounder
+from ..analysis.iv_estimator import IVEstimator
+from ..llm.client import LLMClient
+import os
+
+class IVCoScientist:
+ def __init__(self, config):
+ # Accept either config dict or config path
+ if isinstance(config, str):
+ with open(config, 'r') as f:
+ self.config = yaml.safe_load(f)
+ else:
+ self.config = config
+
+ self.dataset_path = self.config['dataset']['path']
+ self.llm_client = LLMClient(self.config['llm'])
+
+ # Initialize agents
+ self.human_proxy = HumanProxy(self.llm_client, max_pairs=self.config['agents']['human_proxy']['max_pairs'])
+ self.causal_oracle = CausalOracle(self.llm_client, self.dataset_path)
+ self.hypothesizer = Hypothesizer(self.llm_client, k=self.config['agents']['hypothesizer']['k_ivs'])
+ self.confounder_miner = ConfounderMiner(self.llm_client, j=self.config['agents']['confounder_miner']['j_confounders'])
+ self.exclusion_critic = ExclusionCritic(self.llm_client)
+ self.independence_critic = IndependenceCritic(self.llm_client)
+ self.grounder = Grounder(self.dataset_path, threshold=self.config['agents']['grounder']['threshold'])
+ self.estimator = IVEstimator(self.dataset_path, f_threshold=self.config['estimation']['f_threshold'])
+
+ def run_discovery(self):
+ """Run full IV discovery pipeline"""
+ print("๐ Starting IV Co-Scientist Discovery...")
+
+ # Stage 1: PreSelector - Filter correlations
+ print("\n๐ Stage 1: PreSelector - Analyzing correlations...")
+ correlation_pairs = self._preselector()
+ print(f"Found {len(correlation_pairs)} correlated pairs")
+
+ if not correlation_pairs:
+ return {"error": "No significant correlations found"}
+
+ # Stage 2: HumanProxy - Select meaningful pairs
+ print("\n๐ง Stage 2: HumanProxy - Selecting meaningful pairs...")
+ meaningful_pairs = self.human_proxy.select_meaningful_pairs(correlation_pairs)
+ print(f"Selected {len(meaningful_pairs)} meaningful pairs")
+
+ # Stage 3: CausalOracle - Infer direction
+ print("\n๐ฎ Stage 3: CausalOracle - Inferring causal direction...")
+ causal_pairs = []
+ for pair in meaningful_pairs:
+ direction = self.causal_oracle.infer_direction(pair['variable1'], pair['variable2'])
+ if direction:
+ causal_pairs.append({
+ 'treatment': direction[0],
+ 'outcome': direction[1],
+ 'original_pair': pair
+ })
+ print(f"Identified {len(causal_pairs)} directional pairs")
+
+ # Process each causal pair
+ results = []
+ for i, causal_pair in enumerate(causal_pairs):
+ print(f"\n๐ฏ Processing pair {i+1}/{len(causal_pairs)}: {causal_pair['treatment']} โ {causal_pair['outcome']}")
+ result = self._process_causal_pair(causal_pair)
+ if result:
+ results.append(result)
+
+ return {
+ 'total_correlations': len(correlation_pairs),
+ 'meaningful_pairs': len(meaningful_pairs),
+ 'causal_pairs': len(causal_pairs),
+ 'successful_discoveries': len(results),
+ 'results': results
+ }
+
+ def _preselector(self):
+ """Filter variable pairs by correlation and sample size"""
+ # Check if correlation results already exist
+ correlation_file = "gapminder_correlations.csv"
+
+ if not os.path.exists(correlation_file):
+ print("Running correlation analysis first...")
+ from ..analyze_gapminder import analyze_gapminder_folder
+ analyze_gapminder_folder(self.dataset_path)
+
+ # Load existing correlation results
+ import pandas as pd
+ df = pd.read_csv(correlation_file, sep=';')
+
+ # Convert to expected format and filter
+ correlation_pairs = []
+ for _, row in df.iterrows():
+ if self._passes_filter(row.to_dict()):
+ correlation_pairs.append(row.to_dict())
+
+ return correlation_pairs[:self.config['preselector'].get('max_pairs', 20)]
+
+ def _passes_filter(self, result):
+ """Check if correlation pair passes thresholds"""
+ return (abs(result['correlation']) > self.config['preselector']['correlation_threshold'] and
+ result['data_points'] > self.config['preselector']['min_data_points'])
+
+ def _process_causal_pair(self, causal_pair):
+ """Process single causal pair through full pipeline"""
+ treatment = causal_pair['treatment']
+ outcome = causal_pair['outcome']
+
+ # Stage 4: Generate IVs and confounders
+ print(f" ๐ Generating IVs for {treatment} โ {outcome}")
+ proposed_ivs = self.hypothesizer.propose_ivs(treatment, outcome)
+ confounders = self.confounder_miner.identify_confounders(treatment, outcome)
+
+ # Stage 5: Validate IVs
+ print(f" โ
Validating {len(proposed_ivs)} proposed IVs")
+ valid_ivs = self._validate_ivs(proposed_ivs, treatment, outcome, confounders)
+
+ if not valid_ivs:
+ print(f" โ No valid IVs found")
+ return None
+
+ # Stage 6: Ground IVs to dataset
+ print(f" ๐ฏ Grounding {len(valid_ivs)} valid IVs to dataset")
+ grounded_ivs = self.grounder.ground_ivs(valid_ivs)
+
+ if not grounded_ivs:
+ print(f" โ No IVs could be grounded to dataset")
+ return None
+
+ # Stage 7: Estimate causal effects
+ print(f" ๐ Estimating effects with {len(grounded_ivs)} grounded IVs")
+ instrument_vars = [pair[1] for pair in grounded_ivs] # Extract proxy variables
+ estimation_results = self.estimator.estimate_iv_effect(treatment, outcome, instrument_vars)
+
+ valid_estimates = [est for est in estimation_results if est and est['relevance_check']]
+
+ return {
+ 'treatment': treatment,
+ 'outcome': outcome,
+ 'proposed_ivs': proposed_ivs,
+ 'valid_ivs': valid_ivs,
+ 'grounded_ivs': grounded_ivs,
+ 'estimation_results': valid_estimates,
+ 'success': len(valid_estimates) > 0
+ }
+
+ def _validate_ivs(self, proposed_ivs, treatment, outcome, confounders):
+ """Validate IVs using critics (parallel evaluation)"""
+ exclusion_results = {}
+ independence_results = {}
+
+ # Exclusion critic
+ for iv in proposed_ivs:
+ exclusion_results[iv] = self.exclusion_critic.validate_exclusion(iv, treatment, outcome, confounders)
+
+ # Independence critic
+ for iv in proposed_ivs:
+ independence_results[iv] = self.independence_critic.validate_independence(iv, treatment, outcome, confounders)
+
+ # Keep only IVs that pass both tests
+ valid_ivs = []
+ for iv in proposed_ivs:
+ if exclusion_results[iv] and independence_results[iv]:
+ valid_ivs.append(iv)
+
+ return valid_ivs
\ No newline at end of file
diff --git a/cais/iv_llm/src/llm/__init__.py b/cais/iv_llm/src/llm/__init__.py
new file mode 100644
index 0000000..a845b2c
--- /dev/null
+++ b/cais/iv_llm/src/llm/__init__.py
@@ -0,0 +1 @@
+# llm package
\ No newline at end of file
diff --git a/cais/iv_llm/src/llm/client.py b/cais/iv_llm/src/llm/client.py
new file mode 100644
index 0000000..b551bd4
--- /dev/null
+++ b/cais/iv_llm/src/llm/client.py
@@ -0,0 +1,29 @@
+from __future__ import annotations
+
+from typing import Any
+from cais.config import get_llm_client
+from langchain_core.language_models import BaseChatModel
+
+
+def invoke_llm(llm: BaseChatModel, prompt: str) -> str:
+ """Call llm with prompt and return the text content"""
+ response = llm.invoke(prompt)
+ if hasattr(response, "content"):
+ return response.content
+ return str(response)
+
+
+class LLMClient:
+ """Thin adapter that wraps a BaseChatModel"""
+ def __init__(self, config: Any = None, use_cache: bool = False) -> None:
+ if isinstance(config, dict):
+ self._llm = get_llm_client(
+ provider=config.get("provider"),
+ model_name=config.get("model"),
+ )
+ else:
+ self._llm = get_llm_client()
+
+ def invoke(self, prompt: str):
+ """Delegate to the underlying LangChain model."""
+ return self._llm.invoke(prompt)
\ No newline at end of file
diff --git a/cais/iv_llm/src/prompts/__init__.py b/cais/iv_llm/src/prompts/__init__.py
new file mode 100644
index 0000000..03a8353
--- /dev/null
+++ b/cais/iv_llm/src/prompts/__init__.py
@@ -0,0 +1 @@
+# prompts package
\ No newline at end of file
diff --git a/cais/iv_llm/src/prompts/confounder_miner.txt b/cais/iv_llm/src/prompts/confounder_miner.txt
new file mode 100644
index 0000000..81d1147
--- /dev/null
+++ b/cais/iv_llm/src/prompts/confounder_miner.txt
@@ -0,0 +1,16 @@
+You are a causality expert proposing instrumental variables.
+
+For the causal relationship '{treatment}' โ '{outcome}', identify {j} potential confounders that could bias the causal estimate.
+
+A confounder affects both treatment and outcome, creating spurious correlation.
+
+CRITICAL CONSTRAINT:
+If the context includes a list of dataset columns (e.g. "Available columns: ..."), you MUST ONLY return confounders that are EXACTLY from that list.
+- Output must be EXACT column names from the dataset.
+- Do NOT invent new variables.
+- Do NOT use bolding/markdown, explanations, or prose.
+
+Return up to {j} candidate confounders from the available columns.
+If no confounders exist among the available columns, return an empty list.
+
+[col_name_1, col_name_2, ...]
\ No newline at end of file
diff --git a/cais/iv_llm/src/prompts/exclusion_critic.txt b/cais/iv_llm/src/prompts/exclusion_critic.txt
new file mode 100644
index 0000000..eda71a2
--- /dev/null
+++ b/cais/iv_llm/src/prompts/exclusion_critic.txt
@@ -0,0 +1,9 @@
+You are a causality expert proposing instrumental variables.
+
+Evaluate whether '{iv}' satisfies the exclusion restriction for '{treatment}' โ '{outcome}'.
+
+Be generous - accept instruments that are plausibly valid with standard econometric controls. Reject only if there are major, uncontrollable direct pathways.
+
+Is '{iv}' acceptable as an instrument with proper research design?
+
+Valid or Invalid
\ No newline at end of file
diff --git a/cais/iv_llm/src/prompts/hypothesizer.txt b/cais/iv_llm/src/prompts/hypothesizer.txt
new file mode 100644
index 0000000..4c10c03
--- /dev/null
+++ b/cais/iv_llm/src/prompts/hypothesizer.txt
@@ -0,0 +1,12 @@
+You are a causality expert proposing instrumental variables.
+
+CRITICAL CONSTRAINT:
+If the context includes a list of dataset columns (e.g. "Available columns: ..."), you MUST ONLY propose instruments that are EXACTLY from that list.
+- Output must be EXACT column names from the dataset.
+- Do NOT invent new variables.
+- Do NOT use bolding/markdown, explanations, or prose.
+
+For '{treatment}' โ '{outcome}', suggest up to {k} candidate instruments from the available columns.
+If no valid instruments can be formed from the available columns, return an empty list.
+
+[col_name_1, col_name_2, ...]
\ No newline at end of file
diff --git a/cais/iv_llm/src/prompts/independence_critic.txt b/cais/iv_llm/src/prompts/independence_critic.txt
new file mode 100644
index 0000000..add60fe
--- /dev/null
+++ b/cais/iv_llm/src/prompts/independence_critic.txt
@@ -0,0 +1,16 @@
+You are a causality expert proposing instrumental variables.
+
+Evaluate whether '{iv}' satisfies the independence assumption for '{treatment}' โ '{outcome}' with respect to the confounder '{confounder}'.
+
+INDEPENDENCE ASSUMPTION: Accept if the instrument can achieve conditional independence through standard controls.
+
+Be generous - accept if:
+- Correlations can be controlled with fixed effects, trends, or observables
+- The instrument has quasi-experimental variation
+- It's used in credible economics research
+
+Reject only if there are fundamental, uncontrollable correlations that would bias results even with extensive controls.
+
+Is '{iv}' acceptable with proper econometric controls?
+
+Valid or Invalid
\ No newline at end of file
diff --git a/cais/iv_llm/src/prompts/prompt_loader.py b/cais/iv_llm/src/prompts/prompt_loader.py
new file mode 100644
index 0000000..0ffbbf5
--- /dev/null
+++ b/cais/iv_llm/src/prompts/prompt_loader.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+class PromptLoader:
+ def __init__(self, prompts_dir: str | Path | None = None) -> None:
+ # Default to the prompts folder shipped with this package.
+ self.prompts_dir = Path(prompts_dir) if prompts_dir is not None else Path(__file__).resolve().parent
+
+ def load_prompt(self, prompt_name: str) -> str:
+ prompt_path = Path(self.prompts_dir) / f"{prompt_name}.txt"
+ with open(prompt_path, "r", encoding="utf-8") as f:
+ return f.read()
+
+ def format_hypothesizer_prompt(self, treatment: str, outcome: str, k: int = 5, context: str = "") -> str:
+ template = self.load_prompt("hypothesizer")
+ context_text = f"\n\nAdditional context: {context}" if context else ""
+ return template.format(treatment=treatment, outcome=outcome, k=k) + context_text
+
+ def format_confounder_prompt(self, treatment: str, outcome: str, j: int = 5, context: str = "") -> str:
+ template = self.load_prompt("confounder_miner")
+ context_text = f"\n\nAdditional context: {context}" if context else ""
+ return template.format(treatment=treatment, outcome=outcome, j=j) + context_text
+
+ def format_exclusion_prompt(
+ self,
+ iv: str,
+ treatment: str,
+ outcome: str,
+ confounders: list[str] | None = None,
+ context: str = "",
+ ) -> str:
+ template = self.load_prompt("exclusion_critic")
+ context_text = f"\n\nAdditional context: {context}" if context else ""
+ return template.format(iv=iv, treatment=treatment, outcome=outcome) + context_text
+
+ def format_independence_prompt(
+ self,
+ iv: str,
+ treatment: str,
+ outcome: str,
+ confounder: str,
+ context: str = "",
+ ) -> str:
+ template = self.load_prompt("independence_critic")
+ context_text = f"\n\nAdditional context: {context}" if context else ""
+ return template.format(iv=iv, treatment=treatment, outcome=outcome,
+ confounder=confounder) + context_text
+
+ def format_conceptual_equivalence_prompt(self, proposed_iv: str, gold_ivs: str | list[str]) -> str:
+ template = self.load_prompt("conceptual_equivalence")
+ return template.format(proposed_iv=proposed_iv, gold_ivs=gold_ivs)
+
+ def format_human_proxy_prompt(self, variable1: str, variable2: str) -> str:
+ template = self.load_prompt("human_proxy")
+ return template.format(variable1=variable1, variable2=variable2)
\ No newline at end of file
diff --git a/cais/iv_llm/src/variable_utils.py b/cais/iv_llm/src/variable_utils.py
new file mode 100644
index 0000000..f032dc6
--- /dev/null
+++ b/cais/iv_llm/src/variable_utils.py
@@ -0,0 +1,86 @@
+from __future__ import annotations
+
+import re
+from typing import Iterable, Optional
+
+
+def extract_available_columns(context: str) -> list[str]:
+ """Extracts available dataset column names from a context string.
+
+ Expected pattern in context (as used in tests):
+ "Available columns: col_a, col_b, col_c."
+ """
+
+ if not context:
+ return []
+
+ # Capture everything after "Available columns:" up to newline or end.
+ match = re.search(r"Available columns\s*:\s*(.+)", context, flags=re.IGNORECASE)
+ if not match:
+ return []
+
+ raw = match.group(1)
+ # Stop at newline; and strip trailing sentence punctuation.
+ raw = raw.splitlines()[0].strip().rstrip(". ")
+
+ parts = [p.strip() for p in re.split(r"[,;]", raw) if p.strip()]
+ # Strip quotes/backticks/markdown, preserve original spelling.
+ cols: list[str] = []
+ for part in parts:
+ cleaned = _strip_formatting(part)
+ if cleaned:
+ cols.append(cleaned)
+ return cols
+
+
+def _strip_formatting(value: str) -> str:
+ value = value.strip()
+ # remove markdown bold/italics/backticks and surrounding quotes
+ value = value.strip("`" )
+ value = value.strip().strip("\"'")
+ value = value.strip("*")
+ # remove trailing type/category suffixes like "col_name (binary)"
+ value = re.sub(r"\s*\([^)]*\)\s*$", "", value)
+ return value.strip()
+
+
+def normalize_name(value: str) -> str:
+ """Normalization for matching LLM outputs to dataset columns."""
+
+ value = _strip_formatting(value)
+ value = value.lower()
+ # allow space/underscore interchange and remove other punctuation
+ value = re.sub(r"[\s\-]+", "_", value)
+ value = re.sub(r"[^a-z0-9_]", "", value)
+ value = re.sub(r"_+", "_", value).strip("_")
+ return value
+
+
+def map_to_available(value: str, available: Iterable[str]) -> Optional[str]:
+ """Map a candidate name to an exact available column name (or None)."""
+
+ target = normalize_name(value)
+ if not target:
+ return None
+
+ available_list = list(available)
+ normalized_map = {normalize_name(c): c for c in available_list}
+ return normalized_map.get(target)
+
+
+def filter_to_available(values: Iterable[str], available: Iterable[str]) -> list[str]:
+ available_list = list(available)
+ seen: set[str] = set()
+ kept: list[str] = []
+ for v in values:
+ mapped = map_to_available(v, available_list)
+ if mapped and mapped not in seen:
+ seen.add(mapped)
+ kept.append(mapped)
+ return kept
+
+
+def fallback_candidates(available: Iterable[str], *, exclude: Iterable[str] = ()) -> list[str]:
+ available_list = list(available)
+ excluded_norm = {normalize_name(x) for x in exclude}
+ return [c for c in available_list if normalize_name(c) not in excluded_norm]
diff --git a/cais/models.py b/cais/models.py
index 7fd5cf7..d475803 100644
--- a/cais/models.py
+++ b/cais/models.py
@@ -147,6 +147,22 @@ class MethodSelectorInput(BaseModel):
original_query: Optional[str] = None
# Note: is_rct is expected inside inputs.variables
+class IVDiscoveryInput(BaseModel):
+ """Input structure for the IV discovery tool."""
+ variables: Variables
+ dataset_analysis: DatasetAnalysis
+ dataset_description: Optional[str] = None
+ original_query: Optional[str] = None
+
+class IVDiscoveryOutput(BaseModel):
+ """Structured output for the IV discovery tool."""
+ variables: Variables
+ dataset_analysis: DatasetAnalysis
+ dataset_description: Optional[str] = None
+ original_query: Optional[str] = None
+ iv_discovery_results: Dict[str, Any] = Field(default_factory=dict)
+ workflow_state: Dict[str, Any] = Field(default_factory=dict)
+
# --- Models for Method Validator Tool ---
class MethodInfo(BaseModel):
diff --git a/cais/tools/controls_selector_tool.py b/cais/tools/controls_selector_tool.py
index 301c50b..f538554 100644
--- a/cais/tools/controls_selector_tool.py
+++ b/cais/tools/controls_selector_tool.py
@@ -7,7 +7,7 @@
import logging
from typing import Dict, Any, Optional
-#from langchain_core.tools import tool
+from langchain.tools import tool
# Import component function and central LLM factory
from cais.components.controls_selector import select_controls
@@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)
-#@tool(args_schema=ControlsSelectorInput)
+@tool(args_schema=ControlsSelectorInput)
def controls_selector_tool(
method_name: str,
variables: Variables,
diff --git a/cais/tools/iv_discovery_tool.py b/cais/tools/iv_discovery_tool.py
new file mode 100644
index 0000000..90ffa04
--- /dev/null
+++ b/cais/tools/iv_discovery_tool.py
@@ -0,0 +1,143 @@
+"""
+Tool for discovering instrumental variables using IV-LLM.
+
+This module provides a LangChain tool for discovering valid instrumental variables
+for given treatment and outcome variables using the IV-LLM pipeline.
+"""
+
+from typing import Dict, List, Any, Optional
+from langchain.tools import tool
+import logging
+
+from cais.components.iv_discovery import discover_instruments
+from cais.components.state_manager import create_workflow_state_update
+
+from cais.models import Variables, DatasetAnalysis, IVDiscoveryInput, IVDiscoveryOutput
+
+logger = logging.getLogger(__name__)
+
+@tool(args_schema=IVDiscoveryInput)
+def iv_discovery_tool(
+ variables: Variables,
+ dataset_analysis: DatasetAnalysis,
+ dataset_description: Optional[str] = None,
+ original_query: Optional[str] = None
+) -> IVDiscoveryOutput:
+ """
+ Discover valid instrumental variables for the identified treatment and outcome.
+
+ Uses the IV-LLM pipeline to hypothesize potential instruments and validate them
+ using exclusion and independence criteria. If valid instruments are found,
+ updates the variables with the instrument_variable.
+
+ Args:
+ variables: Pydantic model containing identified variables (treatment, outcome, etc.)
+ dataset_analysis: Pydantic model containing dataset analysis results
+ dataset_description: Optional textual description of the dataset
+ original_query: Optional original user query string
+
+ Returns:
+ Updated variables with instrument_variable if found, plus discovery results and workflow state
+ """
+ logger.info("Running iv_discovery_tool")
+
+ # Extract treatment and outcome
+ treatment = variables.treatment_variable
+ outcome = variables.outcome_variable
+
+ if not treatment or not outcome:
+ logger.warning("No treatment or outcome variable identified, skipping IV discovery")
+ workflow_update = create_workflow_state_update(
+ current_step="iv_discovery",
+ step_completed_flag=True, # Completed but no IVs found
+ next_tool="method_selector_tool",
+ next_step_reason="No treatment/outcome variables available for IV discovery"
+ )
+ return IVDiscoveryOutput(
+ variables=variables,
+ dataset_analysis=dataset_analysis,
+ dataset_description=dataset_description,
+ original_query=original_query,
+ iv_discovery_results={
+ 'proposed_ivs': [],
+ 'valid_ivs': [],
+ 'validation_results': []
+ },
+ workflow_state=workflow_update.get('workflow_state', {})
+ )
+
+ # Prepare context from dataset description and analysis
+ context_parts = []
+ if dataset_description:
+ context_parts.append(dataset_description)
+
+ # Add column information
+ columns = dataset_analysis.columns or []
+ column_categories = dataset_analysis.column_categories or {}
+ if columns:
+ column_info = []
+ for col in columns:
+ category = column_categories.get(col, 'unknown')
+ column_info.append(f"{col} ({category})")
+ context_parts.append("Available columns: " + ", ".join(column_info))
+
+ context = ". ".join(context_parts)
+
+ # Get confounders from variables if available
+ confounders = variables.covariates or []
+
+ try:
+ # Run IV discovery
+ discovery_results = discover_instruments(
+ treatment=treatment,
+ outcome=outcome,
+ context=context,
+ confounders=confounders
+ )
+
+ # Update variables if valid IVs found
+ updated_variables = variables.model_copy()
+ valid_ivs = discovery_results.get('valid_ivs', [])
+ if valid_ivs:
+ # Select the first valid IV (could be enhanced to select best one)
+ updated_variables.instrument_variable = valid_ivs[0]
+ logger.info(f"Found valid instrument: {valid_ivs[0]}")
+
+ # Create workflow state
+ workflow_update = create_workflow_state_update(
+ current_step="iv_discovery",
+ step_completed_flag=True,
+ next_tool="method_selector_tool",
+ next_step_reason="IV discovery completed, proceeding to method selection"
+ )
+
+ return IVDiscoveryOutput(
+ variables=updated_variables,
+ dataset_analysis=dataset_analysis,
+ dataset_description=dataset_description,
+ original_query=original_query,
+ iv_discovery_results=discovery_results,
+ workflow_state=workflow_update.get('workflow_state', {})
+ )
+
+ except Exception as e:
+ logger.error(f"Error during IV discovery: {e}", exc_info=True)
+ workflow_update = create_workflow_state_update(
+ current_step="iv_discovery",
+ step_completed_flag=False,
+ next_tool="method_selector_tool",
+ next_step_reason=f"IV discovery failed: {e}"
+ )
+ return IVDiscoveryOutput(
+ variables=variables,
+ dataset_analysis=dataset_analysis,
+ dataset_description=dataset_description,
+ original_query=original_query,
+ iv_discovery_results={
+ 'proposed_ivs': [],
+ 'valid_ivs': [],
+ 'validation_results': [],
+ 'error': str(e)
+ },
+ workflow_state=workflow_update.get('workflow_state', {})
+ )
\ No newline at end of file
diff --git a/cais/utils/agent.py b/cais/utils/agent.py
new file mode 100644
index 0000000..a9c7337
--- /dev/null
+++ b/cais/utils/agent.py
@@ -0,0 +1,346 @@
+"""
+LangChain agent for the cais module.
+
+This module configures a LangChain agent with specialized tools for causal inference,
+allowing for an interactive approach to analyzing datasets and applying appropriate
+causal inference methods.
+"""
+
+import logging
+from typing import Dict, List, Any, Optional
+from langchain.agents.react.agent import create_react_agent
+from langchain.agents import AgentExecutor, create_structured_chat_agent, create_tool_calling_agent
+from langchain.chains.conversation.memory import ConversationBufferMemory
+from langchain_core.messages import SystemMessage, HumanMessage
+from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
+from langchain.tools import tool
+
+from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
+from langchain.tools.render import render_text_description
+from langchain.agents.format_scratchpad.tools import format_to_tool_messages
+from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
+from langchain_core.runnables import RunnablePassthrough
+from langchain_core.language_models import BaseChatModel
+from langchain_anthropic.chat_models import convert_to_anthropic_tool
+from cais.tools.input_parser_tool import input_parser_tool
+from cais.tools.dataset_analyzer_tool import dataset_analyzer_tool
+from cais.tools.query_interpreter_tool import query_interpreter_tool
+from cais.tools.method_selector_tool import method_selector_tool
+from cais.tools.method_validator_tool import method_validator_tool
+from cais.tools.method_executor_tool import method_executor_tool
+from cais.tools.explanation_generator_tool import explanation_generator_tool
+from cais.tools.output_formatter_tool import output_formatter_tool
+from langchain_core.output_parsers import StrOutputParser
+from .config import get_llm_client
+from langchain_core.messages import AIMessage, AIMessageChunk
+import re
+import json
+from typing import Union
+from langchain_core.output_parsers import BaseOutputParser
+from langchain.schema import AgentAction, AgentFinish
+from langchain_anthropic.output_parsers import ToolsOutputParser
+from langchain.agents.react.output_parser import ReActOutputParser
+from langchain.agents import AgentOutputParser
+from langchain.agents.agent import AgentAction, AgentFinish, OutputParserException
+import re
+from typing import Union, List
+
+from langchain_core.agents import AgentAction, AgentFinish
+from langchain_core.exceptions import OutputParserException
+
+from langchain.agents.agent import AgentOutputParser
+from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
+
+FINAL_ANSWER_ACTION = "Final Answer:"
+MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
+ "Invalid Format: Missing 'Action:' after 'Thought:'"
+)
+MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE = (
+ "Invalid Format: Missing 'Action Input:' after 'Action:'"
+)
+FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
+ "Parsing LLM output produced both a final answer and parse-able actions"
+)
+
+
+class ReActMultiInputOutputParser(AgentOutputParser):
+ """Parses ReAct-style output that may contain multiple tool calls."""
+
+ def get_format_instructions(self) -> str:
+ # You can reuse the original FORMAT_INSTRUCTIONS,
+ # but let the model know it may emit multiple actions.
+ return FORMAT_INSTRUCTIONS + (
+ "\n\nIf you need to call more than one tool, simply repeat:\n"
+ "Action: \n"
+ "Action Input: \n"
+ "โฆfor each tool in sequence."
+ )
+
+ @property
+ def _type(self) -> str:
+ return "react-multi-input"
+
+ def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
+ includes_answer = FINAL_ANSWER_ACTION in text
+ print('-------------------')
+ print(text)
+ print('-------------------')
+ # Grab every Action / Action Input block
+ pattern = (
+ r"Action\s*\d*\s*:[\s]*(.*?)\s*"
+ r"Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*?)(?=(?:Action\s*\d*\s*:|$))"
+ )
+ matches = list(re.finditer(pattern, text, re.DOTALL))
+
+ # If we found tool callsโฆ
+ if matches:
+ if includes_answer:
+ # both a final answer *and* tool calls is ambiguous
+ raise OutputParserException(
+ f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
+ )
+
+ actions: List[AgentAction] = []
+ for m in matches:
+ tool_name = m.group(1).strip()
+ tool_input = m.group(2).strip().strip('"')
+ print('\n--------------------------')
+ print(tool_input)
+ print('--------------------------')
+ actions.append(AgentAction(tool_name, json.loads(tool_input), text))
+
+ return actions
+
+ # Otherwise, if there's a final answer, finish
+ if includes_answer:
+ answer = text.split(FINAL_ANSWER_ACTION, 1)[1].strip()
+ return AgentFinish({"output": answer}, text)
+
+ # No calls and no final answer โ figure out which error to throw
+ if not re.search(r"Action\s*\d*\s*:", text):
+ raise OutputParserException(
+ f"Could not parse LLM output: `{text}`",
+ observation=MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
+ llm_output=text,
+ send_to_llm=True,
+ )
+ if not re.search(r"Action\s*\d*\s*Input\s*\d*:", text):
+ raise OutputParserException(
+ f"Could not parse LLM output: `{text}`",
+ observation=MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
+ llm_output=text,
+ send_to_llm=True,
+ )
+
+ # Fallback
+ raise OutputParserException(f"Could not parse LLM output: `{text}`")
+
+
+# Set up basic logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+
+
+def create_agent_prompt(tools: List[tool]) -> ChatPromptTemplate:
+ """Create the prompt template for the causal inference agent, emphasizing workflow and data handoff.
+ (This is the version required by the LCEL agent structure below)
+ """
+ # Get the tool descriptions
+ tool_description = render_text_description(tools)
+ tool_names = ", ".join([t.name for t in tools])
+
+ # Define the system prompt template string
+ system_template = """
+You are a causal inference expert helping users answer causal questions by following a strict workflow using specialized tools.
+
+TOOLS:
+------
+You have access to the following tools:
+
+{tools}
+
+To use a tool, please use the following format:
+
+```
+Thought: Do I need to use a tool? Yes
+Action: the action to take, should be one of [{tool_names}]
+Action Input: the input to the action, as a single, valid JSON object string. Check the tool definition for required arguments and structure.
+Observation: the result of the action, often containing structured data like 'variables', 'dataset_analysis', 'method_info', etc.
+```
+
+When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
+
+```
+Thought: Do I need to use a tool? No
+Final Answer: [your response here]
+```
+
+DO NOT UNDER ANY CIRCUMSTANCE CALL MORE THAN ONE TOOL IN A STEP
+
+**IMPORTANT TOOL USAGE:**
+1. **Action Input Format:** The value for 'Action Input' MUST be a single, valid JSON object string. Do NOT include any other text or formatting around the JSON string.
+2. **Argument Gathering:** You MUST gather ALL required arguments for the Action Input JSON from the initial Human input AND the 'Observation' outputs of PREVIOUS steps. Look carefully at the required arguments for the tool you are calling.
+3. **Data Handoff:** The 'Observation' from a previous step often contains structured data needed by the next tool. For example, the 'variables' output from `query_interpreter_tool` contains fields like `treatment_variable`, `outcome_variable`, `covariates`, `time_variable`, `instrument_variable`, `running_variable`, `cutoff_value`, and `is_rct`. When calling `method_selector_tool`, you MUST construct its required `variables` input argument by including **ALL** these relevant fields identified by the `query_interpreter_tool` in the previous Observation. Similarly, pass the full `dataset_analysis`, `dataset_description`, and `original_query` when required by the next tool.
+
+IMPORTANT WORKFLOW:
+-------------------
+You must follow this exact workflow, selecting the appropriate tool for each step:
+
+1. ALWAYS start with `input_parser_tool` to understand the query
+2. THEN use `dataset_analyzer_tool` to analyze the dataset
+3. THEN use `query_interpreter_tool` to identify variables (output includes `variables` and `dataset_analysis`)
+4. THEN use `method_selector_tool` (input requires `variables` and `dataset_analysis` from previous step)
+5. THEN use `method_validator_tool` (input requires `method_info` and `variables` from previous step)
+6. THEN use `method_executor_tool` (input requires `method`, `variables`, `dataset_path`)
+7. THEN use `explanation_generator_tool` (input requires results, method_info, variables, etc.)
+8. FINALLY use `output_formatter_tool` to return the results
+
+REASONING PROCESS:
+------------------
+EXPLICITLY REASON about:
+1. What step you're currently on (based on previous tool's Observation)
+2. Why you're selecting a particular tool (should follow the workflow)
+3. How the output of the previous tool (especially structured data like `variables`, `dataset_analysis`, `method_info`) informs the inputs required for the current tool.
+
+IMPORTANT RULES:
+1. Do not make more than one tool call in a single step.
+2. Do not include ``` in your output at all.
+3. Don't use action names like default_api.dataset_analyzer_tool, instead use tool names like dataset_analyzer_tool.
+4. Always start, action, and observation with a new line.
+5. Don't use '\\' before double quotes
+6. Don't include ```json for Action Input
+Begin!
+"""
+
+ # Create the prompt template
+ prompt = ChatPromptTemplate.from_messages([
+ ("system", system_template),
+ MessagesPlaceholder("chat_history", optional=True), # Use MessagesPlaceholder
+ # MessagesPlaceholder("agent_scratchpad"),
+
+ ("human", "{input}\n Thought:{agent_scratchpad}"),
+ # ("ai", "{agent_scratchpad}"),
+ # MessagesPlaceholder("agent_scratchpad" ), # Use MessagesPlaceholder
+ # "agent_scratchpad"
+ ])
+ return prompt
+
+def create_causal_agent(llm: BaseChatModel) -> AgentExecutor:
+ """
+ Create and configure the LangChain agent with causal inference tools.
+ (Using explicit LCEL construction, compatible with shared LLM client)
+ """
+ # Define tools available to the agent
+ agent_tools = [
+ input_parser_tool,
+ dataset_analyzer_tool,
+ query_interpreter_tool,
+ method_selector_tool,
+ method_validator_tool,
+ method_executor_tool,
+ explanation_generator_tool,
+ output_formatter_tool
+ ]
+ # anthropic_agent_tools = [ convert_to_anthropic_tool(anthropic_tool) for anthropic_tool in agent_tools]
+ # Create the prompt using the helper
+ prompt = create_agent_prompt(agent_tools)
+
+ # Bind tools to the LLM (using the passed shared instance)
+ llm_with_tools = llm.bind_tools(agent_tools)
+
+ # Create memory
+ # Consider if memory needs to be passed in or created here
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
+
+ # Manually construct the agent runnable using LCEL
+ from langchain_anthropic.output_parsers import ToolsOutputParser
+ from langchain.agents.output_parsers.json import JSONAgentOutputParser
+ # from langchain.agents.react.output_parser import MultiActionAgentOutputParsers ReActMultiInputOutputParser
+ agent = create_react_agent(llm_with_tools, agent_tools, prompt, output_parser=ReActMultiInputOutputParser())
+
+ # Create executor (should now work with the manually constructed agent)
+ executor = AgentExecutor(
+ agent=agent,
+ tools=agent_tools,
+ memory=memory, # Pass the memory object
+ verbose=True,
+ callbacks=[ConsoleCallbackHandler()], # Optional: for console debugging
+ handle_parsing_errors=True, # Let AE handle parsing errors
+ max_retries = 100
+ )
+
+ return executor
+
+def run_causal_analysis(query: str, dataset_path: str,
+ dataset_description: Optional[str] = None,
+ api_key: Optional[str] = None) -> Dict[str, Any]:
+ """
+ Run causal analysis on a dataset based on a user query.
+
+ Args:
+ query: User's causal question
+ dataset_path: Path to the dataset
+ dataset_description: Optional textual description of the dataset
+ api_key: Optional OpenAI API key (DEPRECATED - will be ignored)
+
+ Returns:
+ Dictionary containing the final formatted analysis results from the agent's last step.
+ """
+ # Log the start of the analysis
+ logger.info("Starting causal analysis run...")
+
+ try:
+ # --- Instantiate the shared LLM client ---
+ shared_llm = get_llm_client(temperature=0) # Or read provider/model from env
+
+
+ # --- Create agent using the shared LLM ---
+ agent_executor = create_causal_agent(shared_llm)
+
+ # Construct input, including description if available
+ # IMPORTANT: Agent now expects 'input' and potentially 'chat_history'
+ # The input needs to contain all initial info the first tool might need.
+ initial_input_dict = {
+ "query": query,
+ "dataset_path": dataset_path,
+ "dataset_description": dataset_description
+ }
+ # Maybe format this into a single input string if the prompt expects {input}
+ input_text = f"My question is: {query}\n"
+ input_text += f"The dataset is located at: {dataset_path}\n"
+ if dataset_description:
+ input_text += f"Dataset Description: {dataset_description}\n"
+ input_text += "Please perform the causal analysis following the workflow."
+
+ # Log the constructed input text
+ logger.info(f"Constructed input for agent: \n{input_text}")
+
+ result = agent_executor.invoke({
+ "input": input_text,
+})
+
+
+ # AgentExecutor returns dict. Extract the final output dictionary.
+ logger.info("Causal analysis run finished.")
+
+ # Ensure result is a dict and extract the 'output' part
+ if isinstance(result, dict):
+ final_output = result.get("output")
+ if isinstance(final_output, dict):
+ return final_output # Return only the dictionary from the final tool
+ else:
+ logger.error(f"Agent result['output'] was not a dictionary: {type(final_output)}. Returning error dict.")
+ return {"error": "Agent did not produce the expected dictionary output in the 'output' key.", "raw_agent_result": result}
+ else:
+ logger.error(f"Agent returned non-dict type: {type(result)}. Returning error dict.")
+ return {"error": "Agent did not return expected dictionary output.", "raw_output": str(result)}
+
+ except ValueError as e:
+ logger.error(f"Configuration Error: {e}")
+ # Return an error dictionary in case of exception too
+ return {"error": f"Error: Configuration issue - {e}"} # Ensure consistent error return type
+ except Exception as e:
+ logger.error(f"An unexpected error occurred during causal analysis: {e}", exc_info=True)
+ # Return an error dictionary in case of exception too
+ return {"error": f"An unexpected error occurred: {e}"}
\ No newline at end of file
diff --git a/cais/utils/llm_helpers.py b/cais/utils/llm_helpers.py
index 4b4fb41..2b3a66f 100644
--- a/cais/utils/llm_helpers.py
+++ b/cais/utils/llm_helpers.py
@@ -7,11 +7,30 @@
import pandas as pd
import logging
import json
-from langchain.chat_models.base import BaseChatModel
-from langchain_core.messages import AIMessage
+from langchain_core.language_models import BaseChatModel
logger = logging.getLogger(__name__)
+def invoke_llm(llm: BaseChatModel, prompt: str) -> str:
+ """
+ Call the provided LLM with a prompt and return the text content of the response.
+
+ Args:
+ llm: An instance of BaseChatModel.
+ prompt: The prompt string to send to the LLM.
+
+ Returns:
+ The string content of the LLM response.
+ """
+ if not llm:
+ logger.warning("LLM client not provided to invoke_llm. Cannot make LLM call.")
+ return ""
+
+ response = llm.invoke(prompt)
+ if hasattr(response, "content"):
+ return response.content
+ return str(response)
+
def call_llm_with_json_output(llm: Optional[BaseChatModel], prompt: str) -> Optional[Dict[str, Any]]:
"""
Calls the provided LLM with a prompt, expecting a JSON object in the response.
diff --git a/tests/cais/methods/test_diff_in_diff.py b/tests/cais/methods/test_diff_in_diff.py
new file mode 100644
index 0000000..090e47c
--- /dev/null
+++ b/tests/cais/methods/test_diff_in_diff.py
@@ -0,0 +1,81 @@
+import unittest
+import pandas as pd
+import numpy as np
+from unittest.mock import patch, MagicMock
+
+# Import the function to test
+from cais.methods.diff_in_diff import estimate_effect
+
+class TestDifferenceInDifferences(unittest.TestCase):
+
+ def setUp(self):
+ '''Set up dummy panel data for testing.'''
+ # Simple 2 groups, 2 periods example
+ self.df = pd.DataFrame({
+ 'unit': [1, 1, 2, 2, 3, 3, 4, 4], # 2 treated (1,2), 2 control (3,4)
+ 'time': [0, 1, 0, 1, 0, 1, 0, 1],
+ 'treatment_group': [1, 1, 1, 1, 0, 0, 0, 0], # Group indicator
+ 'outcome': [10, 12, 11, 14, 9, 9.5, 10, 10.5], # Treated increase more in period 1
+ 'covariate1': [1, 1, 2, 2, 1, 1, 2, 2]
+ })
+ self.treatment = 'treatment_group' # This identifies the group
+ self.outcome = 'outcome'
+ self.covariates = ['covariate1']
+ self.time_var = 'time'
+ self.group_var = 'unit'
+
+ # Mock all helper/validation functions within diff_in_diff.py
+ @patch('cais.methods.diff_in_diff.identify_time_variable')
+ @patch('cais.methods.diff_in_diff.identify_treatment_group')
+ @patch('cais.methods.diff_in_diff.determine_treatment_period')
+ @patch('cais.methods.diff_in_diff.validate_parallel_trends')
+ # Mock estimate_did_model to avoid actual regression, return mock results
+ @patch('cais.methods.diff_in_diff.estimate_did_model')
+ def test_estimate_effect_structure_and_types(self, mock_estimate_model, mock_validate_trends,
+ mock_determine_period, mock_identify_group, mock_identify_time):
+ '''Test the basic structure and types of the DiD estimate_effect output.'''
+ # Configure mocks
+ mock_identify_time.return_value = self.time_var
+ mock_identify_group.return_value = self.group_var
+ mock_determine_period.return_value = 1 # Assume treatment starts at time 1
+ mock_validate_trends.return_value = {"valid": True, "p_value": 0.9}
+
+ # Mock the statsmodels result object
+ mock_model_results = MagicMock()
+ # Define the interaction term based on how construct_did_formula names it
+ # Assuming treatment='treatment_group', post='post'
+ interaction_term = f"{self.treatment}_x_post"
+ mock_model_results.params = {interaction_term: 2.5, 'Intercept': 10.0}
+ mock_model_results.bse = {interaction_term: 0.5, 'Intercept': 0.2}
+ mock_model_results.pvalues = {interaction_term: 0.01, 'Intercept': 0.001}
+ # Mock the summary() method if format_did_results uses it
+ mock_model_results.summary.return_value = "Mocked Model Summary"
+ mock_estimate_model.return_value = mock_model_results
+
+ # Call the function (passing explicit vars to bypass internal identification mocks if desired)
+ result = estimate_effect(self.df, self.treatment, self.outcome, self.covariates,
+ time_var=self.time_var, group_var=self.group_var, query="Test query")
+
+ # Assertions
+ self.assertIsInstance(result, dict)
+ expected_keys = ["effect_estimate", "effect_se", "confidence_interval", "p_value",
+ "diagnostics", "method_details", "parameters", "model_summary"]
+ for key in expected_keys:
+ self.assertIn(key, result, f"Key '{key}' missing from result")
+
+ self.assertEqual(result["method_details"], "DiD.TWFE")
+ self.assertIsInstance(result["effect_estimate"], float)
+ self.assertIsInstance(result["effect_se"], float)
+ self.assertIsInstance(result["confidence_interval"], list)
+ self.assertEqual(len(result["confidence_interval"]), 2)
+ self.assertIsInstance(result["diagnostics"], dict)
+ self.assertIsInstance(result["parameters"], dict)
+ self.assertIn("time_var", result["parameters"])
+ self.assertIn("group_var", result["parameters"])
+ self.assertIn("interaction_term", result["parameters"])
+ self.assertEqual(result["parameters"]["interaction_term"], interaction_term)
+ self.assertIn("valid", result["diagnostics"])
+ self.assertIn("model_summary", result)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/cais/test_e2e_iv_new.py b/tests/cais/test_e2e_iv_new.py
new file mode 100644
index 0000000..eb573bc
--- /dev/null
+++ b/tests/cais/test_e2e_iv_new.py
@@ -0,0 +1,45 @@
+import unittest
+import json
+import os
+import sys
+
+ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+if ROOT not in sys.path:
+ sys.path.insert(0, ROOT)
+
+from cais.agent import CausalAgent
+
+
+class TestE2EIVNewPipeline(unittest.TestCase):
+ def test_iv_llm_pipeline_app_engagement_push(self):
+ """Run the full CAIS pipeline end-to-end (real API calls, no mocks)."""
+
+ # --- Scenario ---
+ query = "What is the effect of education on earnings??"
+ dataset_path = "data/all_data/card_geographic.csv"
+ dataset_description = (
+ """The National Longitudinal Survey of Young Men (NLSYM) was conducted to collect data on demographics, education, and employment outcomes. Participants were tracked over time to study long-term patterns. The dataset used here comes from the 1976 wave of the survey. Variables include: lwage: log of wages educ: years of education exper: years of work experience black: 1 if the individual is Black, 0 otherwise south: 1 if the individual lives in a southern state, 0 otherwise married: 1 if married, 0 otherwise smsa: 1 if living in a metropolitan area, 0 otherwise nearc4: 1 if there is a four-year college in the county, 0 otherwise"""
+ )
+
+ print("--- Running E2E Test Output ---")
+ agent = CausalAgent(
+ dataset_path=dataset_path,
+ dataset_description=dataset_description,
+ )
+ output = agent.run_analysis(
+ query=query,
+ )
+ print(json.dumps(output, indent=2, default=str))
+ print("-----------------------------------------------------")
+
+ # --- Assertions ---
+ self.assertIsNotNone(output, "Agent returned None output.")
+ self.assertIsInstance(output, dict, "Agent output is not a dict.")
+ self.assertNotIn("error", output, f"Agent returned error: {output.get('error')}")
+ self.assertIn("results", output)
+ self.assertIn("explanation", output)
+ self.assertIn("instrument", json.dumps(output).lower())
+
+
+if __name__ == "__main__":
+ unittest.main()