diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index da443c4f6..0c9b372f0 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -29,6 +29,15 @@ logger = logging.getLogger(__name__) +def _get_device(group: ProcessGroup) -> torch.device: + if torch.distributed.is_nccl_available() and isinstance(group, torch.distributed.ProcessGroupNCCL): + return torch.device(torch.cuda.current_device()) + elif isinstance(group, torch.distributed.ProcessGroupGloo): + return torch.device("cpu") + else: + raise NotImplementedError(type(group)) + + @contextlib.contextmanager def set_timeout(group: ProcessGroup | None, timeout: float | None = None): if group is not None and timeout is not None: @@ -72,12 +81,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: ) -def safe_barrier( - group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None -) -> None: +def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None: if group: hashed = hash(value) % 2**32 - out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device) + out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout) if out != hashed * group.size(): raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})") @@ -88,10 +95,9 @@ def allreduce_scalar( group: torch.distributed.ProcessGroup | None = None, op=ReduceOp.SUM, timeout: float | None = None, - device: torch.device | None = None, ) -> float | int: if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device) + value = torch.full([1], value, dtype=dtype, device=_get_device(group)) with set_timeout(group, timeout): torch.distributed.all_reduce(value, op=op, group=group) return value.item() @@ -106,7 +112,7 @@ def all_gather_scalar( timeout: float | None = None, ): if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) + value = torch.full([1], value, dtype=dtype, device=_get_device(group)) output_tensor = value.new_empty((group.size(),)) with set_timeout(group, timeout): torch.distributed.all_gather_into_tensor(output_tensor, value, group=group) @@ -116,7 +122,7 @@ def all_gather_scalar( def broadcast_scalar( - value: float | int, + value: float | int | None, dtype: torch.dtype = torch.float64, group: torch.distributed.ProcessGroup | None = None, src: int = 0, @@ -124,7 +130,7 @@ def broadcast_scalar( ) -> float | int: if not group: return value - tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device())) + tensor = torch.empty([1], dtype=dtype, device=torch.device(_get_device(group))) if group.rank() == src: tensor.fill_(value) broadcast(tensor, src, group, timeout=timeout) @@ -141,14 +147,14 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None if group.rank() == src: tensor = _object_to_tensor(input_object) size = tensor.numel() - broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group)) broadcast_tensor.copy_(tensor) broadcast_scalar(size, torch.int64, group, src) broadcast(broadcast_tensor, src, group) return input_object else: size = int(broadcast_scalar(None, torch.int64, group, src)) - output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + output_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group)) broadcast(output_tensor, src, group) return _tensor_to_object(output_tensor) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index ede450dfa..8b18b59ba 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -302,7 +302,6 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp REDIS_DATA_STREAM = "fast_llm_streaming" -REDIS_FIELD = "data" REDIS_GROUP_NAME = "fast_llm_group" diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index c261e383e..48275988f 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -1,23 +1,108 @@ +import functools import json import typing import redis import torch.utils.data -from fast_llm.config import Configurable +from fast_llm.config import Config, Configurable, Field, config_class from fast_llm.data.dataset.abstract import SamplableIterableDataset -from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_FIELD, REDIS_GROUP_NAME, StreamingDatasetConfig +from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_GROUP_NAME, StreamingDatasetConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample +from fast_llm.data.sample.token_data import TokenDataSample from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.utils import Assert -def dtype_from_string(name: str) -> torch.dtype: - try: - return getattr(torch, name) - except AttributeError: - raise ValueError(f"Unknown torch dtype: {name}") +@config_class() +class RedisDocument(Config): + """ + Schema for sending and receiving documents through redis, and the associated handling code. + """ + + tokens: torch.Tensor = Field() + loss_masking_spans: list[tuple[int, int]] | None = Field(default=None) + chosen_span: tuple[int, int] | None = Field(default=None) + rejected_span: tuple[int, int] | None = Field(default=None) + advantage: float | None = Field(default=None) + old_log_probabilities: torch.Tensor | None = Field(default=None) + + def _validate(self): + # Decode message + if isinstance(self.tokens, bytes): + self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64) + elif isinstance(self.tokens, (list, tuple)): + self.tokens = torch.tensor(self.tokens, dtype=torch.int64) + if isinstance(self.loss_masking_spans, str): + self.loss_masking_spans = json.loads(self.loss_masking_spans) + if isinstance(self.chosen_span, str): + self.chosen_span = json.loads(self.chosen_span) + if isinstance(self.rejected_span, str): + self.rejected_span = json.loads(self.rejected_span) + if isinstance(self.old_log_probabilities, bytes): + self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32) + elif isinstance(self.old_log_probabilities, (list, tuple)): + self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32) + super()._validate() + if self.old_log_probabilities is not None: + Assert.eq(len(self.old_log_probabilities), self.num_tokens) + + @functools.cached_property + def num_tokens(self) -> int: + return len(self.tokens) + + @classmethod + def from_message(cls, message: dict[bytes, bytes]) -> typing.Self: + # Read + kwargs = {} + for key, value in message.items(): + key = key.decode() + if key == "data": + kwargs.update(json.loads(value)) + else: + kwargs[key] = value + return cls.from_dict(kwargs) + + def to_message(self) -> dict[str, str | int | float | bytes]: + # Encode message + message: dict[str, str | int | float | bytes] = {"tokens": self.tokens.numpy().tobytes()} + if self.old_log_probabilities is not None: + message["old_log_probabilities"] = self.old_log_probabilities.numpy().tobytes() + data = {} + if self.loss_masking_spans is not None: + data["loss_masking_spans"] = self.loss_masking_spans + if self.chosen_span is not None: + data["chosen_span"] = self.chosen_span + if self.rejected_span is not None: + data["rejected_span"] = self.rejected_span + if self.advantage is not None: + data["advantage"] = self.advantage + if data: + message["data"] = json.dumps(data) + return message + + def to_sample(self): + sample_size = len(self.tokens) + return LanguageModelSample( + tokens=TokenSample(self.tokens, [sample_size]), + loss_masking_spans=( + None + if self.loss_masking_spans is None + else RangeSample([(begin, end) for begin, end in self.loss_masking_spans], sample_size) + ), + chosen_spans=None if self.chosen_span is None else RangeSample([self.chosen_span], sample_size), + rejected_spans=None if self.rejected_span is None else RangeSample([self.rejected_span], sample_size), + advantages=( + None + if self.advantage is None + else TokenDataSample(torch.full([sample_size], self.advantage, dtype=torch.float32)) + ), + old_log_probabilities=( + None if self.old_log_probabilities is None else TokenDataSample(self.old_log_probabilities) + ), + ) class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample]( @@ -77,29 +162,8 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: noack=True, ) if messages: - for stream_key, msgs in messages: + for stream_key, messages_ in messages: assert stream_key == REDIS_DATA_STREAM.encode() - for msg_id, msg_data in msgs: - yield self._read_document(json.loads(msg_data[REDIS_FIELD.encode()])) - - def _read_document(self, data: dict) -> LanguageModelSample: - tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) - sample_size = len(tokens) - if "loss_masking_spans" in data: - loss_masking_spans = RangeSample([(begin, end) for begin, end in data["loss_masking_spans"]], sample_size) - else: - loss_masking_spans = None - if "chosen_spans" in data: - chosen_spans = RangeSample([(begin, end) for begin, end in data["chosen_spans"]], sample_size) - else: - chosen_spans = None - if "rejected_spans" in data: - rejected_spans = RangeSample([(begin, end) for begin, end in data["rejected_spans"]], sample_size) - else: - rejected_spans = None - return LanguageModelSample( - TokenSample(tokens, [sample_size]), - loss_masking_spans, - chosen_spans, - rejected_spans, - ) + for message_id, message in messages_: + print(message) + yield RedisDocument.from_message(message).to_sample() diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index d54776eec..87d176663 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -22,6 +22,8 @@ class LanguageModelPreprocessingConfig(PreprocessingConfig): vocab_size: int | None = Field(default=None) use_loss_masking_spans: bool = Field(default=False) use_preference_spans: bool = Field(default=False) + use_advantages: bool = Field(default=False) + use_old_log_probabilities: bool = Field(default=False) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 22b89acf1..e3dab9bc2 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -39,6 +39,7 @@ RangeWriter, ) from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter +from fast_llm.data.sample.token_data import TokenDataBatch, TokenDataReader, TokenDataReaderConfig, TokenDataSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -53,12 +54,16 @@ def __init__( chosen_spans: RangeSample | None = None, rejected_spans: RangeSample | None = None, image_patches: PatchSample | None = None, + advantages: TokenDataSample | None = None, + old_log_probabilities: TokenDataSample | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans self.image_patches = image_patches + self.advantages = advantages + self.old_log_probabilities = old_log_probabilities @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: @@ -68,6 +73,10 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), _merge_optional(PatchSample.from_documents, [document.image_patches for document in documents]), + _merge_optional(TokenDataSample.from_documents, [document.advantages for document in documents]), + _merge_optional( + TokenDataSample.from_documents, [document.old_log_probabilities for document in documents] + ), ) def crop(self, begin: int, end: int) -> typing.Self: @@ -77,18 +86,22 @@ def crop(self, begin: int, end: int) -> typing.Self: _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), _crop_optional(self.image_patches, begin, end), + _crop_optional(self.advantages, begin, end), + _crop_optional(self.old_log_probabilities, begin, end), ) def __len__(self) -> int: return len(self.tokens) def get_padding(self, size: int) -> typing.Self: - return LanguageModelSample( + return self.__class__( self.tokens.get_padding(size), None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), None if self.chosen_spans is None else self.chosen_spans.get_padding(size), None if self.rejected_spans is None else self.rejected_spans.get_padding(size), None if self.image_patches is None else self.image_patches.get_padding(size), + None if self.advantages is None else self.advantages.get_padding(size), + None if self.old_log_probabilities is None else self.old_log_probabilities.get_padding(size), ) @@ -100,12 +113,16 @@ def __init__( chosen_spans: RangeBatch | None = None, rejected_spans: RangeBatch | None = None, image_patches: PatchBatch | None = None, + advantages: TokenDataBatch | None = None, + old_log_probabilities: TokenDataBatch | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans self.image_patches = image_patches + self.advantages = advantages + self.old_log_probabilities = old_log_probabilities @classmethod def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: @@ -115,6 +132,8 @@ def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.S _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), _merge_optional(PatchBatch.from_samples, [sample.image_patches for sample in samples]), + _merge_optional(TokenDataBatch.from_samples, [sample.advantages for sample in samples]), + _merge_optional(TokenDataBatch.from_samples, [sample.old_log_probabilities for sample in samples]), ) def crop(self, begin: int, end: int) -> typing.Self: @@ -124,6 +143,8 @@ def crop(self, begin: int, end: int) -> typing.Self: _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), _crop_optional(self.image_patches, begin, end), + _crop_optional(self.advantages, begin, end), + _crop_optional(self.old_log_probabilities, begin, end), ) def to_device_(self, device: "torch.device | str"): @@ -136,6 +157,10 @@ def to_device_(self, device: "torch.device | str"): self.rejected_spans.to_device_(device) if self.image_patches is not None: self.image_patches.to_device_(device) + if self.advantages is not None: + self.advantages.to_device_(device) + if self.old_log_probabilities is not None: + self.old_log_probabilities.to_device_(device) def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: @@ -157,6 +182,8 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): chosen_spans: MemmapReaderBaseConfig = Field() rejected_spans: MemmapReaderBaseConfig = Field() image_patches: MemmapReaderBaseConfig = Field() + advantages: MemmapReaderBaseConfig = Field() + old_log_probabilities: MemmapReaderBaseConfig = Field() def _validate(self) -> None: super()._validate() @@ -192,6 +219,16 @@ def _validate(self) -> None: self.rejected_spans, RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, ) + Assert.custom( + isinstance, + self.advantages, + TokenDataReaderConfig if self.preprocessing.use_advantages else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.old_log_probabilities, + TokenDataReaderConfig if self.preprocessing.use_old_log_probabilities else NullReaderConfig, + ) if self.preprocessing.use_image_patches: Assert.custom(isinstance, self.image_patches, PatchReaderConfig) Assert.eq(self.image_patches.patch_shape, self.preprocessing.image_patches.patch_shape) @@ -222,6 +259,8 @@ def _expected_buffer_size(self) -> int: + self.chosen_spans.expected_buffer_size + self.rejected_spans.expected_buffer_size + self.image_patches.expected_buffer_size + + self.advantages.expected_buffer_size + + self.old_log_probabilities.expected_buffer_size ) def get_metadata(self) -> dict[str, typing.Any]: @@ -235,6 +274,10 @@ def get_metadata(self) -> dict[str, typing.Any]: out["rejected_spans"] = self.rejected_spans.get_metadata() if not isinstance(self.image_patches, NullReaderConfig): out["image_patches"] = self.image_patches.get_metadata() + if not isinstance(self.advantages, NullReaderConfig): + out["advantages"] = self.advantages.get_metadata() + if not isinstance(self.old_log_probabilities, NullReaderConfig): + out["old_log_probabilities"] = self.old_log_probabilities.get_metadata() return out @classmethod @@ -257,6 +300,12 @@ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typi out["image_patches"] = PatchReaderConfig.blend_metadata( [metadata_["image_patches"] for metadata_ in metadata] ) + if "advantages" in metadata[0]: + out["advantages"] = RangeReaderConfig.blend_metadata([metadata_["advantages"] for metadata_ in metadata]) + if "old_log_probabilities" in metadata[0]: + out["old_log_probabilities"] = RangeReaderConfig.blend_metadata( + [metadata_["old_log_probabilities"] for metadata_ in metadata] + ) return out @@ -290,6 +339,10 @@ def __init__( self._chosen_spans = self._config.chosen_spans.get_reader(buffer) self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + if self._model_preprocessing.use_advantages: + self._advantages = self._config.advantages.get_reader(buffer) + self._old_log_probabilities = self._config.old_log_probabilities.get_reader(buffer) + if self._model_preprocessing.use_image_patches: model_image_preprocessing: ImagePatchConfig = self._model_preprocessing.image_patches if isinstance(self._config.image_patches, NullReaderConfig): @@ -334,6 +387,12 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: else None ), image_patches, + (self._advantages.get_document(index, begin, end) if self._model_preprocessing.use_advantages else None), + ( + self._old_log_probabilities.get_document(index, begin, end) + if self._model_preprocessing.use_old_log_probabilities + else None + ), ) def get_document_sizes(self) -> torch.Tensor: @@ -356,6 +415,10 @@ def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dic metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index) if hasattr(self, "_image_patches") and isinstance(self._image_patches, PatchReader): metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index) + if hasattr(self, "_advantages") and isinstance(self._advantages, TokenDataReader): + metadata["advantages"] = self._advantages.get_split(begin_index, end_index) + if hasattr(self, "_old_log_probabilities") and isinstance(self._old_log_probabilities, TokenDataReader): + metadata["old_log_probabilities"] = self._old_log_probabilities.get_split(begin_index, end_index) return begin_index, end_index, metadata @@ -379,6 +442,10 @@ def __enter__(self): self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() if self._preprocessing_config.use_image_patches: self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() + if self._preprocessing_config.use_advantages: + self._advantages_writer = PatchWriter(self._path.joinpath("advantages")).__enter__() + if self._preprocessing_config.use_old_log_probabilities: + self._old_log_probabilities_writer = PatchWriter(self._path.joinpath("old_log_probabilities")).__enter__() return self def write(self, document: LanguageModelSample): @@ -403,6 +470,14 @@ def write(self, document: LanguageModelSample): assert document.image_patches is not None self._image_patches_writer.write(document.image_patches) + if self._preprocessing_config.use_advantages: + assert document.advantages is not None + self._advantages_writer.write(document.advantages) + + if self._preprocessing_config.use_old_log_probabilities: + assert document.old_log_probabilities is not None + self._old_log_probabilities_writer.write(document.old_log_probabilities) + def __exit__(self, exc_type, exc_val, exc_tb): self._token_writer.__exit__(exc_type, exc_val, exc_tb) if self._preprocessing_config.use_loss_masking_spans: @@ -412,6 +487,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) if self._preprocessing_config.use_image_patches: self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_advantages: + self._advantages_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_old_log_probabilities: + self._old_log_probabilities_writer.__exit__(exc_type, exc_val, exc_tb) if exc_type is None: # A dummy config so we can verify the begin and end offsets. @@ -475,6 +554,16 @@ def _get_config(self, begin: int, end: int | None): offset = image_patches.end else: image_patches = NullReaderConfig() + if self._preprocessing_config.use_advantages: + advantages = self._advantages_writer.get_config(offset) + offset = advantages.end + else: + advantages = NullReaderConfig() + if self._preprocessing_config.use_old_log_probabilities: + old_log_probabilities = self._old_log_probabilities_writer.get_config(offset) + offset = old_log_probabilities.end + else: + old_log_probabilities = NullReaderConfig() if end is None: end = offset + len(LanguageModelReaderConfig.footer) @@ -488,6 +577,8 @@ def _get_config(self, begin: int, end: int | None): rejected_spans=rejected_spans, image_patches=image_patches, preprocessing=self._preprocessing_config, + advantages=advantages, + old_log_probabilities=old_log_probabilities, ) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 7ae537104..32ea60cb8 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -85,7 +85,7 @@ def __len__(self) -> int: return self.sample_size def get_padding(self, size: int) -> typing.Self: - return PatchSample( + return self.__class__( self.patches.new_empty((0, *self.patches.shape[1:])), self.token_map.new_empty(0), self.positions.new_empty([0, self.patches.ndim - 2]), diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 53683342a..f57ee04d9 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -52,7 +52,7 @@ def __len__(self) -> int: return self.sample_size def get_padding(self, size: int) -> typing.Self: - return RangeSample([], size) + return self.__class__([], size) class RangeBatch(Batch): diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index cd4d7fa02..6ab55dbba 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -58,7 +58,7 @@ def __len__(self) -> int: return len(self.tokens) def get_padding(self, size: int) -> typing.Self: - return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + return self.__class__(torch.full([size], -100, dtype=self.tokens.dtype), [size]) class TokenBatch(Batch): diff --git a/fast_llm/data/sample/token_data.py b/fast_llm/data/sample/token_data.py new file mode 100644 index 000000000..6d2a6f9d1 --- /dev/null +++ b/fast_llm/data/sample/token_data.py @@ -0,0 +1,190 @@ +import functools +import math +import typing + +import numpy as np +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.data.sample.abstract import ( + Batch, + MemmapIndexedDatasetReader, + MemmapReaderBase, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) +from fast_llm.data.sample.patch import PatchReaderBaseConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert, get_unique + + +class TokenDataSample(Sample): + """ + A reusable component holding tensor-valued data of fixed dtype and shape for each token. + TODO: Use as base class for `TokenSample` and `PatchSample`? + """ + + def __init__(self, data: torch.Tensor): + self.data = data + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls(torch.cat([document.data for document in documents])) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__(self.data[begin:end]) + + def __len__(self) -> int: + return len(self.data) + + def get_padding(self, size: int) -> typing.Self: + return self.__class__(torch.full([size], 0, dtype=self.data.dtype)) + + +class TokenDataBatch(Batch): + def __init__(self, data: torch.Tensor) -> None: + self.data = data + + @classmethod + def from_samples(cls, samples: typing.Iterable[TokenDataSample]) -> typing.Self: + return cls(torch.stack([sample.data for sample in samples])) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__(self.data[:, begin:end]) + + def to_device_(self, device: "torch.device | str"): + self.data = self.data.to(device, non_blocking=True) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "token_data"}) +class TokenDataReaderConfig(MemmapReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"token data begin" + footer: typing.ClassVar[bytes] = b"token data end" + num_documents: int = Field() + num_tokens: int = Field() + shape: tuple[int, ...] = Field() + data_type: DataType = Field() + + def __len__(self) -> int: + return self.num_documents + + @functools.cached_property + def size(self) -> int: + return math.prod(self.shape) + + @property + def reader_class(self) -> "type[TokenDataReader]": + return TokenDataReader + + @property + def writer_class(self) -> "type[TokenDataWriter]": + return TokenDataWriter + + @property + def _expected_buffer_size(self) -> int: + return ( + self.num_tokens * self.data_type.torch.itemsize * self.size + + (self.num_documents + 1) * torch.int64.itemsize + ) + + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_tokens": self.num_tokens, + "num_documents": self.num_documents, + "data_type": str(self.data_type), + "shape": self.shape, + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata), + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + "shape": get_unique(metadata_["shape"] for metadata_ in metadata), + } + + +class TokenDataReader[ConfigType: TokenDataReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) + self._data = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_tokens * self._config.size, + ).view(-1, *self._config.shape) + self._size_cumsums = torch.frombuffer( + self._buffer, dtype=torch.int64, count=self._config.num_documents + 1, offset=self._data.nbytes + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + begin_ = self._size_cumsums[index].item() + return TokenDataSample(self._data[begin_ + begin : begin_ + end]) + + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + Assert.custom(lambda x: x == sorted(x), [0, begin_ratio, end_ratio, 1]) + begin_index = _get_nearest_split(self._size_cumsums[1:], begin_ratio * self.num_tokens) + end_index = _get_nearest_split(self._size_cumsums[1:], end_ratio * self.num_tokens) + + return ( + begin_index, + end_index, + { + "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), + "num_documents": end_index - begin_index, + "data_type": str(self._config.data_type), + }, + ) + + +class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]): + def get_document(self, index: int, begin: int, end: int) -> Sample: + # TODO: Does this make sense? + return TokenDataSample(torch.zeros(end - begin, *self._config.shape, dtype=self._config.data_type.torch)) + + +def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: + left = torch.searchsorted(cumsum, value, side="right") + if left == len(cumsum): + return left.item() + return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + + +class TokenDataWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + return self + + def write(self, document: TokenDataSample): + super().write(document) + if self._data_type is None: + self._data_type = document.data.dtype + else: + Assert.eq(self._data_type, document.data.dtype) + self._stream.write(document.data.numpy().tobytes()) + self._size_cumsum.append(self._size_cumsum[-1] + len(document.data)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[TokenDataReaderConfig]: + return TokenDataReaderConfig + + def _get_config(self, begin: int, end: int): + return TokenDataReaderConfig( + begin=begin, + end=end, + num_documents=len(self._size_cumsum) - 1, + num_tokens=self._size_cumsum[-1], + data_type=DataType.from_torch(self._data_type), + preprocessing=self._preprocessing_config, + ) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index d8953488f..05bd05285 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -12,18 +12,17 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.config import RedisConfig, SamplingParameters, StreamingDatasetConfig -from fast_llm.data.dataset.streaming import RedisStreamingDataset +from fast_llm.data.dataset.config import REDIS_DATA_STREAM, RedisConfig, SamplingParameters, StreamingDatasetConfig +from fast_llm.data.dataset.streaming import RedisDocument, RedisStreamingDataset from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert from tests.conftest import WorkerResources -from tests.utils.redis import make_sampling, push_msg, redis_batch_producer +from tests.utils.redis import make_sampling, redis_batch_producer from tests.utils.subtest import DistributedTestContext -from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -40,31 +39,63 @@ def fake_redis(monkeypatch): @pytest.mark.parametrize( - "messages", + "documents", [ (range(3),), - (range(3), range(3, 7)), - (range(3), range(5), [], [9, 4]), + (range(3), range(3, 6)), + (range(3), range(5), [9, 4]), + ( + {"tokens": list(range(3)), "advantage": 0.33, "old_log_probabilities": [0.25, -0.52, 0.99]}, + {"tokens": list(range(5)), "loss_masking_spans": [(0, 1), (2, 3)]}, + {"tokens": list(range(8)), "chosen_span": (0, 2), "rejected_span": (3, 5)}, + ), ], ) def test_streaming_dataset( fake_redis: fakeredis.FakeRedis, - messages: tuple[list[int], ...], + documents: tuple[list[int] | dict[str, typing.Any], ...], worker_resources: WorkerResources, ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) dataset_iterator = iter(RedisStreamingDataset(stream_config, DistributedConfig())) - for message in messages: - push_msg(fake_redis, list(message)) - for message in messages: + documents = [document if isinstance(document, dict) else {"tokens": list(document)} for document in documents] + for document in documents: + fake_redis.xadd(REDIS_DATA_STREAM, RedisDocument.from_dict(document).to_message()) + for document in documents: sample = next(dataset_iterator) assert isinstance(sample, LanguageModelSample) - Assert.eq(sample.tokens.tokens.tolist(), list(message)) - Assert.eq(sample.tokens.lengths, [len(message)]) - assert sample.loss_masking_spans is None - assert sample.chosen_spans is None - assert sample.rejected_spans is None + Assert.eq(sample.tokens.tokens.tolist(), document["tokens"]) + Assert.eq(sample.tokens.lengths, [len(document["tokens"])]) + + if "loss_masking_spans" in document: + Assert.eq(sample.loss_masking_spans.ranges, document["loss_masking_spans"]) + else: + assert sample.loss_masking_spans is None + + if "chosen_span" in document: + Assert.eq(sample.chosen_spans.ranges, [document["chosen_span"]]) + else: + assert sample.chosen_spans is None + + if "rejected_span" in document: + Assert.eq(sample.rejected_spans.ranges, [document["rejected_span"]]) + else: + assert sample.rejected_spans is None + + assert sample.image_patches is None + + if "advantage" in document: + Assert.rms_close( + sample.advantages.data, torch.full([len(document["tokens"])], document["advantage"]), 1e-8 + ) + else: + assert sample.advantages is None + + if "old_log_probabilities" in document: + Assert.rms_close(sample.old_log_probabilities.data, torch.tensor(document["old_log_probabilities"]), 1e-8) + else: + assert sample.old_log_probabilities is None @pytest.mark.parametrize( @@ -95,12 +126,12 @@ def test_streaming_sampled_dataset( ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) - distributed = Distributed(DistributedConfig(), use_cpu=True) + distributed = Distributed(DistributedConfig(use_cuda=False)) dataset_iterator = iter( RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(5, 1, distributed)) ) for message in messages: - push_msg(fake_redis, list(message)) + fake_redis.xadd(REDIS_DATA_STREAM, RedisDocument.from_dict({"tokens": list(message)}).to_message()) for expected_sample, expected_lengths_ in zip(expected_samples, expected_lengths, strict=True): sample = next(dataset_iterator) assert isinstance(sample, LanguageModelSample) @@ -118,7 +149,13 @@ def _get_distributed_and_batch_config( distributed_config_dict: dict[str, typing.Any], world_size: int = 1 ) -> tuple[DistributedConfig, GPTBatchConfig]: distributed_config = DistributedConfig.from_dict( - distributed_config_dict, {"world_size": world_size, "local_world_size": world_size} + distributed_config_dict, + { + "world_size": world_size, + "local_world_size": world_size, + "use_cuda": False, + "backend": DistributedBackend.gloo, + }, ) with NoAutoValidate(): batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=10) @@ -189,14 +226,15 @@ def _run_test_data_streaming_distributed( # Import all dynamic classes. TODO: needed? import fast_llm.cli # noqa + print(_DISTRIBUTED_TESTING_CONFIGS) for name, num_gpus, distributed_config_dict in _DISTRIBUTED_TESTING_CONFIGS: with test_context.subtest(base_path, name, num_gpus) as subtest: + print(name, subtest.do_run) if subtest.do_run: distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) _run_test_data_streaming(base_path / name, distributed_config, batch_config, port) -@requires_cuda def test_data_streaming(result_path, worker_resources): distributed_config, batch_config = _get_distributed_and_batch_config({}) path = result_path / "data_streaming/single_gpu" @@ -218,24 +256,22 @@ def test_data_streaming(result_path, worker_resources): ] -@requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources): - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs") run_parallel_script( _run_test_data_streaming_distributed, (result_path / "data_streaming", worker_resources.torchrun_port), - world_size=torch.cuda.device_count(), + world_size=4, + backend=DistributedBackend.gloo, + use_cuda=False, # Disable device count check. ) -@requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) @pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS) def test_data_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): - report_subtest(path := result_path / f"data_streaming/{name}", num_gpus) + report_subtest(path := result_path / f"data_streaming/{name}", num_gpus, use_cuda=False) distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) check_data_streaming_results(path, distributed_config, batch_config) diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py index 9c06c1919..db53fe0d9 100644 --- a/tests/functional/test_entropy_loss.py +++ b/tests/functional/test_entropy_loss.py @@ -176,4 +176,6 @@ def test_run_entropy_loss_distributed(run_parallel_script, result_path): def test_entropy_loss_distributed(result_path, report_subtest, target_format, entropy_loss_type, loss_masking): if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: pytest.skip(reason="Not implemented") - report_subtest(result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2) + report_subtest( + result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2, use_cuda=False + ) diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 591ee74e6..198c6df78 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -1,6 +1,5 @@ import contextlib import itertools -import json import pathlib import socket import threading @@ -10,7 +9,6 @@ from fast_llm.data.dataset.config import ( REDIS_DATA_STREAM, - REDIS_FIELD, REDIS_GROUP_NAME, RedisConfig, SamplingConfig, @@ -18,6 +16,7 @@ SamplingParameters, StreamingDatasetConfig, ) +from fast_llm.data.dataset.streaming import RedisDocument from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.models.gpt.config import GPTBatchConfig @@ -29,11 +28,6 @@ def find_free_port(): return s.getsockname()[1] -def push_msg(redis_client, tokens): - """Push a message into FakeRedis stream.""" - redis_client.xadd(REDIS_DATA_STREAM, {REDIS_FIELD: json.dumps({"tokens": tokens, "tokens_dtype": "int64"})}) - - def wait_until_stream_empty( redis_client, stream_key, @@ -76,7 +70,10 @@ def producer_loop(): for sample_index in itertools.count(): if stop_event.is_set(): break - push_msg(client, [sample_index] * batch_config.sequence_length) + client.xadd( + REDIS_DATA_STREAM, + RedisDocument.from_dict({"tokens": [sample_index] * batch_config.sequence_length}).to_message(), + ) if sample_index % 5 == 0: wait_until_stream_empty(client, REDIS_DATA_STREAM, REDIS_GROUP_NAME, stop_event) diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index 3ca84499e..e5c87f9f5 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -51,12 +51,12 @@ def __enter__(self): self._configure_logging() self._group = self._pool.get_process_group(range(self._world_size), self._rank) # TODO: Barriers needed? - safe_barrier(self._group, "start", device=self._pool.device) + safe_barrier(self._group, "start") return self def __exit__(self, exc_type, exc_val, exc_tb): # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(self._group, "testing end", device=self._pool.device) + safe_barrier(self._group, "testing end") # Let pytest know how things went. # These should already be reported above, we repeat for convenience. if self._failures: @@ -138,13 +138,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): if (group := self._test_context._group) is not None: # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(group, self._name, device=self._test_context._pool.device) - self._success = ( - allreduce_scalar( - self._success, dtype=torch.int64, group=group, device=self._test_context._pool.device - ) - == group.size() - ) + safe_barrier(group, self._name) + self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() if self._do_capture and torch.cuda.is_available(): # Free resources to limit memory usage. @@ -201,8 +196,8 @@ def report_subtest(request: pytest.FixtureRequest): verbose = request.config.getoption("verbose") do_capture = request.config.getoption("distributed_capture") - def do_report_subtest(path: pathlib.Path, world_size: int) -> None: - if torch.cuda.device_count() < world_size: + def do_report_subtest(path: pathlib.Path, world_size: int, use_cuda: bool = True) -> None: + if use_cuda and torch.cuda.device_count() < world_size: pytest.skip(f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {world_size}") success = check_subtest_success(path) if not do_capture: