Skip to content
This repository was archived by the owner on May 19, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion datawarden/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, dataset: DatasetDict) -> None:
self.dataset = dataset
self.train_questions, self.train_answers = [], []
self.test_questions, self.test_answers = [], []
self.dataset_type = None

self._process_dataset(self.dataset['train'])
self._process_dataset(self.dataset['test'])
Expand All @@ -35,13 +36,15 @@ def _process_dataset(self, dataset_subset: Dataset) -> None:
questions, answers = [], []

if 'instruction' in dataset_subset.features:
self.dataset_type = 'alpaca'
# Extract 'instruction' and 'input' as questions, 'output' as answers
for row in dataset_subset:
question = f"{row['instruction']} {row['input']}"
answer = row['output']
questions.append([question])
answers.append([answer])
elif 'conversations' in dataset_subset.features:
self.dataset_type = 'sharegpt'
# Extract 'value' where 'from' is 'human' as questions, 'gpt' as answers
for row in dataset_subset:
conversation = row['conversations']
Expand All @@ -50,6 +53,7 @@ def _process_dataset(self, dataset_subset: Dataset) -> None:
questions.append(human_messages)
answers.append(gpt_messages)
elif 'text' in dataset_subset.features:
self.dataset_type = 'raw'
# Extract text after 'Human:' as questions, after 'Assistant:' as answers
for row in dataset_subset:
text_segments = row['text'].split('### ')
Expand Down Expand Up @@ -95,4 +99,43 @@ def get_token_counts(self, tokenizer: PreTrainedTokenizer, min_tokens_question:
return (
problematic_rows_train, clean_rows_train, problematic_indexes_train, clean_indexes_train,
problematic_rows_test, clean_rows_test, problematic_indexes_test, clean_indexes_test
)
)

def remove_problematic_rows(self, tokenizer: PreTrainedTokenizer, min_tokens_question: int = 256, min_tokens_answer: int = 256):
"""
Remove problematic rows from the dataset.

Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding text.
min_tokens_question (int, optional): The minimum number of tokens required for questions. Defaults to 256.
min_tokens_answer (int, optional): The minimum number of tokens required for answers. Defaults to 256.
"""
# Get the problematic rows and clean rows for both train and test datasets
(problematic_rows_train, clean_rows_train, _, _,
problematic_rows_test, clean_rows_test, _, _) = self.get_token_counts(tokenizer, min_tokens_question, min_tokens_answer)

# Check if clean_rows_train and clean_rows_test are not empty
if clean_rows_train:
# Remove problematic rows from the train dataset
self.train_data = clean_rows_train

# Update the train questions and answers
self.train_questions, self.train_answers = zip(*clean_rows_train)
else:
# If clean_rows_train is empty, set train_data and related attributes to empty lists
self.train_data = []
self.train_questions = []
self.train_answers = []

# Check if clean_rows_test is not empty
if clean_rows_test:
# Remove problematic rows from the test dataset
self.test_data = clean_rows_test

# Update the test questions and answers
self.test_questions, self.test_answers = zip(*clean_rows_test)
else:
# If clean_rows_test is empty, set test_data and related attributes to empty lists
self.test_data = []
self.test_questions = []
self.test_answers = []
81 changes: 70 additions & 11 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def test_init(sample_dataset):

def test_get_token_counts(sample_dataset_with_tokens):
dataset, tokenizer = sample_dataset_with_tokens
min_tokens_question = 10
min_tokens_answer = 10
min_tokens_question = 5
min_tokens_answer = 5

(
problematic_rows_train,
Expand All @@ -90,15 +90,74 @@ def test_get_token_counts(sample_dataset_with_tokens):
problematic_indexes_test,
clean_indexes_test,
) = dataset.get_token_counts(tokenizer, min_tokens_question, min_tokens_answer)

assert isinstance(problematic_rows_train, list)
assert isinstance(clean_rows_train, list)
assert isinstance(problematic_indexes_train, list)
assert isinstance(clean_indexes_train, list)
assert isinstance(problematic_rows_test, list)
assert isinstance(clean_rows_test, list)
assert isinstance(problematic_indexes_test, list)
assert isinstance(clean_indexes_test, list)

if sample_dataset_with_tokens[0].dataset_type == 'alpaca':
assert problematic_rows_train == [(['instr1 input1'], ['output1']), (['instr2 input2'], ['output2'])]
assert clean_rows_train == []
assert problematic_indexes_train == [0, 1]
assert clean_indexes_train == []
assert problematic_rows_test == []
assert clean_rows_test == [(['test_instr1 test_input1'], ['test_output1']), (['test_instr2 test_input2'], ['test_output2'])]
assert problematic_indexes_test == []
assert clean_indexes_test == [0, 1]
elif sample_dataset_with_tokens[0].dataset_type == 'sharegpt':
assert problematic_rows_train == [(['Hello'], ['Hi'])]
assert clean_rows_train == [(['How are you?'], ['I am fine.'])]
assert problematic_indexes_train == [0]
assert clean_indexes_train == [1]
assert problematic_rows_test == [(['Good morning'], ['Good morning!'])]
assert clean_rows_test == [(['What is your name?'], ['I am a chatbot.'])]
assert problematic_indexes_test == [0]
assert clean_indexes_test == [1]
elif sample_dataset_with_tokens[0].dataset_type == 'raw':
assert problematic_rows_train == []
assert clean_rows_train == [(['How are you?'], ['I am fine.']), (['What is your name?'], ['I am a chatbot.'])]
assert problematic_indexes_train == []
assert clean_indexes_train == [0, 1]
assert problematic_rows_test == [(['Good morning'], ['Good morning!'])]
assert clean_rows_test == [(['How can I help you?'], ['You can ask me anything.'])]
assert problematic_indexes_test == [0]
assert clean_indexes_test == [1]

def test_remove_problematic_rows(sample_dataset_with_tokens):
dataset, tokenizer = sample_dataset_with_tokens
min_tokens_question = 5
min_tokens_answer = 5

# Calculate token counts before removing problematic rows
(
problematic_rows_train_before,
clean_rows_train_before,
problematic_indexes_train_before,
clean_indexes_train_before,
problematic_rows_test_before,
clean_rows_test_before,
problematic_indexes_test_before,
clean_indexes_test_before,
) = dataset.get_token_counts(tokenizer, min_tokens_question, min_tokens_answer)

# Remove problematic rows
dataset.remove_problematic_rows(tokenizer, min_tokens_question, min_tokens_answer)

# Calculate token counts after removing problematic rows
(
problematic_rows_train_after,
clean_rows_train_after,
problematic_indexes_train_after,
clean_indexes_train_after,
problematic_rows_test_after,
clean_rows_test_after,
problematic_indexes_test_after,
clean_indexes_test_after,
) = dataset.get_token_counts(tokenizer, min_tokens_question, min_tokens_answer)

# Ensure that the number of clean rows has increased
assert dataset.train_data == clean_rows_train_after
assert dataset.test_data == clean_rows_test_after

# Ensure that no problematic rows are left
assert len(problematic_rows_train_after) == 0
assert len(problematic_rows_test_after) == 0

if __name__ == "__main__":
pytest.main()