Skip to content

Commit f68214c

Browse files
kinghuinnepeplwu
authored andcommitted
restore the removed code (#235)
* restore the removed code * modify cv reader
1 parent 9ef5320 commit f68214c

File tree

15 files changed

+48
-20
lines changed

15 files changed

+48
-20
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
export FLAGS_eager_delete_tensor_gb=0.0
2+
export CUDA_VISIBLE_DEVICES=0
3+
24
python -u img_classifier.py $@
+2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
export FLAGS_eager_delete_tensor_gb=0.0
2+
export CUDA_VISIBLE_DEVICES=0
3+
24
python -u predict.py $@

demo/qa_classification/run_classifier.sh

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
export FLAGS_eager_delete_tensor_gb=0.0
22
export CUDA_VISIBLE_DEVICES=0
33

4-
54
CKPT_DIR="./ckpt_qa"
65
# Recommending hyper parameters for difference task
76
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5

demo/reading-comprehension/reading_comprehension.py

-2
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,7 @@
8989

9090
# Setup runing config for PaddleHub Finetune API
9191
config = hub.RunConfig(
92-
log_interval=10,
9392
eval_interval=300,
94-
save_ckpt_interval=10000,
9593
use_pyreader=args.use_pyreader,
9694
use_data_parallel=args.use_data_parallel,
9795
use_cuda=args.use_gpu,

demo/reading-comprehension/run_finetune.sh

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export FLAGS_eager_delete_tensor_gb=0.0
2+
export CUDA_VISIBLE_DEVICES=0
23

34
# Recommending hyper parameters for difference task
45
# squad: batch_size=8, weight_decay=0, num_epoch=3, max_seq_len=512, lr=5e-5

demo/reading-comprehension/run_predict.sh

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export FLAGS_eager_delete_tensor_gb=0.0
2+
export CUDA_VISIBLE_DEVICES=0
23

34
CKPT_DIR="./ckpt_cmrc2018"
45
dataset=cmrc2018

demo/regression/run_predict.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export FLAGS_eager_delete_tensor_gb=0.0
2-
# export CUDA_VISIBLE_DEVICES=0
2+
export CUDA_VISIBLE_DEVICES=0
33

44
# User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
55
DATASET="STS-B"

demo/sequence-labeling/predict.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,17 @@
4141

4242
if __name__ == '__main__':
4343
# loading Paddlehub ERNIE pretrained model
44-
module = hub.Module(name="ernie")
44+
module = hub.Module(name="ernie_tiny")
4545
inputs, outputs, program = module.context(max_seq_len=args.max_seq_len)
4646

4747
# Sentence labeling dataset reader
4848
dataset = hub.dataset.MSRA_NER()
4949
reader = hub.reader.SequenceLabelReader(
5050
dataset=dataset,
5151
vocab_path=module.get_vocab_path(),
52-
max_seq_len=args.max_seq_len)
52+
max_seq_len=args.max_seq_len,
53+
sp_model_path=module.get_spm_path(),
54+
word_dict_path=module.get_word_dict_path())
5355
inv_label_map = {val: key for key, val in reader.label_map.items()}
5456

5557
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()

demo/sequence-labeling/run_sequence_label.sh

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export FLAGS_eager_delete_tensor_gb=0.0
2+
export CUDA_VISIBLE_DEVICES=0
23

34
CKPT_DIR="./ckpt_sequence_label"
45
python -u sequence_label.py \

demo/sequence-labeling/sequence_label.py

-3
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@
7171

7272
# Setup runing config for PaddleHub Finetune API
7373
config = hub.RunConfig(
74-
log_interval=10,
75-
eval_interval=300,
76-
save_ckpt_interval=10000,
7774
use_data_parallel=args.use_data_parallel,
7875
use_pyreader=args.use_pyreader,
7976
use_cuda=args.use_gpu,

demo/text-classification/predict.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,35 @@
4545
# Download dataset and use ClassifyReader to read dataset
4646
if args.dataset.lower() == "chnsenticorp":
4747
dataset = hub.dataset.ChnSentiCorp()
48-
module = hub.Module(name="ernie")
48+
module = hub.Module(name="ernie_tiny")
49+
metrics_choices = ["acc"]
50+
elif args.dataset.lower() == "tnews":
51+
dataset = hub.dataset.TNews()
52+
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
4953
metrics_choices = ["acc"]
5054
elif args.dataset.lower() == "nlpcc_dbqa":
5155
dataset = hub.dataset.NLPCC_DBQA()
52-
module = hub.Module(name="ernie")
56+
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
5357
metrics_choices = ["acc"]
5458
elif args.dataset.lower() == "lcqmc":
5559
dataset = hub.dataset.LCQMC()
56-
module = hub.Module(name="ernie")
60+
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
61+
metrics_choices = ["acc"]
62+
elif args.dataset.lower() == 'inews':
63+
dataset = hub.dataset.INews()
64+
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
65+
metrics_choices = ["acc"]
66+
elif args.dataset.lower() == 'bq':
67+
dataset = hub.dataset.BQ()
68+
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
69+
metrics_choices = ["acc"]
70+
elif args.dataset.lower() == 'thucnews':
71+
dataset = hub.dataset.THUCNEWS()
72+
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
73+
metrics_choices = ["acc"]
74+
elif args.dataset.lower() == 'iflytek':
75+
dataset = hub.dataset.IFLYTEK()
76+
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
5777
metrics_choices = ["acc"]
5878
elif args.dataset.lower() == "mrpc":
5979
dataset = hub.dataset.GLUE("MRPC")
@@ -90,7 +110,7 @@
90110
metrics_choices = ["acc"]
91111
elif args.dataset.lower().startswith("xnli"):
92112
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
93-
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
113+
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
94114
metrics_choices = ["acc"]
95115
else:
96116
raise ValueError("%s dataset is not defined" % args.dataset)

demo/text-classification/run_classifier.sh

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export FLAGS_eager_delete_tensor_gb=0.0
2+
export CUDA_VISIBLE_DEVICES=0
23

34
# User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
45
DATASET="chnsenticorp"

demo/text-classification/run_predict.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ python -u predict.py --checkpoint_dir=$CKPT_DIR \
1717
--max_seq_len=128 \
1818
--use_gpu=True \
1919
--dataset=${DATASET} \
20-
--batch_size=150 \
20+
--batch_size=32 \

demo/text-classification/text_classifier.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
elif args.dataset.lower() == "tnews":
4848
dataset = hub.dataset.TNews()
4949
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
50-
metrics_choices = ["acc", "f1"]
50+
metrics_choices = ["acc"]
5151
elif args.dataset.lower() == "nlpcc_dbqa":
5252
dataset = hub.dataset.NLPCC_DBQA()
5353
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
@@ -59,19 +59,19 @@
5959
elif args.dataset.lower() == 'inews':
6060
dataset = hub.dataset.INews()
6161
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
62-
metrics_choices = ["acc", "f1"]
62+
metrics_choices = ["acc"]
6363
elif args.dataset.lower() == 'bq':
6464
dataset = hub.dataset.BQ()
6565
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
66-
metrics_choices = ["acc", "f1"]
66+
metrics_choices = ["acc"]
6767
elif args.dataset.lower() == 'thucnews':
6868
dataset = hub.dataset.THUCNEWS()
6969
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
70-
metrics_choices = ["acc", "f1"]
70+
metrics_choices = ["acc"]
7171
elif args.dataset.lower() == 'iflytek':
7272
dataset = hub.dataset.IFLYTEK()
7373
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
74-
metrics_choices = ["acc", "f1"]
74+
metrics_choices = ["acc"]
7575
elif args.dataset.lower() == "mrpc":
7676
dataset = hub.dataset.GLUE("MRPC")
7777
module = hub.Module(name="ernie_v2_eng_base")
@@ -97,7 +97,7 @@
9797
dataset = hub.dataset.GLUE("RTE")
9898
module = hub.Module(name="ernie_v2_eng_base")
9999
metrics_choices = ["acc"]
100-
elif args.dataset.lower() == "mnli" or args.dataset.lower() == "mnli":
100+
elif args.dataset.lower() == "mnli" or args.dataset.lower() == "mnli_m":
101101
dataset = hub.dataset.GLUE("MNLI_m")
102102
module = hub.Module(name="ernie_v2_eng_base")
103103
metrics_choices = ["acc"]

paddlehub/reader/cv_reader.py

+4
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self,
4949
self.data_augmentation = data_augmentation
5050
self.images_std = images_std
5151
self.images_mean = images_mean
52+
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
5253

5354
if self.images_mean is None:
5455
try:
@@ -80,12 +81,15 @@ def data_generator(self,
8081
raise ValueError("The dataset is none and it's not allowed!")
8182
if phase == "train":
8283
data = self.dataset.train_data(shuffle)
84+
self.num_examples['train'] = len(self.get_train_examples())
8385
elif phase == "test":
8486
shuffle = False
8587
data = self.dataset.test_data(shuffle)
88+
self.num_examples['test'] = len(self.get_test_examples())
8689
elif phase == "val" or phase == "dev":
8790
shuffle = False
8891
data = self.dataset.validate_data(shuffle)
92+
self.num_examples['dev'] = len(self.get_dev_examples())
8993
elif phase == "predict":
9094
data = data
9195

0 commit comments

Comments
 (0)