Skip to content

Commit 06461fd

Browse files
[fix] update some codes for quesnet
1 parent 92776a6 commit 06461fd

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

EduNLP/Vector/quesnet/quesnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class QuesNetModel(Vector):
9-
def __init__(self, pretrained_dir, img_dir=None, device="cpu", **kwargs):
9+
def __init__(self, pretrained_dir, device="cpu", **kwargs):
1010
"""
1111
Parameters
1212
----------
@@ -18,7 +18,7 @@ def __init__(self, pretrained_dir, img_dir=None, device="cpu", **kwargs):
1818
image dir
1919
"""
2020
self.device = torch.device(device)
21-
self.model = QuesNet.from_pretrained(pretrained_dir, img_dir=img_dir).to(self.device)
21+
self.model = QuesNet.from_pretrained(pretrained_dir).to(self.device)
2222
self.model.eval()
2323

2424
def __call__(self, items: dict):

examples/pretrain/quesnet.ipynb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,15 @@
138138
"}\n",
139139
"\n",
140140
"# 当前仅支持linux下训练\n",
141-
"# pretrain_quesnet(os.path.join(os.path.abspath(data_dir), 'quesnet_data.json'),\n",
142-
"# output_dir, tokenizer, True, train_params)"
141+
"pretrain_quesnet(\n",
142+
" path=os.path.join(os.path.abspath(data_dir),'quesnet_data.json'),\n",
143+
" output_dir=output_dir,\n",
144+
" tokenizer=tokenizer,\n",
145+
" img_dir=None,\n",
146+
" save_embs=True,\n",
147+
" load_embs=False,\n",
148+
" train_params=train_params\n",
149+
")"
143150
]
144151
},
145152
{

tests/test_pretrain/test_pretrained_quesnet.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@ def test_tokenizer(self, standard_luna_data, pretrained_tokenizer_dir):
6868
def test_train_quesnet(self, standard_luna_data, pretrained_model_dir):
6969
test_items = [
7070
{'ques_content': '有公式$\\FormFigureID{wrong1?}$和公式$\\FormFigureBase64{wrong2?}$,\
71-
如图$\\FigureID{000004d6-0479-11ec-829b-797d5eb43535}$,\
72-
若$x,y$满足约束条件$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$'},
71+
如图$\\FigureID{000004d6-0479-11ec-829b-797d5eb43535}$,若$x,y$满足约束条件$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$'},
7372
{'ques_content': '如图$\\FigureID{000004d6-0479-11ec-829b-797d5eb43535}$, \
74-
若$x,y$满足约束条件$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$'}
73+
若$x,y$满足约束条件$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$',
74+
"ques_figure_ids": ["000004d6-0479-11ec-829b-797d5eb43535"],
75+
"ques_figure_paths": ["../../static/test_data/quesnet_img/000004d6-0479-11ec-829b-797d5eb43535.png"]}
7576
]
7677

7778
ques_file = path_append(abs_current_dir(__file__),
@@ -139,7 +140,9 @@ def test_quesnet_t2v(self, pretrained_model_dir):
139140
def test_quesnet_i2v(self, pretrained_model_dir):
140141
items = [
141142
{'ques_content': '如图$\\FigureID{000004d6-0479-11ec-829b-797d5eb43535}$, \
142-
若$x,y$满足约束条件$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$'}
143+
若$x,y$满足约束条件$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$',
144+
"ques_figure_ids": ["000004d6-0479-11ec-829b-797d5eb43535"],
145+
"ques_figure_paths": ["../../static/test_data/quesnet_img/000004d6-0479-11ec-829b-797d5eb43535.png"]}
143146
]
144147
img_dir = path_append(abs_current_dir(__file__),
145148
"../../static/test_data/quesnet_img", to_str=True)

0 commit comments

Comments
 (0)