Skip to content

Commit 6e75a25

Browse files
committed
clarify where neo_train.py came from
1 parent 3abb3ee commit 6e75a25

File tree

1 file changed

+27
-70
lines changed

1 file changed

+27
-70
lines changed

ICLR2023/src/neo_train.py

Lines changed: 27 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,25 @@
1515
# limitations under the License.
1616
"""
1717
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18+
1819
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
1920
https://huggingface.co/models?filter=causal-lm
2021
"""
2122
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
2223

24+
"""This file is based on: https://github.com/huggingface/transformers/blob/1b5ce1e63b7bd4382cd1b4fdcca72d50f8b29494/examples/language-modeling/run_clm.py
25+
26+
There were only two lines changed, both have the comment # CHANGED: added
27+
"""
28+
2329
import logging
2430
import math
2531
import os
26-
2732
import sys
2833
from dataclasses import dataclass, field
2934
from typing import Optional
30-
from pathlib import Path
3135

32-
from datasets import load_dataset, Dataset
36+
from datasets import load_dataset
3337

3438
import transformers
3539
from transformers import (
@@ -73,36 +77,25 @@ class ModelArguments:
7377
)
7478
model_type: Optional[str] = field(
7579
default=None,
76-
metadata={
77-
"help": "If training from scratch, pass a model type from the list: "
78-
+ ", ".join(MODEL_TYPES)
79-
},
80+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
8081
)
8182
config_name: Optional[str] = field(
82-
default=None,
83-
metadata={"help": "Pretrained config name or path if not the same as model_name"},
83+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
8484
)
8585
tokenizer_name: Optional[str] = field(
86-
default=None,
87-
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
86+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
8887
)
8988
cache_dir: Optional[str] = field(
9089
default=None,
91-
metadata={
92-
"help": "Where do you want to store the pretrained models downloaded from huggingface.co"
93-
},
90+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
9491
)
9592
use_fast_tokenizer: bool = field(
9693
default=True,
97-
metadata={
98-
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
99-
},
94+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
10095
)
10196
model_revision: str = field(
10297
default="main",
103-
metadata={
104-
"help": "The specific model version to use (can be a branch name, tag name or commit id)."
105-
},
98+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
10699
)
107100
use_auth_token: bool = field(
108101
default=False,
@@ -120,23 +113,15 @@ class DataTrainingArguments:
120113
"""
121114

122115
dataset_name: Optional[str] = field(
123-
default=None,
124-
metadata={"help": "The name of the dataset to use (via the datasets library)."},
116+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
125117
)
126118
dataset_config_name: Optional[str] = field(
127-
default=None,
128-
metadata={
129-
"help": "The configuration name of the dataset to use (via the datasets library)."
130-
},
131-
)
132-
train_file: Optional[str] = field(
133-
default=None, metadata={"help": "The input training data file (a text file)."}
119+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
134120
)
121+
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
135122
validation_file: Optional[str] = field(
136123
default=None,
137-
metadata={
138-
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
139-
},
124+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
140125
)
141126
max_train_samples: Optional[int] = field(
142127
default=None,
@@ -181,18 +166,10 @@ def __post_init__(self):
181166
else:
182167
if self.train_file is not None:
183168
extension = self.train_file.split(".")[-1]
184-
assert extension in [
185-
"csv",
186-
"json",
187-
"txt",
188-
], "`train_file` should be a csv, a json or a txt file."
169+
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
189170
if self.validation_file is not None:
190171
extension = self.validation_file.split(".")[-1]
191-
assert extension in [
192-
"csv",
193-
"json",
194-
"txt",
195-
], "`validation_file` should be a csv, a json or a txt file."
172+
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
196173

197174

198175
def main():
@@ -204,19 +181,13 @@ def main():
204181
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
205182
# If we pass only one argument to the script and it's the path to a json file,
206183
# let's parse it to get our arguments.
207-
model_args, data_args, training_args = parser.parse_json_file(
208-
json_file=os.path.abspath(sys.argv[1])
209-
)
184+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
210185
else:
211186
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
212187

213188
# Detecting last checkpoint.
214189
last_checkpoint = None
215-
if (
216-
os.path.isdir(training_args.output_dir)
217-
and training_args.do_train
218-
and not training_args.overwrite_output_dir
219-
):
190+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
220191
last_checkpoint = get_last_checkpoint(training_args.output_dir)
221192
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
222193
raise ValueError(
@@ -288,8 +259,6 @@ def main():
288259
)
289260
if extension == "txt":
290261
extension = "text"
291-
print(extension)
292-
print(data_files)
293262
datasets = load_dataset(extension, data_files=data_files)
294263
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
295264
# https://huggingface.co/docs/datasets/loading_datasets.html.
@@ -313,12 +282,8 @@ def main():
313282
config = CONFIG_MAPPING[model_args.model_type]()
314283
logger.warning("You are instantiating a new config instance from scratch.")
315284

316-
# Things that were changed from the huggingface file
317-
318-
config.gradient_checkpointing = True
319-
config.use_cache = False
320-
321-
#
285+
config.gradient_checkpointing = True # CHANGED: added
286+
config.use_cache = False # CHANGED: added
322287

323288
tokenizer_kwargs = {
324289
"cache_dir": model_args.cache_dir,
@@ -445,9 +410,7 @@ def group_texts(examples):
445410
if training_args.do_train:
446411
if last_checkpoint is not None:
447412
checkpoint = last_checkpoint
448-
elif model_args.model_name_or_path is not None and os.path.isdir(
449-
model_args.model_name_or_path
450-
):
413+
elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
451414
checkpoint = model_args.model_name_or_path
452415
else:
453416
checkpoint = None
@@ -457,9 +420,7 @@ def group_texts(examples):
457420
metrics = train_result.metrics
458421

459422
max_train_samples = (
460-
data_args.max_train_samples
461-
if data_args.max_train_samples is not None
462-
else len(train_dataset)
423+
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
463424
)
464425
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
465426

@@ -473,11 +434,7 @@ def group_texts(examples):
473434

474435
metrics = trainer.evaluate()
475436

476-
max_val_samples = (
477-
data_args.max_val_samples
478-
if data_args.max_val_samples is not None
479-
else len(eval_dataset)
480-
)
437+
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
481438
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
482439
perplexity = math.exp(metrics["eval_loss"])
483440
metrics["perplexity"] = perplexity
@@ -492,4 +449,4 @@ def _mp_fn(index):
492449

493450

494451
if __name__ == "__main__":
495-
main()
452+
main()

0 commit comments

Comments
 (0)