diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 01483b9a7..bc2a52885 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -231,8 +231,15 @@ def generate_client_class(class_name: str, """Generate client wrapper class code.""" def build_imports() -> Tuple[List[str], str]: - # Include both method signatures and __init__ signature for import detection - signatures = [sig for _, sig in methods] + # Include both method signatures and __init__ signature for import detection. + # Use the timeout-injected signature for ``encode`` so the ``Optional`` + # import introduced by ``_inject_timeout`` is detected. + signatures = [] + for name, sig in methods: + if name == 'encode' and class_name in ('Dataset', 'LazyDataset'): + signatures.append(_inject_timeout(sig)) + else: + signatures.append(sig) if init_signature: signatures.append(init_signature) @@ -261,15 +268,31 @@ def build_imports() -> Tuple[List[str], str]: lines.append('') return lines, inheritance + def _inject_timeout(signature: str) -> str: + """Insert `timeout: Optional[int] = 600` before any **kwargs in the signature.""" + if 'timeout' in signature: + return signature + if ', **' in signature: + pre, post = signature.rsplit(', **', 1) + return f'{pre}, timeout: Optional[int] = 600, **{post}' + if signature.startswith('**'): + return f'timeout: Optional[int] = 600, {signature}' + if signature: + return f'{signature}, timeout: Optional[int] = 600' + return 'timeout: Optional[int] = 600' + def build_method(name: str, signature: str) -> str: param_names = parse_params_from_signature(signature) kwargs_dict = '{' + ', '.join(f"'{p}': {p}" for p in param_names) + '}' if param_names else '{}' - sig_part = f', {signature}' if signature else '' + wants_timeout = name == 'encode' and class_name in ('Dataset', 'LazyDataset') + effective_sig = _inject_timeout(signature) if wants_timeout else signature + sig_part = f', {effective_sig}' if effective_sig else '' if 'kwargs' in sig_part: extra_args = '\n **kwargs' else: extra_args = '' ret = 'self' if name == '__iter__' else 'response.json()["result"]' + timeout_kwarg = ',\n timeout=timeout' if wants_timeout else '' code = f''' def {name}(self{sig_part}): @@ -279,7 +302,7 @@ def {name}(self{sig_part}): 'processor_id': self.processor_id, 'function': '{name}', **{kwargs_dict},{extra_args} - }} + }}{timeout_kwarg} ) response.raise_for_status() return {ret} diff --git a/docs/source_en/Components/Dataset/Dataset.md b/docs/source_en/Components/Dataset/Dataset.md index 3fb86119f..85c0f5470 100644 --- a/docs/source_en/Components/Dataset/Dataset.md +++ b/docs/source_en/Components/Dataset/Dataset.md @@ -148,6 +148,7 @@ dataset.encode() > 1. Dataset's `map`, `encode`, `filter`, and other methods all use the `map` method of `datasets`, so you can use the corresponding parameters in the kwargs of the corresponding methods > 2. The `load_from_cache_file` parameter defaults to False, because when this parameter is set to True, it can cause headaches when the dataset changes but training still uses the cache. If your dataset is large and updated infrequently, you can directly set it to True > 3. encode does not need to specify `DatasetMeta` because after preprocessing, all datasets have the same format +> 4. `encode` tokenizes with a single process by default. For large datasets, enable multi-process parallelism via `num_proc`, e.g. `dataset.encode(num_proc=8)` 6. Getting data diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index af53f5ab1..5a118a617 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -109,6 +109,11 @@ dataset.map('SelfCognitionProcessor', # Encode dataset into tokens usable by the model dataset.encode(batched=True) +# For large datasets, use num_proc to enable multi-process parallelism: +# dataset.encode(batched=True, num_proc=8) +# When using twinkle_client.dataset, encode calls the remote server over HTTP +# with a default 600s timeout; raise it via the timeout argument if needed: +# dataset.encode(batched=True, num_proc=8, timeout=3600) # Create DataLoader dataloader = DataLoader(dataset=dataset, batch_size=4) diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" index 967479b9b..4c3fa09e4 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" @@ -109,6 +109,11 @@ dataset.map('SelfCognitionProcessor', # 编码数据集为模型可用的 token dataset.encode(batched=True) +# 数据量大时可用 num_proc 多进程加速: +# dataset.encode(batched=True, num_proc=8) +# 使用 twinkle_client.dataset 时,encode 是通过 HTTP 调用远端服务, +# 默认 600 秒超时,可用 timeout 参数按需调大: +# dataset.encode(batched=True, num_proc=8, timeout=3600) # 创建 DataLoader dataloader = DataLoader(dataset=dataset, batch_size=4) diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/Dataset.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/Dataset.md" index 812a7e7fd..5c18f82ed 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/Dataset.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/Dataset.md" @@ -148,6 +148,7 @@ dataset.encode() > 1. Dataset 的 `map`、`encode`、`filter` 等方法均使用 `datasets` 的 `map` 方式进行,因此在对应方法的 kwargs 中均可以使用对应的参数 > 2. `load_from_cache_file` 参数默认为 False,因为该参数设置为 True 时会引发一些数据集改变但训练仍然使用缓存的头疼问题。如果你的数据集较大而且更新不频繁,可以直接置为 True > 3. encode 不需要指定 `DatasetMeta`,因为预处理过后所有数据集格式都是相同的 +> 4. `encode` 默认使用单进程分词。数据量较大时可通过 `num_proc` 开启多进程并行加速,例如 `dataset.encode(num_proc=8)` 6. 获取数据 diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py index 21169637a..845b11cb8 100644 --- a/src/twinkle_client/dataset/base.py +++ b/src/twinkle_client/dataset/base.py @@ -9,7 +9,7 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from typing import Any, Callable, Dict, Type, Union +from typing import Any, Callable, Dict, Optional, Type, Union from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta @@ -50,7 +50,7 @@ def set_template(self, template_func: Union[Template, Type[Template], str], **kw return response.json()["result"] - def encode(self, add_generation_prompt: bool = False, **kwargs): + def encode(self, add_generation_prompt: bool = False, timeout: Optional[int] = 600, **kwargs): response = http_post( url=f'{self.server_url}/call', json_data={ @@ -58,7 +58,8 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): 'function': 'encode', **{'add_generation_prompt': add_generation_prompt}, **kwargs - } + }, + timeout=timeout ) response.raise_for_status() return response.json()["result"] @@ -146,6 +147,33 @@ def mix_dataset(self, interleave = True): return response.json()["result"] + def save_as(self, output_path: str, format: Optional[str] = None, batch_size: int = 1000, mode: str = 'immediate', **kwargs): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'save_as', + **{'output_path': output_path, 'format': format, 'batch_size': batch_size, 'mode': mode}, + **kwargs + } + ) + response.raise_for_status() + return response.json()["result"] + + + def flush_save(self): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'flush_save', + **{}, + } + ) + response.raise_for_status() + return response.json()["result"] + + def __getitem__(self, idx): response = http_post( url=f'{self.server_url}/call', diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py index 54a9abc59..c3b70c85f 100644 --- a/src/twinkle_client/dataset/lazy_dataset.py +++ b/src/twinkle_client/dataset/lazy_dataset.py @@ -9,7 +9,7 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from typing import Any, Callable, Dict, Type, Union +from typing import Any, Callable, Dict, Optional, Type, Union from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta @@ -91,7 +91,7 @@ def mix_dataset(self, interleave = True): return response.json()["result"] - def encode(self, add_generation_prompt: bool = False, **kwargs): + def encode(self, add_generation_prompt: bool = False, timeout: Optional[int] = 600, **kwargs): response = http_post( url=f'{self.server_url}/call', json_data={ @@ -99,7 +99,8 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): 'function': 'encode', **{'add_generation_prompt': add_generation_prompt}, **kwargs - } + }, + timeout=timeout ) response.raise_for_status() return response.json()["result"] diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 57e9be929..6aac84a1a 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -120,7 +120,7 @@ def http_post( json_data: Optional[Dict[str, Any]] = {}, data: Optional[Any] = {}, additional_headers: Optional[Dict[str, str]] = {}, - timeout: int = 600, + timeout: Optional[int] = 600, ) -> requests.Response: """ Send HTTP POST request with required headers. @@ -130,7 +130,7 @@ def http_post( json_data: JSON data to send in request body data: Form data or raw data to send in request body additional_headers: Additional headers to include - timeout: Request timeout in seconds + timeout: Request timeout in seconds; None disables the timeout. Returns: requests.Response object