Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mindsdb/integrations/libs/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_completed_prompts(base_template: str, df: pd.DataFrame, strict=True) ->
if strict:
raise AssertionError("No placeholders found in the prompt, please provide a valid prompt template.")
prompts = [base_template] * len(df)
return prompts, np.ndarray(0)
return prompts, np.array([], dtype=int)

first_span = matches[0].start()
last_span = matches[-1].end()
Expand All @@ -80,15 +80,15 @@ def get_completed_prompts(base_template: str, df: pd.DataFrame, strict=True) ->

empty_prompt_ids = np.where(df[columns].isna().all(axis=1).values)[0]

df["__mdb_prompt"] = ""
completed_prompts = pd.Series("", index=df.index)
for i in range(len(template)):
atom = template[i]
if i < len(columns):
col = df[columns[i]].replace(to_replace=[None], value="") # add empty quote if data is missing
df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom) + col.astype("string")
completed_prompts = completed_prompts.apply(lambda x: x + atom) + col.astype("string")
else:
df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom)
prompts = list(df["__mdb_prompt"])
completed_prompts = completed_prompts.apply(lambda x: x + atom)
prompts = list(completed_prompts)

return prompts, empty_prompt_ids

Expand Down
15 changes: 15 additions & 0 deletions tests/unit/various/test_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,18 @@ def test_get_completed_prompts(self):
df = pd.DataFrame({"text": user_inputs})
with self.assertRaises(Exception):
get_completed_prompts(base_template, df)

def test_get_completed_prompts_does_not_mutate_dataframe(self):
df = pd.DataFrame(
{
"text": ["What is MindsDB?"],
"__mdb_prompt": ["user data"],
}
)

prompts, empties = get_completed_prompts("Answer: {{text}}", df)

assert prompts == ["Answer: What is MindsDB?"]
assert empties.shape == (0,)
assert df["__mdb_prompt"].tolist() == ["user data"]
assert df.columns.tolist() == ["text", "__mdb_prompt"]
Loading