diff --git a/roll/third_party/vllm/__init__.py b/roll/third_party/vllm/__init__.py index 77f67cbbb..a074a2d2c 100644 --- a/roll/third_party/vllm/__init__.py +++ b/roll/third_party/vllm/__init__.py @@ -59,7 +59,8 @@ async def create_async_llm(resource_placement_groups: List[Dict], **kwargs): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "" # torch.cuda may already init, explicitly disable expandable_segments # here (only matters when VLLM_USE_RAY_SPMD_WORKER=0) - current_platform.memory._set_allocator_settings("expandable_segments:False") + if not current_platform.is_npu(): + current_platform.memory._set_allocator_settings("expandable_segments:False") os.environ["VLLM_CACHE_ROOT"] = os.path.join(get_default_cache_root(), "vllm", os.environ.get("WORKER_NAME", "")) diff --git a/roll/third_party/vllm/worker.py b/roll/third_party/vllm/worker.py index ea82ceb40..0e6d668b9 100644 --- a/roll/third_party/vllm/worker.py +++ b/roll/third_party/vllm/worker.py @@ -118,7 +118,7 @@ def broadcast_parameter(self, names, dtypes, shapes, group_name, is_lora=False): weights_and_handles = [] for name, dtype, shape in zip(names, dtypes, shapes): target_dtype = dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) - weight = torch.empty(shape, dtype=target_dtype, device=self.device) + weight = torch.empty(shape, dtype=target_dtype, device=current_platform.device_type) handle = collective.broadcast(tensor=weight, src_rank=0, group_name=group_name, async_op=True) weights_and_handles.append((name, weight, handle))