-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathvisualize_attn.py
More file actions
67 lines (52 loc) · 1.98 KB
/
visualize_attn.py
File metadata and controls
67 lines (52 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# -*- coding: utf-8 -*-
""" Visualize 2D attention
"""
import os
import sys
import json
from pprint import pprint
import numpy as np
import seaborn as sns
import matplotlib.pylab as plt
from langml.tokenizer import WPTokenizer
from dataloader import DataLoader
from model import AGNClassifier
if len(sys.argv) != 2:
print("usage: python visualize_attn.py /path/to/config")
exit()
config_file = str(sys.argv[1])
with open(config_file, "r") as reader:
config = json.load(reader)
print("config:")
pprint(config)
# Load tokenizer
tokenizer = WPTokenizer(os.path.join(config['pretrained_model_dir'], 'vocab.txt'), lowercase=True)
tokenizer.enable_truncation(max_length=config['max_len'])
dataloader = DataLoader(tokenizer,
config['max_len'],
use_vae=True,
batch_size=1,
ae_epochs=config['ae_epochs'])
dataloader.load_vocab(os.path.join(config['save_dir'], 'vocab.pickle'))
dataloader.load_autoencoder(os.path.join(config['save_dir'], 'autoencoder.weights'))
config['output_size'] = dataloader.label_size
classifier = AGNClassifier(config)
classifier.model.load_weights(os.path.join(config['save_dir'], 'clf_model.weights'))
text = input('input a text: ')
text = text.replace(',', '').replace('.', '')
tokenized = tokenizer.encode(text)
token_ids = tokenized.ids[:config['max_len']] + [0] * (config['max_len'] - len(tokenized.ids))
segment_ids = [0] * len(token_ids)
data = [{'token_ids': token_ids, 'segment_ids': segment_ids}]
data = dataloader.parse_tcol_ids(data)
token_ids = np.array([data[0]['token_ids']])
segment_ids = np.array([data[0]['segment_ids']])
tcol_ids = np.array([data[0]['tcol_ids']])
logits = classifier.attn_model.predict([token_ids, segment_ids, tcol_ids])
logits = logits[0][:len(tokenized.tokens)]
# visualize
ax, fig = plt.subplots(figsize=[20, 8])
ax = sns.heatmap(logits[1:-1], linewidth=1)
ax.set_yticklabels(tokenized.tokens[1:-1])
plt.show()
plt.savefig('attn_visualize.png')