Skip to content

Commit c7e3efc

Browse files
authored
Merge pull request #12 from LLukas22/feat/streaming
Add streaming support
2 parents 19a4788 + 2735dc4 commit c7e3efc

10 files changed

Lines changed: 411 additions & 39 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "llm-rs"
3-
version = "0.2.5"
3+
version = "0.2.6"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,26 @@ model = AutoModel.from_pretrained("path/to/model.bin",model_type=KnownModels.Lla
2323
#generate
2424
print(model.generate("The meaning of life is"))
2525
```
26+
27+
### Streaming Text
28+
Text can be yielded from a generator via the `stream` function:
29+
```python
30+
from llm_rs import AutoModel, KnownModels
31+
32+
#load the model
33+
model = AutoModel.from_pretrained("path/to/model.bin",model_type=KnownModels.Llama)
34+
35+
#generate
36+
for token in model.stream("The meaning of life is"):
37+
print(token)
38+
```
39+
2640
### Running GGML models from the Hugging Face Hub
2741
GGML converted models can be directly downloaded and run from the hub.
2842
```python
2943
from llm_rs import AutoModel
3044

31-
model = AutoModel.from_pretrained("LLukas22/mpt-7b-ggml",model_file="mpt-7b-q4_0-ggjt.bin")
45+
model = AutoModel.from_pretrained("rustformers/mpt-7b-ggml",model_file="mpt-7b-q4_0-ggjt.bin")
3246
```
3347
If there are multiple models in a repo the `model_file` has to be specified.
3448
If you want to load repositories which were not created throught this library, you have to specify the `model_type` parameter as the metadata files needed to infer the architecture are missing.

llm_rs/base_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Callable, List, Union
1+
from typing import Optional, Callable, List, Union, Generator
22
from abc import ABC
33
import os
44

@@ -36,6 +36,14 @@ def generate(self,prompt:str,
3636
Generates text from a prompt.
3737
"""
3838
...
39+
40+
def stream(self,prompt:str,
41+
generation_config:Optional[GenerationConfig]=None,
42+
) -> Generator[str,None,None]:
43+
"""
44+
Streams text from a prompt.
45+
"""
46+
...
3947

4048
def tokenize(self,text:str) -> List[int]:
4149
"""

llm_rs/langchain/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .langchain import RustformerLLM

llm_rs/langchain/langchain.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
try:
2+
from langchain.llms.base import LLM
3+
except ImportError:
4+
raise ImportError(
5+
'To use the llm_rs.langchain module, please install llm-rs with the additional "langchain" dependencies via: pip install llm-rs[langchain]')
6+
7+
from typing import Any, Dict, Optional, Sequence, Union, List
8+
import os
9+
10+
from pydantic import root_validator
11+
from langchain.callbacks.manager import CallbackManagerForLLMRun
12+
13+
from ..auto import AutoModel, KnownModels
14+
from ..config import GenerationConfig, SessionConfig
15+
from ..base_model import Model
16+
17+
class RustformerLLM(LLM):
18+
"""
19+
Langchain-Wrapper around a Rustformers model.
20+
"""
21+
22+
model: Optional[Model] = None #: :meta private:
23+
24+
model_path_or_repo_id: Union[str,os.PathLike]
25+
"""The path to the model file or directory or the name of a Hugging Face Hub
26+
model repo."""
27+
28+
model_type: Optional[KnownModels] = None
29+
"""The model type."""
30+
31+
model_file: Optional[str] = None
32+
"""The name of the model file in repo or directory."""
33+
34+
# session_config:SessionConfig=SessionConfig()
35+
# """Session config for the model."""
36+
37+
# generation_config:GenerationConfig=GenerationConfig()
38+
# """Generation config for the model."""
39+
40+
lora_paths:Optional[List[Union[str,os.PathLike]]]=None
41+
"""Paths to the lora files."""
42+
43+
44+
@property
45+
def _identifying_params(self) -> Dict[str, Any]:
46+
"""Get the identifying parameters."""
47+
return {
48+
'model_path_or_repo_id': self.model_path_or_repo_id,
49+
'model_type': self.model_type,
50+
'model_file': self.model_file,
51+
# 'session_config': self.session_config,
52+
# 'generation_config': self.generation_config,
53+
'lora_paths': self.lora_paths,
54+
}
55+
56+
@property
57+
def _llm_type(self) -> str:
58+
"""Return type of llm."""
59+
return 'rustformer'
60+
61+
@root_validator()
62+
def validate_environment(cls, values: Dict) -> Dict:
63+
"""Validate and load model from a local file or remote repo."""
64+
values['model'] = AutoModel.from_pretrained(
65+
model_path_or_repo_id= values['model_path_or_repo_id'],
66+
model_type=values['model_type'],
67+
model_file=values['model_file'],
68+
# session_config=values['session_config'],
69+
lora_paths=values['lora_paths'],
70+
)
71+
return values
72+
73+
def _call(
74+
self,
75+
prompt: str,
76+
stop: Optional[Sequence[str]] = None,
77+
run_manager: Optional[CallbackManagerForLLMRun] = None,
78+
) -> str:
79+
"""Generate text from a prompt.
80+
81+
Args:
82+
prompt: The prompt to generate text from.
83+
stop: A list of sequences to stop generation when encountered.
84+
85+
Returns:
86+
The generated text.
87+
"""
88+
text = []
89+
generation_config = GenerationConfig()
90+
91+
if stop:
92+
generation_config.stop_words = list(stop)
93+
for chunk in self.model.stream(prompt, generation_config=generation_config):
94+
text.append(chunk)
95+
if run_manager:
96+
run_manager.on_llm_new_token(chunk, verbose=self.verbose)
97+
return ''.join(text)

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,9 @@ convert=[
3737
"einops >= 0.6.1"
3838
]
3939

40+
langchain=[
41+
"langchain"
42+
]
43+
4044
[tool.maturin]
4145
features = ["pyo3/extension-module"]

src/configs.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
use crate::stopwords::StopWordHandler;
12
use llm::{InferenceParameters, InferenceSessionConfig, ModelKVMemoryType, TokenBias};
23
use pyo3::prelude::*;
34

45
#[pyclass]
6+
#[derive(Clone)]
57
pub struct GenerationConfig {
68
#[pyo3(get, set)]
79
pub top_k: usize,
@@ -19,6 +21,7 @@ pub struct GenerationConfig {
1921
pub max_new_tokens: Option<usize>,
2022
#[pyo3(get, set)]
2123
pub stop_words: Option<Vec<String>>,
24+
pub stop_word_handler: Option<StopWordHandler>,
2225
}
2326

2427
impl Default for GenerationConfig {
@@ -32,6 +35,18 @@ impl Default for GenerationConfig {
3235
seed: 42,
3336
max_new_tokens: None,
3437
stop_words: None,
38+
stop_word_handler: None,
39+
}
40+
}
41+
}
42+
43+
impl GenerationConfig {
44+
pub fn init_stop_words(&mut self, model: &dyn llm::Model) {
45+
if self.stop_words.is_some() {
46+
let stopwords = self.stop_words.clone().unwrap();
47+
self.stop_word_handler = Some(StopWordHandler::new(model, &stopwords));
48+
} else {
49+
self.stop_word_handler = None;
3550
}
3651
}
3752
}
@@ -59,6 +74,7 @@ impl GenerationConfig {
5974
seed: seed.unwrap_or(42),
6075
max_new_tokens,
6176
stop_words,
77+
stop_word_handler: None,
6278
}
6379
}
6480
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod model_base;
55
mod models;
66
mod quantize;
77
mod results;
8+
mod stopwords;
89

910
#[pymodule]
1011
fn llm_rs(_py: Python, m: &PyModule) -> PyResult<()> {

0 commit comments

Comments
 (0)