-
Notifications
You must be signed in to change notification settings - Fork 10
IV LLM refactoring #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactoring
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |||||||
| from cais.tools.input_parser_tool import input_parser_tool | ||||||||
| from cais.tools.dataset_analyzer_tool import dataset_analyzer_tool | ||||||||
| from cais.tools.query_interpreter_tool import query_interpreter_tool | ||||||||
| from cais.tools.iv_discovery_tool import iv_discovery_tool | ||||||||
| from cais.tools.method_selector_tool import method_selector_tool | ||||||||
| from cais.tools.controls_selector_tool import controls_selector_tool | ||||||||
| from cais.tools.method_validator_tool import method_validator_tool | ||||||||
|
|
@@ -49,6 +50,11 @@ | |||||||
|
|
||||||||
| # Set up basic logging | ||||||||
| os.makedirs('./logs/', exist_ok=True) | ||||||||
| logging.basicConfig( | ||||||||
| filename='./logs/agent_debug.log', | ||||||||
| level=logging.INFO, | ||||||||
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||||||||
| ) | ||||||||
| logger = logging.getLogger(__name__) | ||||||||
|
|
||||||||
|
|
||||||||
|
|
@@ -153,6 +159,23 @@ def select_method(self, query=None, llm_decision=True): | |||||||
| self.selected_method = self.method_info.selected_method | ||||||||
| return self.selected_method | ||||||||
|
|
||||||||
| def discover_instruments(self, query=None): | ||||||||
| query = self.checkq(query) | ||||||||
|
|
||||||||
| iv_discovery_output = iv_discovery_tool.func( | ||||||||
| variables=self.variables, | ||||||||
| dataset_analysis=self.dataset_analysis, | ||||||||
| dataset_description=self.dataset_description, | ||||||||
| original_query=query | ||||||||
|
||||||||
| original_query=query | |
| original_query=query, | |
| llm=self.llm, |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| import logging | ||
| from typing import Dict, List, Any, Optional | ||
|
|
||
| from cais.config import get_llm_client | ||
| from cais.iv_llm.src.agents.hypothesizer import Hypothesizer | ||
| from cais.iv_llm.src.agents.confounder_miner import ConfounderMiner | ||
| from cais.iv_llm.src.critics.exclusion_critic import ExclusionCritic | ||
| from cais.iv_llm.src.critics.independence_critic import IndependenceCritic | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class IVDiscovery: | ||
| def __init__(self): | ||
| llm = get_llm_client() | ||
| self.hypothesizer = Hypothesizer(llm, k=5) | ||
| self.confounder_miner = ConfounderMiner(llm, j=5) | ||
| self.exclusion_critic = ExclusionCritic(llm) | ||
| self.independence_critic = IndependenceCritic(llm) | ||
|
Comment on lines
+13
to
+19
|
||
|
|
||
| def discover_instruments(self, treatment: str, outcome: str, context: str = "", confounders: Optional[List[str]] = None) -> Dict[str, Any]: | ||
| """ | ||
| Discover valid instrumental variables for the given treatment and outcome. | ||
|
|
||
| Args: | ||
| treatment: Name of the treatment variable | ||
| outcome: Name of the outcome variable | ||
| context: Additional context about the dataset/query | ||
| confounders: List of known confounders (optional) | ||
|
|
||
| Returns: | ||
| Dict containing proposed IVs, valid IVs, and validation results | ||
| """ | ||
| logger.info(f"Discovering instruments for treatment: {treatment}, outcome: {outcome}") | ||
|
|
||
| # Step 1: Hypothesize IVs | ||
| proposed_ivs = self.hypothesizer.propose_ivs(treatment, outcome, context=context) | ||
| logger.info(f"Proposed IVs: {proposed_ivs}") | ||
|
|
||
| if not proposed_ivs: | ||
| return { | ||
| 'proposed_ivs': [], | ||
| 'valid_ivs': [], | ||
| 'validation_results': [], | ||
| 'confounders': confounders or [] | ||
| } | ||
|
|
||
| # Step 2: Identify confounders if not provided | ||
| if confounders is None: | ||
| confounders = self.confounder_miner.identify_confounders(treatment, outcome, context=context) | ||
| logger.info(f"Identified confounders: {confounders}") | ||
|
|
||
| # Step 3: Validate IVs with critics | ||
| validation_results = [] | ||
| valid_ivs = [] | ||
|
|
||
| # Run exclusion and independence critics | ||
| exclusion_results = {} | ||
| independence_results = {} | ||
|
|
||
| # First pass: exclusion critic for all IVs | ||
| for iv in proposed_ivs: | ||
| exclusion_results[iv] = self.exclusion_critic.validate_exclusion(iv, treatment, outcome, confounders) | ||
|
|
||
| # Second pass: independence critic for all IVs | ||
| for iv in proposed_ivs: | ||
| independence_results[iv] = self.independence_critic.validate_independence(iv, treatment, outcome, confounders) | ||
|
|
||
| # Combine results | ||
| for iv in proposed_ivs: | ||
| exclusion_valid = exclusion_results[iv] | ||
| independence_valid = independence_results[iv] | ||
|
|
||
| validation_results.append({ | ||
| 'iv': iv, | ||
| 'exclusion_valid': exclusion_valid, | ||
| 'independence_valid': independence_valid, | ||
| 'overall_valid': exclusion_valid and independence_valid | ||
| }) | ||
|
|
||
| if exclusion_valid and independence_valid: | ||
| valid_ivs.append(iv) | ||
|
|
||
| logger.info(f"Valid IVs found: {valid_ivs}") | ||
|
|
||
| return { | ||
| 'proposed_ivs': proposed_ivs, | ||
| 'valid_ivs': valid_ivs, | ||
| 'validation_results': validation_results, | ||
| 'confounders': confounders | ||
| } | ||
|
|
||
| def discover_instruments(treatment: str, outcome: str, context: str = "", confounders: Optional[List[str]] = None, config_path: Optional[str] = None) -> Dict[str, Any]: | ||
| """ | ||
| Convenience function to discover instruments using default IVDiscovery instance. | ||
| """ | ||
| # `config_path` is currently unused; keep it for backwards compatibility. | ||
| discovery = IVDiscovery() | ||
| return discovery.discover_instruments(treatment, outcome, context, confounders) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| # IV-LLM package | ||
|
|
||
| import logging | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| def _find_project_root() -> Path: | ||
| """Walk up from this file to find the directory containing pyproject.toml.""" | ||
| current = Path(__file__).resolve().parent | ||
| for parent in [current, *current.parents]: | ||
| if (parent / "pyproject.toml").exists(): | ||
| return parent | ||
| return Path.cwd() | ||
|
|
||
|
|
||
| def _get_log_path() -> Path: | ||
| env_dir = os.getenv("IV_LLM_OUTPUT_DIR") | ||
| if env_dir: | ||
| return Path(env_dir) / "iv_llm.jsonl" | ||
| return _find_project_root() / "logs" / "iv_llm.jsonl" | ||
|
|
||
|
|
||
| # Configure a file handler on the "cais.iv_llm" logger so that every child | ||
| # logger (agents, critics, etc.) automatically writes to the IV-LLM log file. | ||
| _logger = logging.getLogger(__name__) # "cais.iv_llm" | ||
| if not _logger.handlers: | ||
| _logger.setLevel(logging.DEBUG) | ||
| _logger.propagate = True # still propagate to root for console output | ||
| try: | ||
| _log_path = _get_log_path() | ||
| _log_path.parent.mkdir(parents=True, exist_ok=True) | ||
| _handler = logging.FileHandler(str(_log_path), encoding="utf-8") | ||
| _handler.setLevel(logging.INFO) | ||
| _handler.setFormatter(logging.Formatter("%(message)s")) | ||
| _logger.addHandler(_handler) | ||
| except Exception: | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # IV-LLM src package |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # Agents package |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| import json | ||
| import logging | ||
| from typing import List | ||
|
|
||
| from langchain_core.language_models import BaseChatModel | ||
|
|
||
| from cais.utils.llm_helpers import invoke_llm | ||
| from ..prompts.prompt_loader import PromptLoader | ||
| from ..variable_utils import extract_available_columns, filter_to_available | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class ConfounderMiner: | ||
| def __init__(self, llm: BaseChatModel, j: int = 5) -> None: | ||
| self.llm = llm | ||
| self.j = j | ||
| self.prompt_loader = PromptLoader() | ||
|
|
||
| def identify_confounders(self, treatment: str, outcome: str, context: str = "") -> List[str]: | ||
| prompt = self.prompt_loader.format_confounder_prompt(treatment, outcome, self.j, context=context) | ||
| response = invoke_llm(self.llm, prompt) | ||
| confounders_raw = self._parse_confounders(response) | ||
|
|
||
| available_cols = extract_available_columns(context) | ||
| confounders = ( | ||
| filter_to_available(confounders_raw, available_cols) | ||
| if available_cols | ||
| else confounders_raw | ||
| ) | ||
|
|
||
| confounders = confounders[: self.j] | ||
|
|
||
| logger.info(json.dumps({ | ||
| 'name': 'confounder_miner', | ||
| 'inputs': {'treatment': treatment, 'outcome': outcome, 'j': self.j}, | ||
| 'outputs': {'confounders': confounders}, | ||
| 'raw_response': response, | ||
| }, default=str)) | ||
|
|
||
| return confounders | ||
|
|
||
| def _parse_confounders(self, response: str) -> List[str]: | ||
| import re | ||
|
|
||
| def _clean(name: str) -> str: | ||
| return name.strip().strip('"\'').strip('`').strip('*').strip() | ||
|
|
||
| # Try XML format first | ||
| match = re.search(r'<Answer>\[(.*?)\]</Answer>', response) | ||
| if match: | ||
| confounders_str = match.group(1) | ||
| confounders = [_clean(c) for c in confounders_str.split(',')] | ||
| return confounders[:self.j] | ||
|
|
||
| # Fallback: look for bracket format without XML | ||
| bracket_match = re.search(r'\[([^\]]+)\]', response) | ||
| if bracket_match: | ||
| confounders_str = bracket_match.group(1) | ||
| confounders = [_clean(c) for c in confounders_str.split(',')] | ||
| return confounders[:self.j] | ||
|
|
||
| print(f"WARNING: Could not parse confounders from response: {response[:200]}...") | ||
| return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logging.basicConfig(...)at import time reconfigures global logging for any library consumer and for the whole test suite. Prefer leaving logging configuration to the application entrypoint/CLI; here, set a module logger and emit logs without callingbasicConfig(or gate it behindif __name__ == "__main__").