diff --git a/mindsdb/integrations/libs/llm/utils.py b/mindsdb/integrations/libs/llm/utils.py index da01454142e..0b70109fa5a 100644 --- a/mindsdb/integrations/libs/llm/utils.py +++ b/mindsdb/integrations/libs/llm/utils.py @@ -62,7 +62,7 @@ def get_completed_prompts(base_template: str, df: pd.DataFrame, strict=True) -> if strict: raise AssertionError("No placeholders found in the prompt, please provide a valid prompt template.") prompts = [base_template] * len(df) - return prompts, np.ndarray(0) + return prompts, np.array([], dtype=int) first_span = matches[0].start() last_span = matches[-1].end() @@ -80,15 +80,15 @@ def get_completed_prompts(base_template: str, df: pd.DataFrame, strict=True) -> empty_prompt_ids = np.where(df[columns].isna().all(axis=1).values)[0] - df["__mdb_prompt"] = "" + completed_prompts = pd.Series("", index=df.index) for i in range(len(template)): atom = template[i] if i < len(columns): col = df[columns[i]].replace(to_replace=[None], value="") # add empty quote if data is missing - df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom) + col.astype("string") + completed_prompts = completed_prompts.apply(lambda x: x + atom) + col.astype("string") else: - df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom) - prompts = list(df["__mdb_prompt"]) + completed_prompts = completed_prompts.apply(lambda x: x + atom) + prompts = list(completed_prompts) return prompts, empty_prompt_ids diff --git a/tests/unit/various/test_llm_utils.py b/tests/unit/various/test_llm_utils.py index d5df7a77e42..6a65f7e6a74 100644 --- a/tests/unit/various/test_llm_utils.py +++ b/tests/unit/various/test_llm_utils.py @@ -33,3 +33,18 @@ def test_get_completed_prompts(self): df = pd.DataFrame({"text": user_inputs}) with self.assertRaises(Exception): get_completed_prompts(base_template, df) + + def test_get_completed_prompts_does_not_mutate_dataframe(self): + df = pd.DataFrame( + { + "text": ["What is MindsDB?"], + "__mdb_prompt": ["user data"], + } + ) + + prompts, empties = get_completed_prompts("Answer: {{text}}", df) + + assert prompts == ["Answer: What is MindsDB?"] + assert empties.shape == (0,) + assert df["__mdb_prompt"].tolist() == ["user data"] + assert df.columns.tolist() == ["text", "__mdb_prompt"]