Skip to content

Commit 1a2502c

Browse files
authored
[Safetensors] Fix mmap for Windows system (#8734)
* fix mmap for Windows system
1 parent fbe613b commit 1a2502c

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import multiprocessing
1919
import os
20+
import sys
2021

2122
import numpy as np
2223
import paddle
@@ -67,10 +68,12 @@
6768
from paddlenlp.utils.tools import get_env_device
6869

6970
if is_safetensors_available():
70-
# from safetensors import safe_open
7171
from safetensors.numpy import save_file as safe_save_file
7272

73-
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
73+
if sys.platform.startswith("win"):
74+
from safetensors import safe_open
75+
else:
76+
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
7477

7578
FP32_MASTER = "fp32_master_0"
7679
optimizer_scalar_name = [

paddlenlp/transformers/model_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import json
2121
import os
2222
import re
23+
import sys
2324
import tempfile
2425
import warnings
2526
from contextlib import contextmanager
@@ -108,12 +109,13 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):
108109

109110

110111
if is_safetensors_available():
111-
112-
# from safetensors import safe_open
113112
from safetensors.numpy import load_file as safe_load_file
114113
from safetensors.numpy import save_file as safe_save_file
115114

116-
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
115+
if sys.platform.startswith("win"):
116+
from safetensors import safe_open
117+
else:
118+
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
117119

118120

119121
def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:

tests/testing_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,22 @@ def decorator(func):
296296
return decorator
297297

298298

299+
def skip_platform(*platform):
300+
"""decorator which can detect that it will skip the specific platform
301+
302+
Args:
303+
platform (str): the name of platform, including win32, cygwin, linux, and darwin
304+
"""
305+
306+
def decorator(func):
307+
for plat in platform:
308+
if sys.platform.startswith(plat):
309+
return unittest.skip(f"platform<{plat}> matched, so to skip this test")(func)
310+
return func
311+
312+
return decorator
313+
314+
299315
def is_slow_test() -> bool:
300316
"""check whether is the slow test
301317

tests/transformers/test_safetensors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import unittest
1818

1919
import numpy as np
20-
21-
# from safetensors import safe_open
2220
from safetensors.numpy import load_file, save_file
2321

2422
from paddlenlp.utils.safetensors import fast_load_file, fast_safe_open
2523

24+
from ..testing_utils import skip_platform
25+
2626

2727
class FastSafetensors(unittest.TestCase):
2828
def setUp(self):
@@ -42,6 +42,7 @@ def setUp(self):
4242
count += 1
4343
print(self.weigth_map)
4444

45+
@skip_platform("win32", "cygwin")
4546
def test_load_file(self):
4647
with tempfile.TemporaryDirectory() as tmpdirname:
4748
path = os.path.join(tmpdirname, "test.safetensors")
@@ -52,6 +53,7 @@ def test_load_file(self):
5253
np.testing.assert_equal(v, sf_load[k])
5354
np.testing.assert_equal(v, fs_sf_load[k])
5455

56+
@skip_platform("win32", "cygwin")
5557
def test_safe_open(self):
5658
with tempfile.TemporaryDirectory() as tmpdirname:
5759
path = os.path.join(tmpdirname, "test.safetensors")

0 commit comments

Comments
 (0)