diff --git a/src/NERDA/preprocessing.py b/src/NERDA/preprocessing.py index 4dd7dae..4d481ae 100644 --- a/src/NERDA/preprocessing.py +++ b/src/NERDA/preprocessing.py @@ -99,6 +99,10 @@ def __getitem__(self, item): # compute padding length if self.pad_sequences: padding_len = self.max_len - len(input_ids) + if self.pad_token_id == None: + input_ids = input_ids + ([0] * padding_len) + else: + input_ids = input_ids + ([self.pad_token_id] * padding_len) input_ids = input_ids + ([self.pad_token_id] * padding_len) masks = masks + ([0] * padding_len) offsets = offsets + ([0] * padding_len)