Skip to content

Commit f1c697d

Browse files
authored
Merge pull request #608 from GuoxiaWang/replace_url
Simple implementation of file download and caching tools
2 parents 46fe570 + 4fb839b commit f1c697d

File tree

4 files changed

+134
-9
lines changed

4 files changed

+134
-9
lines changed

fleetx/data/data_tools/gpt/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
首先下载样例数据:
4141
```
4242
mkdir data && cd data
43-
wget https://bj.bcebos.com/paddlefleetx/models/transformers/data_tools/baike.txt
43+
wget https://fleet.bj.bcebos.com/datasets/gpt/wikitext-103-en.txt
4444
cd ..
4545
```
4646
### 原始数据转换 jsonl 格式
@@ -96,7 +96,7 @@ optional arguments:
9696
-h, --help show this help message and exit
9797
--model_name MODEL_NAME
9898
What model to use.
99-
必须设置,如:ernie-1.0-base-zh, 可以参考已有的模型名称 https://paddlenlp.readthedocs.io/zh/latest/model_zoo/index.html#transformer
99+
必须设置,如:gpt2
100100
--tokenizer_name {ErnieTokenizer,BertTokenizer,GPTTokenizer,GPTChineseTokenizer}
101101
What type of tokenizer to use.
102102
模型对应的tokenizer, 目前暂时只支持 Ernie,Bert,GPT
@@ -142,7 +142,7 @@ common config:
142142
处理文本id化的进程个数。
143143
```
144144
通过下面脚本转化,我们可以得到处理好的预训练数据,token ids:`wikitext_103_en.npy`, 文章索引信息`wikitext_103_en.npz`.
145-
在使用 `GPTTokenizer` 时需要用到 `gpt2-vocab.json``gpt2-merges.txt`,如果没有下载缓存过这两个文件,脚本会自动下载并缓存。当遇到网络问题时,可以自行下载并将这两个文件放置在 `~/.cache/cached_path/` 目录下。
145+
在使用 `GPTTokenizer` 时需要用到 `gpt2-vocab.json``gpt2-merges.txt`,如果没有下载缓存过这两个文件,脚本会自动下载并缓存。当遇到网络问题时,可以自行下载并将这两个文件放置在 `~/.cache/fleetx/` 目录下。
146146
```
147147
python -u preprocess_data.py \
148148
--model_name gpt2 \

fleetx/data/tokenizers/gpt_tokenizer.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import regex as re
2626
from io import open
2727

28+
from fleetx.utils.download import cached_path
29+
2830
try:
2931
from functools import lru_cache
3032
except ImportError:
@@ -38,12 +40,10 @@ def lru_cache():
3840
logger = logging.getLogger(__name__)
3941

4042
PRETRAINED_VOCAB_ARCHIVE_MAP = {
41-
'gpt2':
42-
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
43+
'gpt2': "https://fleet.bj.bcebos.com/datasets/gpt/gpt2-vocab.json",
4344
}
4445
PRETRAINED_MERGES_ARCHIVE_MAP = {
45-
'gpt2':
46-
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
46+
'gpt2': "https://fleet.bj.bcebos.com/datasets/gpt/gpt2-merges.txt",
4747
}
4848
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {'gpt2': 1024, }
4949
VOCAB_NAME = 'vocab.json'
@@ -124,7 +124,6 @@ def from_pretrained(cls,
124124
special_tokens_file))
125125
# redirect to the cache, if necessary
126126
try:
127-
from cached_path import cached_path
128127
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
129128
resolved_merges_file = cached_path(
130129
merges_file, cache_dir=cache_dir)

fleetx/utils/download.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import time
17+
import requests
18+
import shutil
19+
from fleetx.utils import logger
20+
from tqdm import tqdm
21+
import paddle
22+
23+
DOWNLOAD_RETRY_LIMIT = 3
24+
25+
26+
def is_url(path):
27+
"""
28+
Whether path is URL.
29+
Args:
30+
path (string): URL string or not.
31+
"""
32+
return path.startswith('http://') or path.startswith('https://')
33+
34+
35+
def _map_path(url, root_dir):
36+
# parse path after download under root_dir
37+
fname = os.path.split(url)[-1]
38+
fpath = fname
39+
return os.path.join(root_dir, fpath)
40+
41+
42+
def cached_path(url_or_path, cache_dir=None):
43+
if cache_dir is None:
44+
cache_dir = '~/.cache/fleetx/'
45+
46+
cache_dir = os.path.expanduser(cache_dir)
47+
48+
if not os.path.exists(cache_dir):
49+
os.makedirs(cache_dir)
50+
51+
if is_url(url_or_path):
52+
path = _map_path(url_or_path, cache_dir)
53+
url = url_or_path
54+
else:
55+
path = url_or_path
56+
url = None
57+
58+
if os.path.exists(path):
59+
logger.info(
60+
f"Found {os.path.split(path)[-1]} in cache_dir: {cache_dir}.")
61+
return path
62+
63+
download(url, path)
64+
return path
65+
66+
67+
def _download(url, fullname):
68+
"""
69+
Download from url, save to path.
70+
url (str): download url
71+
path (str): download to given path
72+
"""
73+
retry_cnt = 0
74+
75+
while not os.path.exists(fullname):
76+
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
77+
retry_cnt += 1
78+
else:
79+
raise RuntimeError("Download from {} failed. "
80+
"Retry limit reached".format(url))
81+
82+
logger.info("Downloading {}".format(url))
83+
84+
try:
85+
req = requests.get(url, stream=True)
86+
except Exception as e: # requests.exceptions.ConnectionError
87+
logger.info("Downloading {} failed {} times with exception {}".
88+
format(url, retry_cnt + 1, str(e)))
89+
time.sleep(1)
90+
continue
91+
92+
if req.status_code != 200:
93+
raise RuntimeError("Downloading from {} failed with code "
94+
"{}!".format(url, req.status_code))
95+
96+
# For protecting download interupted, download to
97+
# tmp_fullname firstly, move tmp_fullname to fullname
98+
# after download finished
99+
tmp_fullname = fullname + "_tmp"
100+
total_size = req.headers.get('content-length')
101+
with open(tmp_fullname, 'wb') as f:
102+
if total_size:
103+
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
104+
for chunk in req.iter_content(chunk_size=1024):
105+
f.write(chunk)
106+
pbar.update(1)
107+
else:
108+
for chunk in req.iter_content(chunk_size=1024):
109+
if chunk:
110+
f.write(chunk)
111+
shutil.move(tmp_fullname, fullname)
112+
113+
return fullname
114+
115+
116+
def download(url, path):
117+
local_rank = 0
118+
world_size = 1
119+
if paddle.fluid.core.is_compiled_with_dist():
120+
local_rank = paddle.distributed.ParallelEnv().dev_id
121+
world_size = paddle.distributed.get_world_size()
122+
if world_size > 1 and local_rank != 0:
123+
while not os.path.exists(path):
124+
time.sleep(1)
125+
else:
126+
_download(url, path)

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
regex
22
colorlog
33
colorama
4-
cached_path >= 1.1.5
54
inspect
65
omegaconf
6+
tqdm

0 commit comments

Comments
 (0)