Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions align_system/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from outlines.backends.outlines_core import OutlinesCoreBackend
from outlines.models.transformers import TransformerTokenizer
from outlines_core import Vocabulary


# monkey patch to fix https://github.com/dottxt-ai/outlines/pull/1831
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know what version of outlines should have this fix (probably one higher version than the latest one assuming they haven't cut a release yet with this fix merged) we should note it here to make it easier on our future selves.

# fix was applied to outlines main, but we will probably be blocked from updating due to vllm dependency
# assuming this will be in official release >1.1.12
@staticmethod
def deterministic_create_vocab(vocab, eos_token_id, eos_token, token_to_str):
formatted_vocab = {}
for token, token_id in vocab.items():
token_as_str = token_to_str(token)
formatted_vocab.setdefault(token_as_str, []).append(token_id)
formatted_vocab.pop(eos_token)
return Vocabulary(eos_token_id, formatted_vocab)


OutlinesCoreBackend.create_outlines_core_vocabulary = deterministic_create_vocab


# monkey patch to fix https://github.com/dottxt-ai/outlines/pull/1817
# newer verion of outlines fixes this issue (1.2.10), but we are blocked with the vllm dependency
def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = self.tokenizer.convert_tokens_to_string([token])

if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string


TransformerTokenizer.convert_token_to_string = convert_token_to_string
5 changes: 3 additions & 2 deletions align_system/algorithms/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ def choose_action(self,

class StructuredInferenceEngine(ABC):
@abstractmethod
def dialog_to_prompt(dialog: list[dict]) -> str:
def dialog_to_prompt(self, dialog: list[dict]) -> str:
pass

@abstractmethod
def run_inference(prompts: Union[str, list[str]],
def run_inference(self,
prompts: Union[str, list[str]],
schema: str) -> Union[dict, list[dict]]:
pass

Expand Down
9 changes: 4 additions & 5 deletions align_system/algorithms/lib/kaleido.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file looks like it's been hit by some auto-linting (and similar story with one or two of the other files). This particular file was provided by a sub, and I would rather than change it at all in case they deliver some updates to us so that the diff is much cleaner. Could either copy the version of this file from main and overwrite this version or revert the particular commit that made these changes.

In generally I would prefer not to auto-lint existing files since it tends to blow up the diffs with superficial changes.

Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def get_gens(self, actions, n_gens=100, sample=False, batch_size=None):
n_batches = math.ceil(len(actions) / batch_size)
for i in self.tqdm(range(n_batches), desc='Generation'):
batch_actions = actions[i*batch_size:(i+1)*batch_size]
encoded_batch = self.tokenizer.batch_encode_plus(
encoded_batch = self.tokenizer(
[self.gen_template(action) for action in batch_actions],
return_tensors='pt',
padding=True,
Expand All @@ -411,7 +411,7 @@ def get_gens(self, actions, n_gens=100, sample=False, batch_size=None):
with torch.no_grad():
gens = self.model.generate(encoded_batch, num_beams=n_gens, num_return_sequences=n_gens, max_new_tokens=30)
# decode
gens = self.tokenizer.batch_decode(gens, skip_special_tokens=True)
gens = self.tokenizer(gens, skip_special_tokens=True)
# add to list
batch_gens.extend(gens)
# reshape to (n_actions, n_gens)
Expand Down Expand Up @@ -447,7 +447,7 @@ def get_explanation(self, actions, vrds, texts, batch_size=None, explanation_dec
batch_vrds = vrds[i*batch_size:(i+1)*batch_size]
batch_texts = texts[i*batch_size:(i+1)*batch_size]
# get explanations
encoded_batch = self.tokenizer.batch_encode_plus(
encoded_batch = self.tokenizer(
[self.explanation_template(action, vrd, text) for action, vrd, text in zip(batch_actions, batch_vrds, batch_texts)],
return_tensors='pt',
padding=True,
Expand All @@ -464,7 +464,6 @@ def get_explanation(self, actions, vrds, texts, batch_size=None, explanation_dec
batch_exps = batch_exps[0]
return batch_exps


def get_dummy(self, encoded_batch):
# get dummy labels (0,0) * batch size
dummy_labels = torch.tensor([[0, 0]] * encoded_batch.shape[0]).to(self.device)
Expand All @@ -485,7 +484,7 @@ def get_probs(self, inputs, batch_size=None):
for i in self.tqdm(range(n_batches+1), desc='Inference'):
# inds = list(range(i*batch_size, (i+1)*batch_size))
inds = list(range(i*batch_size, min((i+1)*batch_size, len(inputs))))
encoded_batch = self.tokenizer.batch_encode_plus(
encoded_batch = self.tokenizer(
inputs[inds].tolist(),
return_tensors='pt', padding=True, truncation=False, max_length=128,
).to(self.device)
Expand Down
105 changes: 61 additions & 44 deletions align_system/algorithms/outlines_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import partial

import outlines
from outlines.samplers import MultinomialSampler
from outlines.types import JsonSchema
import jinja2
from rich.highlighter import JSONHighlighter
from align_system.data_models.compat.ta3_ph1_client_models import (
Expand All @@ -16,6 +16,7 @@
CharacterTagEnum,
KDMAValue
)
import transformers

from align_system.utils import logging
from align_system.utils import adm_utils
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(self,
model_name,
device='auto',
baseline=False,
sampler=MultinomialSampler(),
generation_kwargs=None,
scenario_description_template=scenario_state_description_1,
action_selection_prompt_template=action_selection_prompt,
baseline_system_prompt=baseline_system_prompt,
Expand All @@ -86,19 +87,21 @@ def __init__(self,
f"Unexpected value for 'precision' ({kwargs['precision']})"
", expecting either 'half' or 'full'")

model_kwargs['torch_dtype'] = torch_dtype
model_kwargs['dtype'] = torch_dtype

self.model = outlines.models.transformers(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=kwargs.get('tokenizer_kwargs', {}))
# NOTE: In cases where we want multiple samples, we're passing
# in a list of prompts (this allows us to shuffle answers in
# each prompt), rather than setting the number of samples in
# the sampler itself (which defaults to 1); setting the number
# of samples in the sampler may result in unexpected behavior
self.sampler = sampler
self.model = outlines.from_transformers(
transformers.AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs, device_map=device),
transformers.AutoTokenizer.from_pretrained(model_name, **kwargs.get('tokenizer_kwargs', {})),
device_dtype=torch_dtype)

if generation_kwargs is None:
generation_kwargs = {'temperature': 0.7}
self.generation_kwargs = generation_kwargs

# Sometimes the internal default for outlines/transformers is 20,
# leading to very short (and often invalid JSON) outputs. Setting a
# somewhat generous default.
self.generation_kwargs.setdefault('max_new_tokens', 8192)

self.outlines_seed = outlines_seed
if self.outlines_seed is None:
Expand Down Expand Up @@ -240,15 +243,11 @@ def batched(cls, iterable, n):
yield batch

@classmethod
def run_in_batches(cls, inference_function, inputs, batch_size, rng=None):
def run_in_batches(cls, inference_function, inputs, batch_size, **generation_kwargs):
''' Batch inference to avoid out of memory error'''
outputs = []
for batch in cls.batched(inputs, batch_size):
if rng is None:
output = inference_function(list(batch))
else:
output = inference_function(list(batch), rng=rng)

output = inference_function(list(batch), **generation_kwargs)
if not isinstance(output, list):
output = [output]
outputs.extend(output)
Expand Down Expand Up @@ -432,12 +431,14 @@ def top_level_choose_action(self,
# Need to set the whitespace_pattern to prevent the state
# machine from looping indefinitely in some cases, see:
# https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934
generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
action_choice_json_schema(json.dumps(choices), reasoning_max_length),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

if max_generator_tokens >= 0:
generator = partial(generator, max_tokens=max_generator_tokens)

Expand All @@ -454,7 +455,13 @@ def top_level_choose_action(self,
extra={"markup": True})
log.info(dialog_texts[0])

responses = self.run_in_batches(generator, dialog_texts, generator_batch_size, rng=self.outlines_rng)
responses = self.run_in_batches(generator.batch,
dialog_texts,
generator_batch_size,
rng=self.outlines_rng,
**self.generation_kwargs)
responses = [json.loads(r) for r in responses]

positive_responses_choices =\
[r['action_choice'] for r in
responses[0:num_positive_samples]]
Expand Down Expand Up @@ -657,17 +664,19 @@ def ensure_character_id_is_populated(self,

character_names = [c.name for c in characters]

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
character_choice_json_schema(json.dumps(character_names)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_text)

selected_character = generator(dialog_text)
selected_character = json.loads(generator(dialog_text, **self.generation_kwargs))
selected_character_idx = character_names.index(selected_character['character_choice'])

log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
Expand Down Expand Up @@ -727,19 +736,21 @@ def populate_treatment_parameters(self,

dialog_text = self.dialog_to_prompt(dialog)

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
treatment_choice_json_schema(
json.dumps([s.type for s in available_supplies]),
json.dumps(valid_treatment_locations)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_text)

selected_treatment = generator(dialog_text)
selected_treatment = json.loads(generator(dialog_text, **self.generation_kwargs))

log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
extra={"markup": True})
Expand Down Expand Up @@ -799,14 +810,16 @@ def select_treatment_parameters(self,
extra={"markup": True})
log.info(dialog_text)

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
treatment_choice_from_list_json_schema(
json.dumps(possible_treatments)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

selected_treatment = generator(dialog_text)
generator = outlines.Generator(
self.model,
json_schema)

selected_treatment = json.loads(generator(dialog_text, **self.generation_kwargs))
log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
extra={"markup": True})
log.info(selected_treatment, extra={"highlighter": JSON_HIGHLIGHTER})
Expand Down Expand Up @@ -843,18 +856,20 @@ def populate_tagging_parameters(self,

dialog_text = self.dialog_to_prompt(dialog)

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
tag_choice_json_schema(
json.dumps(valid_tags)),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_text)

selected_tag = generator(dialog_text)
selected_tag = json.loads(generator(dialog_text, **self.generation_kwargs))

log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
extra={"markup": True})
Expand Down Expand Up @@ -906,18 +921,20 @@ def populate_aid_parameters(self,

dialog_text = self.dialog_to_prompt(dialog)

generator = outlines.generate.json(
self.model,
json_schema = JsonSchema(
aid_choice_json_schema(
json.dumps([aid.id for aid in available_aids])),
sampler=self.sampler,
whitespace_pattern=r"[ ]?")

generator = outlines.Generator(
self.model,
json_schema)

log.info("[bold]*DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_text)

selected_aid = generator(dialog_text)
selected_aid = json.loads(generator(dialog_text, **self.generation_kwargs))

log.info("[bold]*STRUCTURED RESPONSE*[/bold]",
extra={"markup": True})
Expand Down
Loading