You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: ICLR2023/src/neo_train.py
+27-70Lines changed: 27 additions & 70 deletions
Original file line number
Diff line number
Diff line change
@@ -15,21 +15,25 @@
15
15
# limitations under the License.
16
16
"""
17
17
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
18
19
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
19
20
https://huggingface.co/models?filter=causal-lm
20
21
"""
21
22
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
22
23
24
+
"""This file is based on: https://github.com/huggingface/transformers/blob/1b5ce1e63b7bd4382cd1b4fdcca72d50f8b29494/examples/language-modeling/run_clm.py
25
+
26
+
There were only two lines changed, both have the comment # CHANGED: added
27
+
"""
28
+
23
29
importlogging
24
30
importmath
25
31
importos
26
-
27
32
importsys
28
33
fromdataclassesimportdataclass, field
29
34
fromtypingimportOptional
30
-
frompathlibimportPath
31
35
32
-
fromdatasetsimportload_dataset, Dataset
36
+
fromdatasetsimportload_dataset
33
37
34
38
importtransformers
35
39
fromtransformersimport (
@@ -73,36 +77,25 @@ class ModelArguments:
73
77
)
74
78
model_type: Optional[str] =field(
75
79
default=None,
76
-
metadata={
77
-
"help": "If training from scratch, pass a model type from the list: "
78
-
+", ".join(MODEL_TYPES)
79
-
},
80
+
metadata={"help": "If training from scratch, pass a model type from the list: "+", ".join(MODEL_TYPES)},
80
81
)
81
82
config_name: Optional[str] =field(
82
-
default=None,
83
-
metadata={"help": "Pretrained config name or path if not the same as model_name"},
83
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
84
84
)
85
85
tokenizer_name: Optional[str] =field(
86
-
default=None,
87
-
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
86
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
88
87
)
89
88
cache_dir: Optional[str] =field(
90
89
default=None,
91
-
metadata={
92
-
"help": "Where do you want to store the pretrained models downloaded from huggingface.co"
93
-
},
90
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
94
91
)
95
92
use_fast_tokenizer: bool=field(
96
93
default=True,
97
-
metadata={
98
-
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
99
-
},
94
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
100
95
)
101
96
model_revision: str=field(
102
97
default="main",
103
-
metadata={
104
-
"help": "The specific model version to use (can be a branch name, tag name or commit id)."
105
-
},
98
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
106
99
)
107
100
use_auth_token: bool=field(
108
101
default=False,
@@ -120,23 +113,15 @@ class DataTrainingArguments:
120
113
"""
121
114
122
115
dataset_name: Optional[str] =field(
123
-
default=None,
124
-
metadata={"help": "The name of the dataset to use (via the datasets library)."},
116
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
125
117
)
126
118
dataset_config_name: Optional[str] =field(
127
-
default=None,
128
-
metadata={
129
-
"help": "The configuration name of the dataset to use (via the datasets library)."
130
-
},
131
-
)
132
-
train_file: Optional[str] =field(
133
-
default=None, metadata={"help": "The input training data file (a text file)."}
119
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
134
120
)
121
+
train_file: Optional[str] =field(default=None, metadata={"help": "The input training data file (a text file)."})
135
122
validation_file: Optional[str] =field(
136
123
default=None,
137
-
metadata={
138
-
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
139
-
},
124
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
140
125
)
141
126
max_train_samples: Optional[int] =field(
142
127
default=None,
@@ -181,18 +166,10 @@ def __post_init__(self):
181
166
else:
182
167
ifself.train_fileisnotNone:
183
168
extension=self.train_file.split(".")[-1]
184
-
assertextensionin [
185
-
"csv",
186
-
"json",
187
-
"txt",
188
-
], "`train_file` should be a csv, a json or a txt file."
169
+
assertextensionin ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
189
170
ifself.validation_fileisnotNone:
190
171
extension=self.validation_file.split(".")[-1]
191
-
assertextensionin [
192
-
"csv",
193
-
"json",
194
-
"txt",
195
-
], "`validation_file` should be a csv, a json or a txt file."
172
+
assertextensionin ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
0 commit comments