diff --git a/align_system/algorithms/outlines_adm.py b/align_system/algorithms/outlines_adm.py index d7f71f69..3da8b2d0 100644 --- a/align_system/algorithms/outlines_adm.py +++ b/align_system/algorithms/outlines_adm.py @@ -16,6 +16,7 @@ CharacterTagEnum, KDMAValue ) +import ubelt as ub from align_system.utils import logging from align_system.utils import adm_utils @@ -381,6 +382,25 @@ def choose_action(self, scenario_state, available_actions, alignment_target, **k return action_to_take, choice_info def populate_action_parameters(self, scenario_state, action_to_take, dialog): + scenario_state_copy = copy.deepcopy(scenario_state) + # Don't consider the elapsed_time of the state when caching + scenario_state_copy.elapsed_time = 0 + depends = '\n'.join(( + repr(self.model.model), + repr(scenario_state_copy), + repr(action_to_take), + repr(dialog))) + + cacher = ub.Cacher('outlines_adm_populate_action_params', depends, verbose=0) + log.debug(f'cacher.fpath={cacher.fpath}') + + cached_output = cacher.tryload() + if cached_output is not None: + log.info("Cache hit for `populate_action_parameters` returning cached output") + return cached_output + else: + log.info("Cache miss for `populate_action_parameters` ..") + if action_to_take.action_type in {ActionTypeEnum.APPLY_TREATMENT, ActionTypeEnum.TAG_CHARACTER, ActionTypeEnum.CHECK_ALL_VITALS, @@ -469,7 +489,10 @@ def populate_action_parameters(self, scenario_state, action_to_take, dialog): selected_character_idx, dialog) - return action_to_take, dialog + outputs = (action_to_take, dialog) + cacher.save(outputs) + + return outputs def ensure_character_id_is_populated(self, scenario_state, diff --git a/align_system/algorithms/outlines_regression_adm_comparative.py b/align_system/algorithms/outlines_regression_adm_comparative.py index bef8e178..ac3778b9 100644 --- a/align_system/algorithms/outlines_regression_adm_comparative.py +++ b/align_system/algorithms/outlines_regression_adm_comparative.py @@ -3,11 +3,13 @@ import jinja2 import json import numpy as np +import copy import outlines from outlines.samplers import MultinomialSampler from rich.highlighter import JSONHighlighter from swagger_client.models import kdma_value +import ubelt as ub from align_system.utils import logging from align_system.utils import adm_utils @@ -130,6 +132,28 @@ def sample_relevance_predictions(self, ''' Samples prediction of the relevance of each response to each KDMA ''' + scenario_state_copy = copy.deepcopy(scenario_state) + # Don't consider the elapsed_time of the state when caching + scenario_state_copy.elapsed_time = 0 + depends = '\n'.join(( + repr(self.model.model), + repr(scenario_state_copy), + repr(scenario_description), + repr(choices), + repr([t['kdma'] for t in target_kdmas]), + repr(available_actions), + repr(incontext_settings))) + + cacher = ub.Cacher('comp_reg_relevance', depends, verbose=0) + log.debug(f'cacher.fpath={cacher.fpath}') + + cached_output = cacher.tryload() + if cached_output is not None: + log.info("Cache hit for `sample_relevance_predictions` returning cached output") + return cached_output + else: + log.info("Cache miss for `sample_relevance_predictions` ..") + use_icl = False if "number" in incontext_settings and incontext_settings["number"] > 0: use_icl = True @@ -216,7 +240,10 @@ def sample_relevance_predictions(self, else: predictions[choice][kdma_key] = 0 - return predictions, reasonings, icl_example_responses + outputs = (predictions, reasonings, icl_example_responses) + cacher.save(outputs) + + return outputs def sample_kdma_score_predictions(self, scenario_state, @@ -236,6 +263,29 @@ def sample_kdma_score_predictions(self, - predictions: {choice1:{kdma1:[score1(int), ...], ...}, ...} - reasonings: {choice1:{kdma1:[reasoning1(str), ...], ...}, ...} ''' + scenario_state_copy = copy.deepcopy(scenario_state) + # Don't consider the elapsed_time of the state when caching + scenario_state_copy.elapsed_time = 0 + depends = '\n'.join(( + repr(self.model.model), + repr(scenario_state_copy), + repr(choices), + repr(available_actions), + repr(outcome_predictions), + repr(kdma_score_examples), + repr(enum_scores), + repr(incontext_settings))) + + cacher = ub.Cacher('comp_reg_kdma_estimation', depends, verbose=0) + log.debug(f'cacher.fpath={cacher.fpath}') + + cached_output = cacher.tryload() + if cached_output is not None: + log.info("Cache hit for `sample_kdma_score_predictions` returning cached output") + return cached_output + else: + log.info("Cache miss for `sample_kdma_score_predictions` ..") + use_icl = False if "number" in incontext_settings and incontext_settings["number"] > 0: use_icl = True @@ -346,7 +396,10 @@ def sample_kdma_score_predictions(self, # Scale score to be between 0 and 1 to match targets predictions[choice][kdma_key].append(kdma_prediction[choice]['score'] / kdma_factor) - return predictions, reasonings, icl_example_responses + outputs = (predictions, reasonings, icl_example_responses) + cacher.save(outputs) + + return outputs # Returns the outcome prediction (if there was one) and score reasoning for the best sample of the selected choice def get_selected_choice_reasoning(self, selected_choice, best_sample_index, outcome_predictions, reasonings, relevance_reasonings=None): diff --git a/align_system/configs/experiment/multi_kdma_evaluation/comp_reg_adept_eval_live.yaml b/align_system/configs/experiment/multi_kdma_evaluation/comp_reg_adept_eval_live.yaml new file mode 100644 index 00000000..0d5d8586 --- /dev/null +++ b/align_system/configs/experiment/multi_kdma_evaluation/comp_reg_adept_eval_live.yaml @@ -0,0 +1,39 @@ +# @package _global_ +defaults: + - override /adm: outlines_regression_aligned_comparative/incontext_phase1 + - override /interface: ta3 + +interface: + api_endpoint: "https://darpaitm.caci.com" + session_type: adept + training_session: null + username: "ALIGN-ADM-ComparativeRegression-ADEPT" + +adm: + instance: + precision: half + model_name: mistralai/Mistral-7B-Instruct-v0.3 + sampler: + _target_: outlines.samplers.GreedySampler + inference_kwargs: + distribution_matching: average # no rel + predict_relevance: false # no rel + kdma_score_examples: true + num_samples: 1 + predict_outcomes: false + generator_batch_size: 5 + incontext: + method: prompt_bert_similarity + sort_actions: true + normalization: null + number: 5 + leave_one_out_strategy: null + most_similar_first: false + +force_determinism: true +align_to_target: true +save_last_unstructured_state_per_scenario: true + +hydra: + run: + dir: 'multi_experiment_live/ALIGN-ADM-ComparativeRegression-Mistral-7B-Instruct-v0.3-ADEPT/${now:%Y-%m-%d__%H-%M-%S}' diff --git a/align_system/configs/experiment/multi_kdma_evaluation/relevance_comp_reg_adept_eval_live.yaml b/align_system/configs/experiment/multi_kdma_evaluation/relevance_comp_reg_adept_eval_live.yaml new file mode 100644 index 00000000..8888d69c --- /dev/null +++ b/align_system/configs/experiment/multi_kdma_evaluation/relevance_comp_reg_adept_eval_live.yaml @@ -0,0 +1,39 @@ +# @package _global_ +defaults: + - override /adm: outlines_regression_aligned_comparative/incontext_phase1 + - override /interface: ta3 + +interface: + api_endpoint: "https://darpaitm.caci.com" + session_type: adept + training_session: null + username: "ALIGN-ADM-RelevanceComparativeRegression-ADEPT" + +adm: + instance: + precision: half + model_name: mistralai/Mistral-7B-Instruct-v0.3 + sampler: + _target_: outlines.samplers.GreedySampler + inference_kwargs: + distribution_matching: relevance_average # use rel + predict_relevance: true # use rel + kdma_score_examples: true + num_samples: 1 + predict_outcomes: false + generator_batch_size: 5 + incontext: + method: prompt_bert_similarity + sort_actions: true + normalization: null + number: 5 + leave_one_out_strategy: null + most_similar_first: false + +force_determinism: true +align_to_target: true +save_last_unstructured_state_per_scenario: true + +hydra: + run: + dir: 'multi_experiment_live/ALIGN-ADM-RelevanceComparativeRegression-Mistral-7B-Instruct-v0.3-ADEPT/${now:%Y-%m-%d__%H-%M-%S}' diff --git a/poetry.lock b/poetry.lock index 39ca7b06..bf0dc7be 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3768,6 +3768,29 @@ files = [ {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, ] +[[package]] +name = "ubelt" +version = "1.3.6" +description = "A Python utility belt containing simple tools, a stdlib like feel, and extra batteries" +optional = false +python-versions = ">=3.6" +files = [ + {file = "ubelt-1.3.6-py3-none-any.whl", hash = "sha256:2a38f260e7f3c25d3618f653d5c900230dc5af56bf0bc1ff85cfdbe9d0c88f0f"}, + {file = "ubelt-1.3.6.tar.gz", hash = "sha256:327a516a1fc95595096727ae3ae879379bc56fc11fb945857b971ef85a74f698"}, +] + +[package.extras] +all = ["Pygments (>=2.2.0)", "colorama (>=0.4.3)", "coverage (>=4.3.4)", "coverage (>=4.5)", "coverage (>=5.3.1)", "coverage (>=5.3.1)", "coverage (>=5.3.1)", "coverage (>=6.1.1)", "coverage (>=6.1.1)", "coverage (>=6.1.1)", "coverage (>=6.1.1)", "coverage (>=7.3.0)", "jaraco.windows (>=3.9.1)", "numpy (>=1.12.0,<2.0.0)", "numpy (>=1.14.5,<2.0.0)", "numpy (>=1.19.2)", "numpy (>=1.19.3)", "numpy (>=1.21.1)", "numpy (>=1.23.2)", "numpy (>=1.26.0)", "packaging (>=21.0)", "pydantic (<2.0)", "pytest (>=4.6.0)", "pytest (>=4.6.0)", "pytest (>=4.6.0,<=4.6.11)", "pytest (>=4.6.0,<=4.6.11)", "pytest (>=4.6.0,<=6.1.2)", "pytest (>=6.2.5)", "pytest (>=8.1.1)", "pytest (>=8.1.1)", "pytest (>=8.1.1)", "pytest-cov (>=2.8.1)", "pytest-cov (>=2.8.1)", "pytest-cov (>=2.9.0)", "pytest-cov (>=3.0.0)", "pytest-timeout (>=1.4.2)", "pytest-timeout (>=2.3.1)", "python-dateutil (>=2.8.1)", "requests (>=2.25.1)", "xdoctest (>=1.1.3)", "xxhash (>=1.3.0)", "xxhash (>=1.3.0)", "xxhash (>=1.4.3)", "xxhash (>=2.0.2)", "xxhash (>=3.0.0)", "xxhash (>=3.2.0)", "xxhash (>=3.4.1)"] +all-strict = ["Pygments (==2.2.0)", "colorama (==0.4.3)", "coverage (==4.3.4)", "coverage (==4.5)", "coverage (==5.3.1)", "coverage (==5.3.1)", "coverage (==5.3.1)", "coverage (==6.1.1)", "coverage (==6.1.1)", "coverage (==6.1.1)", "coverage (==6.1.1)", "coverage (==7.3.0)", "jaraco.windows (==3.9.1)", "numpy (==1.12.0)", "numpy (==1.14.5)", "numpy (==1.19.2)", "numpy (==1.19.3)", "numpy (==1.21.1)", "numpy (==1.23.2)", "numpy (==1.26.0)", "packaging (==21.0)", "pydantic (<2.0)", "pytest (==4.6.0)", "pytest (==4.6.0)", "pytest (==4.6.0)", "pytest (==4.6.0)", "pytest (==4.6.0)", "pytest (==6.2.5)", "pytest (==8.1.1)", "pytest (==8.1.1)", "pytest (==8.1.1)", "pytest-cov (==2.8.1)", "pytest-cov (==2.8.1)", "pytest-cov (==2.9.0)", "pytest-cov (==3.0.0)", "pytest-timeout (==1.4.2)", "pytest-timeout (==2.3.1)", "python-dateutil (==2.8.1)", "requests (==2.25.1)", "xdoctest (==1.1.3)", "xxhash (==1.3.0)", "xxhash (==1.3.0)", "xxhash (==1.4.3)", "xxhash (==2.0.2)", "xxhash (==3.0.0)", "xxhash (==3.2.0)", "xxhash (==3.4.1)"] +docs = ["Pygments (>=2.9.0)", "myst-parser (>=0.16.1)", "sphinx (>=4.3.2)", "sphinx-autoapi (>=1.8.4)", "sphinx-autobuild (>=2021.3.14)", "sphinx-reredirects (>=0.0.1)", "sphinx-rtd-theme (>=1.0.0)", "sphinxcontrib-napoleon (>=0.7)"] +docs-strict = ["Pygments (==2.9.0)", "myst-parser (==0.16.1)", "sphinx (==4.3.2)", "sphinx-autoapi (==1.8.4)", "sphinx-autobuild (==2021.3.14)", "sphinx-reredirects (==0.0.1)", "sphinx-rtd-theme (==1.0.0)", "sphinxcontrib-napoleon (==0.7)"] +optional = ["Pygments (>=2.2.0)", "colorama (>=0.4.3)", "jaraco.windows (>=3.9.1)", "numpy (>=1.12.0,<2.0.0)", "numpy (>=1.14.5,<2.0.0)", "numpy (>=1.19.2)", "numpy (>=1.19.3)", "numpy (>=1.21.1)", "numpy (>=1.23.2)", "numpy (>=1.26.0)", "packaging (>=21.0)", "pydantic (<2.0)", "python-dateutil (>=2.8.1)", "xxhash (>=1.3.0)", "xxhash (>=1.3.0)", "xxhash (>=1.4.3)", "xxhash (>=2.0.2)", "xxhash (>=3.0.0)", "xxhash (>=3.2.0)", "xxhash (>=3.4.1)"] +optional-strict = ["Pygments (==2.2.0)", "colorama (==0.4.3)", "jaraco.windows (==3.9.1)", "numpy (==1.12.0)", "numpy (==1.14.5)", "numpy (==1.19.2)", "numpy (==1.19.3)", "numpy (==1.21.1)", "numpy (==1.23.2)", "numpy (==1.26.0)", "packaging (==21.0)", "pydantic (<2.0)", "python-dateutil (==2.8.1)", "xxhash (==1.3.0)", "xxhash (==1.3.0)", "xxhash (==1.4.3)", "xxhash (==2.0.2)", "xxhash (==3.0.0)", "xxhash (==3.2.0)", "xxhash (==3.4.1)"] +tests = ["coverage (>=4.3.4)", "coverage (>=4.5)", "coverage (>=5.3.1)", "coverage (>=5.3.1)", "coverage (>=5.3.1)", "coverage (>=6.1.1)", "coverage (>=6.1.1)", "coverage (>=6.1.1)", "coverage (>=6.1.1)", "coverage (>=7.3.0)", "pytest (>=4.6.0)", "pytest (>=4.6.0)", "pytest (>=4.6.0,<=4.6.11)", "pytest (>=4.6.0,<=4.6.11)", "pytest (>=4.6.0,<=6.1.2)", "pytest (>=6.2.5)", "pytest (>=8.1.1)", "pytest (>=8.1.1)", "pytest (>=8.1.1)", "pytest-cov (>=2.8.1)", "pytest-cov (>=2.8.1)", "pytest-cov (>=2.9.0)", "pytest-cov (>=3.0.0)", "pytest-timeout (>=1.4.2)", "pytest-timeout (>=2.3.1)", "requests (>=2.25.1)", "xdoctest (>=1.1.3)"] +tests-strict = ["coverage (==4.3.4)", "coverage (==4.5)", "coverage (==5.3.1)", "coverage (==5.3.1)", "coverage (==5.3.1)", "coverage (==6.1.1)", "coverage (==6.1.1)", "coverage (==6.1.1)", "coverage (==6.1.1)", "coverage (==7.3.0)", "pytest (==4.6.0)", "pytest (==4.6.0)", "pytest (==4.6.0)", "pytest (==4.6.0)", "pytest (==4.6.0)", "pytest (==6.2.5)", "pytest (==8.1.1)", "pytest (==8.1.1)", "pytest (==8.1.1)", "pytest-cov (==2.8.1)", "pytest-cov (==2.8.1)", "pytest-cov (==2.9.0)", "pytest-cov (==3.0.0)", "pytest-timeout (==1.4.2)", "pytest-timeout (==2.3.1)", "requests (==2.25.1)", "xdoctest (==1.1.3)"] +types = ["autoflake (>=1.4)", "mypy", "yapf (>=0.32.0)"] +types-strict = ["autoflake (==1.4)", "mypy", "yapf (==0.32.0)"] + [[package]] name = "urllib3" version = "1.26.18" @@ -4022,4 +4045,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "4ab215174a245b35e1e6af8887fe842a994ce2ba0dfc8f11b185d8b8035d13ae" +content-hash = "5bdc23cc1eca2fa431dde1a047176b5527b2ca2aae5bc29a433b6e8a370f47c9" diff --git a/pyproject.toml b/pyproject.toml index 26748789..e40bdc33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ outlines = "^0.0.46" setuptools = "^70.1.1" sentencepiece = "^0.2.0" protobuf = "^5.28.3" +ubelt = "1.3.6" [tool.poetry.scripts] run_align_system = 'align_system.cli.run_align_system:main'