Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ packages = [
dev = [
"huggingface-hub[cli]>=0.29.3",
"ipykernel>=6.29.5",
"pytest>=9.0.2",
]
5 changes: 5 additions & 0 deletions src/core/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class BaseDatasetConfig(PydraConfig):
path: str
dataset_id: str


class BaseDataset[C: BaseDatasetConfig](ABC):
Expand All @@ -21,3 +22,7 @@ def user_prompt(self, row: dict) -> str: ...

@abstractmethod
def row_id(self, row: dict) -> str: ...

@property
def dataset_id(self) -> str:
return self.config.dataset_id
47 changes: 47 additions & 0 deletions src/core/datasets/mmlu/mmlu_cot_response_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import override

from transformers import PreTrainedTokenizer

from core.datasets.mmlu.mmlu_single_token_response_dataset import MMLUSingleTokenResponseDataset
from core.datasets.qa_dataset import QADatasetConfig


class MMLUCoTResponseDataset(MMLUSingleTokenResponseDataset):
def __init__(self, tokenizer: PreTrainedTokenizer, config: QADatasetConfig):
super().__init__(tokenizer, config)

self.answer_marker = ("[[", "]]")

@override
def system_prompt(self, row: dict) -> str:
subject = row["base_cluster"]
return f"The following are multiple choice questions about {subject}. Explain your thinking process step-by-step. At the end, choose a correct option letter by strictly following this format: {self.answer_marker[0]}correct_option{self.answer_marker[1]}."

@override
def assistant_response(self, row: dict) -> str:
raise NotImplementedError(
"MMLUCoTResponseDataset does not implement assistant_response since it is not used for training. Use MMLUReasoningResponseDataset for evaluation instead."
)

@override
def verify_assistant_response(self, row: dict, assistant_response: str) -> tuple[str, bool]:
answer_start_token_position = assistant_response.find(self.answer_marker[0])
answer_end_token_position = assistant_response.find(self.answer_marker[1])
if (
answer_start_token_position == -1
or answer_end_token_position == -1
or answer_end_token_position < answer_start_token_position
):
return "", False

extracted_answer = (
assistant_response[answer_start_token_position + len(self.answer_marker[0]) : answer_end_token_position]
.strip()
.lower()
)

correct_answer = str(row["answer"]).strip().lower()
try:
return extracted_answer, correct_answer == extracted_answer
except:
return extracted_answer, False
Empty file added src/core/evaluation/__init__.py
Empty file.
Loading