|
45 | 45 | # Download dataset and use ClassifyReader to read dataset
|
46 | 46 | if args.dataset.lower() == "chnsenticorp":
|
47 | 47 | dataset = hub.dataset.ChnSentiCorp()
|
48 |
| - module = hub.Module(name="ernie") |
| 48 | + module = hub.Module(name="ernie_tiny") |
| 49 | + metrics_choices = ["acc"] |
| 50 | + elif args.dataset.lower() == "tnews": |
| 51 | + dataset = hub.dataset.TNews() |
| 52 | + module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") |
49 | 53 | metrics_choices = ["acc"]
|
50 | 54 | elif args.dataset.lower() == "nlpcc_dbqa":
|
51 | 55 | dataset = hub.dataset.NLPCC_DBQA()
|
52 |
| - module = hub.Module(name="ernie") |
| 56 | + module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") |
53 | 57 | metrics_choices = ["acc"]
|
54 | 58 | elif args.dataset.lower() == "lcqmc":
|
55 | 59 | dataset = hub.dataset.LCQMC()
|
56 |
| - module = hub.Module(name="ernie") |
| 60 | + module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") |
| 61 | + metrics_choices = ["acc"] |
| 62 | + elif args.dataset.lower() == 'inews': |
| 63 | + dataset = hub.dataset.INews() |
| 64 | + module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") |
| 65 | + metrics_choices = ["acc"] |
| 66 | + elif args.dataset.lower() == 'bq': |
| 67 | + dataset = hub.dataset.BQ() |
| 68 | + module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") |
| 69 | + metrics_choices = ["acc"] |
| 70 | + elif args.dataset.lower() == 'thucnews': |
| 71 | + dataset = hub.dataset.THUCNEWS() |
| 72 | + module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") |
| 73 | + metrics_choices = ["acc"] |
| 74 | + elif args.dataset.lower() == 'iflytek': |
| 75 | + dataset = hub.dataset.IFLYTEK() |
| 76 | + module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") |
57 | 77 | metrics_choices = ["acc"]
|
58 | 78 | elif args.dataset.lower() == "mrpc":
|
59 | 79 | dataset = hub.dataset.GLUE("MRPC")
|
|
90 | 110 | metrics_choices = ["acc"]
|
91 | 111 | elif args.dataset.lower().startswith("xnli"):
|
92 | 112 | dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
|
93 |
| - module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12") |
| 113 | + module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16") |
94 | 114 | metrics_choices = ["acc"]
|
95 | 115 | else:
|
96 | 116 | raise ValueError("%s dataset is not defined" % args.dataset)
|
|
0 commit comments