Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
logger = logging.getLogger(__name__)


def _get_device(group: ProcessGroup) -> torch.device:
if torch.distributed.is_nccl_available() and isinstance(group, torch.distributed.ProcessGroupNCCL):
return torch.device(torch.cuda.current_device())
elif isinstance(group, torch.distributed.ProcessGroupGloo):
return torch.device("cpu")
else:
raise NotImplementedError(type(group))


@contextlib.contextmanager
def set_timeout(group: ProcessGroup | None, timeout: float | None = None):
if group is not None and timeout is not None:
Expand Down Expand Up @@ -72,12 +81,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
)


def safe_barrier(
group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None
) -> None:
def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None:
if group:
hashed = hash(value) % 2**32
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device)
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout)
if out != hashed * group.size():
raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})")

Expand All @@ -88,10 +95,9 @@ def allreduce_scalar(
group: torch.distributed.ProcessGroup | None = None,
op=ReduceOp.SUM,
timeout: float | None = None,
device: torch.device | None = None,
) -> float | int:
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device)
value = torch.full([1], value, dtype=dtype, device=_get_device(group))
with set_timeout(group, timeout):
torch.distributed.all_reduce(value, op=op, group=group)
return value.item()
Expand All @@ -106,7 +112,7 @@ def all_gather_scalar(
timeout: float | None = None,
):
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
value = torch.full([1], value, dtype=dtype, device=_get_device(group))
output_tensor = value.new_empty((group.size(),))
with set_timeout(group, timeout):
torch.distributed.all_gather_into_tensor(output_tensor, value, group=group)
Expand All @@ -116,15 +122,15 @@ def all_gather_scalar(


def broadcast_scalar(
value: float | int,
value: float | int | None,
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
src: int = 0,
timeout: float | None = None,
) -> float | int:
if not group:
return value
tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device()))
tensor = torch.empty([1], dtype=dtype, device=torch.device(_get_device(group)))
if group.rank() == src:
tensor.fill_(value)
broadcast(tensor, src, group, timeout=timeout)
Expand All @@ -141,14 +147,14 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None
if group.rank() == src:
tensor = _object_to_tensor(input_object)
size = tensor.numel()
broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group))
broadcast_tensor.copy_(tensor)
broadcast_scalar(size, torch.int64, group, src)
broadcast(broadcast_tensor, src, group)
return input_object
else:
size = int(broadcast_scalar(None, torch.int64, group, src))
output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
output_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group))
broadcast(output_tensor, src, group)
return _tensor_to_object(output_tensor)

Expand Down
1 change: 0 additions & 1 deletion fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp


REDIS_DATA_STREAM = "fast_llm_streaming"
REDIS_FIELD = "data"
REDIS_GROUP_NAME = "fast_llm_group"


Expand Down
128 changes: 96 additions & 32 deletions fast_llm/data/dataset/streaming.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,108 @@
import functools
import json
import typing

import redis
import torch.utils.data

from fast_llm.config import Configurable
from fast_llm.config import Config, Configurable, Field, config_class
from fast_llm.data.dataset.abstract import SamplableIterableDataset
from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_FIELD, REDIS_GROUP_NAME, StreamingDatasetConfig
from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_GROUP_NAME, StreamingDatasetConfig
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.data.sample.range import RangeSample
from fast_llm.data.sample.token import TokenSample
from fast_llm.data.sample.token_data import TokenDataSample
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.utils import Assert


def dtype_from_string(name: str) -> torch.dtype:
try:
return getattr(torch, name)
except AttributeError:
raise ValueError(f"Unknown torch dtype: {name}")
@config_class()
class RedisDocument(Config):
"""
Schema for sending and receiving documents through redis, and the associated handling code.
"""

tokens: torch.Tensor = Field()
loss_masking_spans: list[tuple[int, int]] | None = Field(default=None)
chosen_span: tuple[int, int] | None = Field(default=None)
rejected_span: tuple[int, int] | None = Field(default=None)
advantage: float | None = Field(default=None)
old_log_probabilities: torch.Tensor | None = Field(default=None)

def _validate(self):
# Decode message
if isinstance(self.tokens, bytes):
self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64)
elif isinstance(self.tokens, (list, tuple)):
self.tokens = torch.tensor(self.tokens, dtype=torch.int64)
if isinstance(self.loss_masking_spans, str):
self.loss_masking_spans = json.loads(self.loss_masking_spans)
if isinstance(self.chosen_span, str):
self.chosen_span = json.loads(self.chosen_span)
if isinstance(self.rejected_span, str):
self.rejected_span = json.loads(self.rejected_span)
if isinstance(self.old_log_probabilities, bytes):
self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32)
elif isinstance(self.old_log_probabilities, (list, tuple)):
self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32)
super()._validate()
if self.old_log_probabilities is not None:
Assert.eq(len(self.old_log_probabilities), self.num_tokens)

@functools.cached_property
def num_tokens(self) -> int:
return len(self.tokens)

@classmethod
def from_message(cls, message: dict[bytes, bytes]) -> typing.Self:
# Read
kwargs = {}
for key, value in message.items():
key = key.decode()
if key == "data":
kwargs.update(json.loads(value))
else:
kwargs[key] = value
return cls.from_dict(kwargs)

def to_message(self) -> dict[str, str | int | float | bytes]:
# Encode message
message: dict[str, str | int | float | bytes] = {"tokens": self.tokens.numpy().tobytes()}
if self.old_log_probabilities is not None:
message["old_log_probabilities"] = self.old_log_probabilities.numpy().tobytes()
data = {}
if self.loss_masking_spans is not None:
data["loss_masking_spans"] = self.loss_masking_spans
if self.chosen_span is not None:
data["chosen_span"] = self.chosen_span
if self.rejected_span is not None:
data["rejected_span"] = self.rejected_span
if self.advantage is not None:
data["advantage"] = self.advantage
if data:
message["data"] = json.dumps(data)
return message

def to_sample(self):
sample_size = len(self.tokens)
return LanguageModelSample(
tokens=TokenSample(self.tokens, [sample_size]),
loss_masking_spans=(
None
if self.loss_masking_spans is None
else RangeSample([(begin, end) for begin, end in self.loss_masking_spans], sample_size)
),
chosen_spans=None if self.chosen_span is None else RangeSample([self.chosen_span], sample_size),
rejected_spans=None if self.rejected_span is None else RangeSample([self.rejected_span], sample_size),
advantages=(
None
if self.advantage is None
else TokenDataSample(torch.full([sample_size], self.advantage, dtype=torch.float32))
),
old_log_probabilities=(
None if self.old_log_probabilities is None else TokenDataSample(self.old_log_probabilities)
),
)


class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample](
Expand Down Expand Up @@ -77,29 +162,8 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]:
noack=True,
)
if messages:
for stream_key, msgs in messages:
for stream_key, messages_ in messages:
assert stream_key == REDIS_DATA_STREAM.encode()
for msg_id, msg_data in msgs:
yield self._read_document(json.loads(msg_data[REDIS_FIELD.encode()]))

def _read_document(self, data: dict) -> LanguageModelSample:
tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"]))
sample_size = len(tokens)
if "loss_masking_spans" in data:
loss_masking_spans = RangeSample([(begin, end) for begin, end in data["loss_masking_spans"]], sample_size)
else:
loss_masking_spans = None
if "chosen_spans" in data:
chosen_spans = RangeSample([(begin, end) for begin, end in data["chosen_spans"]], sample_size)
else:
chosen_spans = None
if "rejected_spans" in data:
rejected_spans = RangeSample([(begin, end) for begin, end in data["rejected_spans"]], sample_size)
else:
rejected_spans = None
return LanguageModelSample(
TokenSample(tokens, [sample_size]),
loss_masking_spans,
chosen_spans,
rejected_spans,
)
for message_id, message in messages_:
print(message)
yield RedisDocument.from_message(message).to_sample()
2 changes: 2 additions & 0 deletions fast_llm/data/preprocessing/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class LanguageModelPreprocessingConfig(PreprocessingConfig):
vocab_size: int | None = Field(default=None)
use_loss_masking_spans: bool = Field(default=False)
use_preference_spans: bool = Field(default=False)
use_advantages: bool = Field(default=False)
use_old_log_probabilities: bool = Field(default=False)

def _validate(self) -> None:
super()._validate()
Expand Down
Loading