forked from ayaka14732/TransCan
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path3_predict.py
More file actions
68 lines (55 loc) · 2.98 KB
/
3_predict.py
File metadata and controls
68 lines (55 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import jax; jax.config.update('jax_platforms', 'cpu'); jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
import jax.numpy as np
import regex as re
import sys
from transformers import BartConfig, BartTokenizer, BertTokenizer
from tqdm import tqdm
from typing import Any
from lib.dataset.load_cantonese import load_cantonese
from lib.Generator import Generator
from lib.param_utils.load_params import load_params
from lib.en_kfw_nmt.fwd_transformer_encoder_part import fwd_transformer_encoder_part
def chunks(lst: list[Any], chunk_size: int) -> list[list[Any]]:
'''Yield successive n-sized chunks from lst.'''
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
def remove_tokenisation_space(s: str) -> str:
'''
>>> remove_space('阿 爸 好 忙 , 成 日 出 差')
'阿爸好忙,成日出差'
>>> remove_space('摸 A B 12至 3')
'摸A B 12至3'
>>> remove_space('噉你哋要唔要呢 ?')
'噉你哋要唔要呢?'
>>> remove_space('3 . 1')
'3.1'
'''
s = re.sub(r'(?<=[\p{Unified_Ideograph}\u3006\u3007。,、!:?()《》「」]) (?=[\p{Unified_Ideograph}\u3006\u3007。,、!:?()《》「」])', r'', s)
s = re.sub(r'(?<=[\p{Unified_Ideograph}\u3006\u3007。,、!:?()《》「」]) (?=[\da-zA-Z])', r'', s)
s = re.sub(r'(?<=[\da-zA-Z]) (?=[\p{Unified_Ideograph}\u3006\u3007。,、!:?()《》「」])', r'', s)
s = re.sub(r'(?<=[\da-zA-Z]) (?=[.,])', r'', s)
s = re.sub(r'(?<=[.,]) (?=[\da-zA-Z])', r'', s)
return s
sentences = load_cantonese(split='test')
sentences_en = [en for en, _ in sentences]
param_file = sys.argv[1] if len(sys.argv) >= 2 else 'atomic-thunder-15-7.dat'
params = load_params(param_file)
params = jax.tree_map(np.asarray, params)
tokenizer_en = BartTokenizer.from_pretrained('facebook/bart-base')
tokenizer_yue = BertTokenizer.from_pretrained('Ayaka/bart-base-cantonese')
config = BartConfig.from_pretrained('Ayaka/bart-base-cantonese')
generator = Generator({'embedding': params['decoder_embedding'], **params}, config=config)
predictions = []
for chunk in tqdm(chunks(sentences_en, chunk_size=32)):
inputs = tokenizer_en(chunk, return_tensors='jax', padding=True)
src = inputs.input_ids.astype(np.uint16)
mask_enc_1d = inputs.attention_mask.astype(np.bool_)
mask_enc = np.einsum('bi,bj->bij', mask_enc_1d, mask_enc_1d)[:, None]
encoder_last_hidden_output = fwd_transformer_encoder_part(params, src, mask_enc)
generated_ids = generator.generate(encoder_last_hidden_output, mask_enc_1d, num_beams=5, max_length=128)
decoded_sentences = tokenizer_yue.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for decoded_sentence in decoded_sentences:
predictions.append(decoded_sentence)
with open('results-bart.txt', 'w', encoding='utf-8') as f:
for prediction in predictions:
prediction = remove_tokenisation_space(prediction)
print(prediction, file=f)