Skip to content

Commit f54e80d

Browse files
authored
[Function optimization] add unittest for downloading with file-lock (#4972)
* add unittest for file lock * add file
1 parent c436823 commit f54e80d

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

paddlenlp/utils/downloader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
139139

140140

141141
def get_path_from_url_with_filelock(
142-
url: str, root_dir: str, md5sum: Optional[str] = None, check_exist: bool = True
142+
url: str, root_dir: str, md5sum: Optional[str] = None, check_exist: bool = True, timeout: float = -1
143143
) -> str:
144144
"""construct `get_path_from_url` for `model_utils` to enable downloading multiprocess-safe
145145
@@ -148,6 +148,7 @@ def get_path_from_url_with_filelock(
148148
root_dir (str): the local download path
149149
md5sum (str, optional): md5sum string for file. Defaults to None.
150150
check_exist (bool, optional): whether check the file is exist. Defaults to True.
151+
timeout (int, optional): the timeout for downloading. Defaults to -1.
151152
152153
Returns:
153154
str: the path of downloaded file
@@ -163,7 +164,7 @@ def get_path_from_url_with_filelock(
163164

164165
os.makedirs(os.path.dirname(lock_file_path), exist_ok=True)
165166

166-
with FileLock(lock_file_path):
167+
with FileLock(lock_file_path, timeout=timeout):
167168
result = get_path_from_url(url=url, root_dir=root_dir, md5sum=md5sum, check_exist=check_exist)
168169
return result
169170

tests/utils/test_downloader.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import hashlib
18+
import os
19+
import unittest
20+
from tempfile import TemporaryDirectory
21+
22+
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
23+
24+
25+
class LockFileTest(unittest.TestCase):
26+
test_url = (
27+
"https://bj.bcebos.com/paddlenlp/models/transformers/roformerv2/roformer_v2_chinese_char_small/vocab.txt"
28+
)
29+
30+
def test_downloading_with_exist_file(self):
31+
32+
with TemporaryDirectory() as tempdir:
33+
lock_file_name = hashlib.md5((self.test_url + tempdir).encode("utf-8")).hexdigest()
34+
lock_file_path = os.path.join(tempdir, ".lock", lock_file_name)
35+
os.makedirs(os.path.dirname(lock_file_path), exist_ok=True)
36+
37+
# create lock file
38+
with open(lock_file_path, "w", encoding="utf-8") as f:
39+
f.write("temp test")
40+
41+
# downloading with exist lock file
42+
config_file = get_path_from_url_with_filelock(self.test_url, root_dir=tempdir)
43+
self.assertIsNotNone(config_file)
44+
45+
def test_downloading_with_opened_exist_file(self):
46+
47+
with TemporaryDirectory() as tempdir:
48+
lock_file_name = hashlib.md5((self.test_url + tempdir).encode("utf-8")).hexdigest()
49+
lock_file_path = os.path.join(tempdir, ".lock", lock_file_name)
50+
os.makedirs(os.path.dirname(lock_file_path), exist_ok=True)
51+
52+
# create lock file
53+
with open(lock_file_path, "w", encoding="utf-8") as f:
54+
f.write("temp test")
55+
56+
# downloading with opened lock file
57+
open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
58+
_ = os.open(lock_file_path, open_mode)
59+
config_file = get_path_from_url_with_filelock(self.test_url, root_dir=tempdir)
60+
self.assertIsNotNone(config_file)

0 commit comments

Comments
 (0)