Skip to content

Commit 65459a4

Browse files
authored
Merge pull request #22 from bigdata-ustc/i2v
[FEATURE] add dynamic embedding model familty of rnn
2 parents 4530911 + 7bb97f4 commit 65459a4

File tree

10 files changed

+151
-26
lines changed

10 files changed

+151
-26
lines changed

.travis.yml

Lines changed: 0 additions & 18 deletions
This file was deleted.

EduNLP/ModelZoo/rnn/rnn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch import nn
66
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
7+
from baize.torch import load_net
78

89

910
class LM(nn.Module):
@@ -28,7 +29,7 @@ class LM(nn.Module):
2829
"""
2930

3031
def __init__(self, rnn_type: str, vocab_size: int, embedding_dim: int, hidden_size: int, num_layers=1,
31-
bidirectional=False, embedding=None, **kwargs):
32+
bidirectional=False, embedding=None, model_params=None, **kwargs):
3233
super(LM, self).__init__()
3334
rnn_type = rnn_type.upper()
3435
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim) if embedding is None else embedding
@@ -61,12 +62,15 @@ def __init__(self, rnn_type: str, vocab_size: int, embedding_dim: int, hidden_si
6162
self.num_layers *= 2
6263
self.hidden_size = hidden_size
6364

65+
if model_params:
66+
load_net(model_params, self, allow_missing=True)
67+
6468
def forward(self, seq_idx, seq_len):
6569
seq = self.embedding(seq_idx)
6670
pack = pack_padded_sequence(seq, seq_len, batch_first=True)
67-
h0 = torch.randn(self.num_layers, seq.shape[0], self.hidden_size)
71+
h0 = torch.zeros(self.num_layers, seq.shape[0], self.hidden_size)
6872
if self.c is True:
69-
c0 = torch.randn(self.num_layers, seq.shape[0], self.hidden_size)
73+
c0 = torch.zeros(self.num_layers, seq.shape[0], self.hidden_size)
7074
output, (hn, _) = self.rnn(pack, (h0, c0))
7175
else:
7276
output, hn = self.rnn(pack, h0)

EduNLP/ModelZoo/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
from .padder import PadSequence, pad_sequence
55
from .device import set_device
6+
from .masker import Masker

EduNLP/ModelZoo/utils/masker.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# coding: utf-8
2+
# 2021/8/3 @ tongshiwei
3+
4+
from copy import deepcopy
5+
import numpy as np
6+
7+
8+
class Masker(object):
9+
"""
10+
Examples
11+
-------
12+
>>> masker = Masker(per=0.5, seed=10)
13+
>>> items = [[1, 1, 3, 4, 6], [2], [5, 9, 1, 4]]
14+
>>> masked_seq, mask_label = masker(items)
15+
>>> masked_seq
16+
[[1, 1, 0, 0, 6], [2], [0, 9, 0, 4]]
17+
>>> mask_label
18+
[[0, 0, 1, 1, 0], [0], [1, 0, 1, 0]]
19+
>>> items = [[1, 2, 3], [1, 1, 0], [2, 0, 0]]
20+
>>> masked_seq, mask_label = masker(items, [3, 2, 1])
21+
>>> masked_seq
22+
[[1, 0, 3], [0, 1, 0], [2, 0, 0]]
23+
>>> mask_label
24+
[[0, 1, 0], [1, 0, 0], [0, 0, 0]]
25+
>>> masker = Masker(mask="[MASK]", per=0.5, seed=10)
26+
>>> items = [["a", "b", "c"], ["d", "[PAD]", "[PAD]"], ["hello", "world", "[PAD]"]]
27+
>>> masked_seq, mask_label = masker(items, length=[3, 1, 2])
28+
>>> masked_seq
29+
[['a', '[MASK]', 'c'], ['d', '[PAD]', '[PAD]'], ['hello', '[MASK]', '[PAD]']]
30+
>>> mask_label
31+
[[0, 1, 0], [0, 0, 0], [0, 1, 0]]
32+
"""
33+
34+
def __init__(self, mask: (int, str, ...) = 0, per=0.2, seed=None):
35+
"""
36+
37+
Parameters
38+
----------
39+
mask: int, str
40+
per
41+
seed
42+
"""
43+
self.seed = np.random.default_rng(seed)
44+
self.per = per
45+
self.mask = mask
46+
47+
def __call__(self, seqs, length=None, *args, **kwargs) -> tuple:
48+
seqs = deepcopy(seqs)
49+
masked_list = []
50+
if length is None:
51+
length = [len(seq) for seq in seqs]
52+
for seq, _length in zip(seqs, length):
53+
masked = self.seed.choice(len(seq) - 1, size=int(_length * self.per), replace=False)
54+
_masked_list = [0] * len(seq)
55+
for _masked in masked:
56+
seq[_masked] = self.mask
57+
_masked_list[_masked] = 1
58+
masked_list.append(_masked_list)
59+
return seqs, masked_list

EduNLP/Vector/embedding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, w2v: (W2V, tuple, list, dict, None), freeze=True, device=None
1919
elif isinstance(w2v, W2V):
2020
self.w2v = w2v
2121
else:
22-
raise TypeError("w2v argument must be one of W2V, tuple, list, dict or None")
22+
raise TypeError("w2v argument must be one of W2V, tuple, list, dict or None, now is %s" % type(w2v))
2323

2424
if self.w2v is not None:
2525
self.vocab_size = len(self.w2v)
@@ -63,7 +63,10 @@ def indexing(self, items: List[List[str]], padding=False, indexing=True) -> tupl
6363
6464
Returns
6565
-------
66-
word_id: list of list of int
66+
token_idx: list of list of int
67+
the list of the tokens of each item
68+
token_len: list of int
69+
the list of the length of tokens of each item
6770
"""
6871
items_idx = [[self.key_to_index(word) for word in item] for item in items] if indexing else items
6972
item_len = [len(_idx) for _idx in items_idx]

EduNLP/Vector/meta.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,10 @@ def infer_tokens(self, items, *args, **kwargs) -> ...:
1111
@property
1212
def vector_size(self):
1313
raise NotImplementedError
14+
15+
@property
16+
def is_frozen(self): # pragma: no cover
17+
return True
18+
19+
def freeze(self, *args, **kwargs): # pragma: no cover
20+
pass

EduNLP/Vector/rnn/rnn.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..embedding import Embedding
77
from ..meta import Vector
88
from EduNLP.ModelZoo import rnn, set_device
9+
from baize.torch import save_params
910

1011

1112
class RNNModel(Vector):
@@ -38,7 +39,8 @@ class RNNModel(Vector):
3839
torch.Size([2, 3, 2])
3940
"""
4041

41-
def __init__(self, rnn_type, w2v: (W2V, tuple, list, dict, None), hidden_size, freeze_pretrained=True, device=None,
42+
def __init__(self, rnn_type, w2v: (W2V, tuple, list, dict, None), hidden_size,
43+
freeze_pretrained=True, model_params=None, device=None,
4244
**kwargs):
4345
self.embedding = Embedding(w2v, freeze_pretrained, **kwargs)
4446
for key in ["vocab_size", "embedding_dim"]:
@@ -50,6 +52,7 @@ def __init__(self, rnn_type, w2v: (W2V, tuple, list, dict, None), hidden_size, f
5052
self.embedding.embedding_dim,
5153
hidden_size=hidden_size,
5254
embedding=self.embedding.embedding,
55+
model_params=model_params,
5356
**kwargs
5457
)
5558
self.bidirectional = self.rnn.rnn.bidirectional
@@ -86,3 +89,22 @@ def vector_size(self) -> int:
8689

8790
def set_device(self, device):
8891
self.rnn = set_device(self.rnn, device)
92+
93+
def save(self, filepath, save_embedding=False):
94+
save_params(filepath, self.rnn, select=None if save_embedding is True else '^(?!.*embedding)')
95+
return filepath
96+
97+
def freeze(self, *args, **kwargs):
98+
return self.eval()
99+
100+
@property
101+
def is_frozen(self):
102+
return not self.rnn.training
103+
104+
def eval(self):
105+
self.rnn.eval()
106+
return self
107+
108+
def train(self, mode=True):
109+
self.rnn.train(mode)
110+
return self

examples/pretrain/rnn/rnn.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# coding: utf-8
2+
# 2021/8/3 @ tongshiwei
3+
4+
from longling import load_jsonl
5+
from EduNLP.Tokenizer import get_tokenizer
6+
from EduNLP.Pretrain import train_vector
7+
from EduNLP.Vector import W2V, RNNModel
8+
9+
10+
def etl():
11+
tokenizer = get_tokenizer("text")
12+
return tokenizer([item["stem"] for item in load_jsonl("../../../data/OpenLUNA.json")])
13+
14+
15+
items = list(etl())
16+
model_path = train_vector(items, "./w2v", 10, "sg")
17+
18+
w2v = W2V(model_path, "sg")
19+
rnn = RNNModel("lstm", w2v, 5, device="cpu")
20+
saved_params = rnn.save("./lstm.params", save_embedding=True)
21+
22+
rnn1 = RNNModel("lstm", w2v, 5, model_params=saved_params)

setup.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
setup(
2323
name='EduNLP',
24-
version='0.0.3',
24+
version='0.0.4',
2525
extras_require={
2626
'test': test_deps,
2727
'tutor': tutor_deps,
@@ -35,11 +35,23 @@
3535
'jieba',
3636
'js2py',
3737
'torch',
38-
'EduData>=0.0.16'
38+
'EduData>=0.0.16',
39+
'PyBaize[torch]>=0.0.3'
3940
], # And any other dependencies foo needs
4041
entry_points={
4142
"console_scripts": [
4243
"edunlp = EduNLP.main:cli",
4344
],
4445
},
46+
classifiers=[
47+
'Programming Language :: Python :: 3.6',
48+
'Programming Language :: Python :: 3.7',
49+
'Programming Language :: Python :: 3.8',
50+
'Programming Language :: Python :: 3.9',
51+
"Environment :: Other Environment",
52+
"Intended Audience :: Developers",
53+
"License :: OSI Approved :: Apache License 2.0 (Apache 2.0)",
54+
"Operating System :: OS Independent",
55+
"Topic :: Software Development :: Libraries :: Python Modules",
56+
],
4557
)

tests/test_vec/test_vec.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# coding: utf-8
22
# 2021/5/30 @ tongshiwei
33

4+
import torch
45
import numpy as np
56
import pytest
67
from EduNLP.Pretrain import train_vector, GensimWordTokenizer
@@ -111,10 +112,22 @@ def test_rnn(stem_tokens, tmpdir):
111112
item = rnn.infer_vector(stem_tokens[:1])
112113
assert tokens.shape == (1, len(stem_tokens[0]), 20 * (2 if rnn.bidirectional else 1))
113114
assert item.shape == (1, rnn.vector_size)
115+
item_vec = rnn.infer_vector(stem_tokens[:1])
116+
assert torch.equal(item, item_vec)
114117

115118
t2v = T2V(rnn_type, w2v, 20)
116119
assert len(t2v(stem_tokens[:1])[0]) == t2v.vector_size
117120

121+
saved_params = rnn.save(str((tmpdir / method).join("stem_tf_rnn.params")), save_embedding=True)
122+
123+
rnn = RNNModel(rnn_type, w2v, 20, device="cpu", model_params=saved_params)
124+
rnn.train()
125+
assert rnn.is_frozen is False
126+
rnn.freeze()
127+
assert rnn.is_frozen is True
128+
item_vec1 = rnn.infer_vector(stem_tokens[:1])
129+
assert torch.equal(item, item_vec1)
130+
118131

119132
def test_d2v(stem_tokens, tmpdir, stem_data):
120133
method = "d2v"

0 commit comments

Comments
 (0)