It would be nice if we don't need to write code for example inputs.
For example, we need something like:
|
def get_example_inputs(self): |
|
past_seq_len = 511 |
|
cur_seq_len = 1 |
|
input_ids = torch.tensor([[812]]).to(torch.long) |
|
attention_mask = torch.ones(1, past_seq_len + cur_seq_len) |
|
position_ids = torch.tensor([[past_seq_len]]).to(torch.long) |
|
|
|
past_key_values = DynamicCache() |
|
for layer_id in range(self.config.num_hidden_layers): |
|
past_key_values.update( |
|
torch.randn( |
|
[ |
|
1, |
|
self.config.num_attention_heads, |
|
past_seq_len, |
|
self.config.head_dim, |
|
] |
|
), |
|
torch.randn( |
|
[ |
|
1, |
|
self.config.num_attention_heads, |
|
past_seq_len, |
|
self.config.head_dim, |
|
] |
|
), |
|
layer_id, |
|
) |
|
|
|
return ( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
) |
To correctly working example input, we need to understand much about the target model.
I tried to capture the input from user-level inputs.
I succeed to write a working version, tested with Maykeye/TinyLLama-v0.
It confirmed it works for transformers version below:
- 4.49.0 ❌ DynamicCache is not pytree-flattenable
- 4.51.3 ⭕
- 4.52.4 ⭕
A good news is that PyTorch supports transformers DynamicCahce as pytree-flatenable since 4.50.0. (though it seems some bug on MacOS, all versions after 4.50.1 would work.
I hope and guess it will work for other models by modifying model name and user inputs.
It would be nice if we don't need to write code for example inputs.
For example, we need something like:
TICO/test/modules/model/LlamaWithKVCache/model.py
Lines 36 to 70 in a3ed23e
To correctly working example input, we need to understand much about the target model.
I tried to capture the input from user-level inputs.
I succeed to write a working version, tested with
Maykeye/TinyLLama-v0.It confirmed it works for transformers version below:
A good news is that PyTorch supports transformers DynamicCahce as pytree-flatenable since 4.50.0. (though it seems some bug on MacOS, all versions after 4.50.1 would work.
I hope and guess it will work for other models by modifying model name and user inputs.