diff --git a/datawarden/dataset.py b/datawarden/dataset.py index d195103..74c6e9c 100644 --- a/datawarden/dataset.py +++ b/datawarden/dataset.py @@ -18,6 +18,7 @@ def __init__(self, dataset: DatasetDict) -> None: self.dataset = dataset self.train_questions, self.train_answers = [], [] self.test_questions, self.test_answers = [], [] + self.dataset_type = None self._process_dataset(self.dataset['train']) self._process_dataset(self.dataset['test']) @@ -35,6 +36,7 @@ def _process_dataset(self, dataset_subset: Dataset) -> None: questions, answers = [], [] if 'instruction' in dataset_subset.features: + self.dataset_type = 'alpaca' # Extract 'instruction' and 'input' as questions, 'output' as answers for row in dataset_subset: question = f"{row['instruction']} {row['input']}" @@ -42,6 +44,7 @@ def _process_dataset(self, dataset_subset: Dataset) -> None: questions.append([question]) answers.append([answer]) elif 'conversations' in dataset_subset.features: + self.dataset_type = 'sharegpt' # Extract 'value' where 'from' is 'human' as questions, 'gpt' as answers for row in dataset_subset: conversation = row['conversations'] @@ -50,6 +53,7 @@ def _process_dataset(self, dataset_subset: Dataset) -> None: questions.append(human_messages) answers.append(gpt_messages) elif 'text' in dataset_subset.features: + self.dataset_type = 'raw' # Extract text after 'Human:' as questions, after 'Assistant:' as answers for row in dataset_subset: text_segments = row['text'].split('### ') @@ -95,4 +99,43 @@ def get_token_counts(self, tokenizer: PreTrainedTokenizer, min_tokens_question: return ( problematic_rows_train, clean_rows_train, problematic_indexes_train, clean_indexes_train, problematic_rows_test, clean_rows_test, problematic_indexes_test, clean_indexes_test - ) \ No newline at end of file + ) + + def remove_problematic_rows(self, tokenizer: PreTrainedTokenizer, min_tokens_question: int = 256, min_tokens_answer: int = 256): + """ + Remove problematic rows from the dataset. + + Args: + tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding text. + min_tokens_question (int, optional): The minimum number of tokens required for questions. Defaults to 256. + min_tokens_answer (int, optional): The minimum number of tokens required for answers. Defaults to 256. + """ + # Get the problematic rows and clean rows for both train and test datasets + (problematic_rows_train, clean_rows_train, _, _, + problematic_rows_test, clean_rows_test, _, _) = self.get_token_counts(tokenizer, min_tokens_question, min_tokens_answer) + + # Check if clean_rows_train and clean_rows_test are not empty + if clean_rows_train: + # Remove problematic rows from the train dataset + self.train_data = clean_rows_train + + # Update the train questions and answers + self.train_questions, self.train_answers = zip(*clean_rows_train) + else: + # If clean_rows_train is empty, set train_data and related attributes to empty lists + self.train_data = [] + self.train_questions = [] + self.train_answers = [] + + # Check if clean_rows_test is not empty + if clean_rows_test: + # Remove problematic rows from the test dataset + self.test_data = clean_rows_test + + # Update the test questions and answers + self.test_questions, self.test_answers = zip(*clean_rows_test) + else: + # If clean_rows_test is empty, set test_data and related attributes to empty lists + self.test_data = [] + self.test_questions = [] + self.test_answers = [] \ No newline at end of file diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ead6d4f..226bf59 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -77,8 +77,8 @@ def test_init(sample_dataset): def test_get_token_counts(sample_dataset_with_tokens): dataset, tokenizer = sample_dataset_with_tokens - min_tokens_question = 10 - min_tokens_answer = 10 + min_tokens_question = 5 + min_tokens_answer = 5 ( problematic_rows_train, @@ -90,15 +90,74 @@ def test_get_token_counts(sample_dataset_with_tokens): problematic_indexes_test, clean_indexes_test, ) = dataset.get_token_counts(tokenizer, min_tokens_question, min_tokens_answer) - - assert isinstance(problematic_rows_train, list) - assert isinstance(clean_rows_train, list) - assert isinstance(problematic_indexes_train, list) - assert isinstance(clean_indexes_train, list) - assert isinstance(problematic_rows_test, list) - assert isinstance(clean_rows_test, list) - assert isinstance(problematic_indexes_test, list) - assert isinstance(clean_indexes_test, list) + + if sample_dataset_with_tokens[0].dataset_type == 'alpaca': + assert problematic_rows_train == [(['instr1 input1'], ['output1']), (['instr2 input2'], ['output2'])] + assert clean_rows_train == [] + assert problematic_indexes_train == [0, 1] + assert clean_indexes_train == [] + assert problematic_rows_test == [] + assert clean_rows_test == [(['test_instr1 test_input1'], ['test_output1']), (['test_instr2 test_input2'], ['test_output2'])] + assert problematic_indexes_test == [] + assert clean_indexes_test == [0, 1] + elif sample_dataset_with_tokens[0].dataset_type == 'sharegpt': + assert problematic_rows_train == [(['Hello'], ['Hi'])] + assert clean_rows_train == [(['How are you?'], ['I am fine.'])] + assert problematic_indexes_train == [0] + assert clean_indexes_train == [1] + assert problematic_rows_test == [(['Good morning'], ['Good morning!'])] + assert clean_rows_test == [(['What is your name?'], ['I am a chatbot.'])] + assert problematic_indexes_test == [0] + assert clean_indexes_test == [1] + elif sample_dataset_with_tokens[0].dataset_type == 'raw': + assert problematic_rows_train == [] + assert clean_rows_train == [(['How are you?'], ['I am fine.']), (['What is your name?'], ['I am a chatbot.'])] + assert problematic_indexes_train == [] + assert clean_indexes_train == [0, 1] + assert problematic_rows_test == [(['Good morning'], ['Good morning!'])] + assert clean_rows_test == [(['How can I help you?'], ['You can ask me anything.'])] + assert problematic_indexes_test == [0] + assert clean_indexes_test == [1] + +def test_remove_problematic_rows(sample_dataset_with_tokens): + dataset, tokenizer = sample_dataset_with_tokens + min_tokens_question = 5 + min_tokens_answer = 5 + + # Calculate token counts before removing problematic rows + ( + problematic_rows_train_before, + clean_rows_train_before, + problematic_indexes_train_before, + clean_indexes_train_before, + problematic_rows_test_before, + clean_rows_test_before, + problematic_indexes_test_before, + clean_indexes_test_before, + ) = dataset.get_token_counts(tokenizer, min_tokens_question, min_tokens_answer) + + # Remove problematic rows + dataset.remove_problematic_rows(tokenizer, min_tokens_question, min_tokens_answer) + + # Calculate token counts after removing problematic rows + ( + problematic_rows_train_after, + clean_rows_train_after, + problematic_indexes_train_after, + clean_indexes_train_after, + problematic_rows_test_after, + clean_rows_test_after, + problematic_indexes_test_after, + clean_indexes_test_after, + ) = dataset.get_token_counts(tokenizer, min_tokens_question, min_tokens_answer) + + # Ensure that the number of clean rows has increased + assert dataset.train_data == clean_rows_train_after + assert dataset.test_data == clean_rows_test_after + + # Ensure that no problematic rows are left + assert len(problematic_rows_train_after) == 0 + assert len(problematic_rows_test_after) == 0 if __name__ == "__main__": pytest.main()