Skip to content

Commit 2286799

Browse files
authored
[Safetensors] Fix fast safe open slice. (#8512)
* fix more case. * fix slice. * fix error slice.
1 parent fd2a39e commit 2286799

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

paddlenlp/utils/safetensors.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,16 +157,16 @@ def __getitem__(self, index):
157157

158158
out_start, out_stop, out_step = copy.deepcopy((self.start, self.stop, self.step))
159159
for i, (start, stop, step, slice_) in enumerate(zip(self.start, self.stop, self.step, index)):
160-
out_start[i] = slice_.start or 0
161-
out_step[i] = slice_.step or 1
162-
out_stop[i] = slice_.stop or stop - start
160+
out_start[i] = slice_.start if slice_.start is not None else 0
161+
out_step[i] = slice_.step if slice_.step is not None else 1
162+
out_stop[i] = slice_.stop if slice_.stop is not None else stop - start
163163
out_stop[i] = min(stop, out_stop[i])
164164

165165
target_shape = []
166-
for x, y, z in zip(out_start, out_stop, out_step):
166+
for x, y, z, sli in zip(out_start, out_stop, out_step, index):
167167
assert z == 1, "only support step = 1"
168-
if y - x > 1:
169-
target_shape.append(int(y - x))
168+
if y - x > 1 or sli.step is None:
169+
target_shape.append(max(int(y - x), 0))
170170

171171
if len(target_shape) == 0:
172172
if self.shape == [1]:

tests/transformers/test_safetensors.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ class FastSafetensors(unittest.TestCase):
2828
def setUp(self):
2929
super().setUp()
3030
self.weigth_map = {}
31-
tensors = [([10, 10], "float32"), ([8], "float16"), ([5, 5, 5], "int32")]
31+
tensors = [
32+
([10, 1, 10], "float32"),
33+
([1, 1, 10], "float32"),
34+
([1, 1, 1, 10], "float32"),
35+
([10, 10], "float32"),
36+
([8], "float16"),
37+
([5, 5, 5], "int32"),
38+
]
3239
count = 0
3340
for shape, dtype in tensors:
3441
self.weigth_map[f"weight_{count}"] = (np.random.random(shape) * 100).astype(dtype)
@@ -53,5 +60,10 @@ def test_safe_open(self):
5360
with fast_safe_open(path, framework="np") as f:
5461
for key in f.keys():
5562
safe_slice = f.get_slice(key)
63+
# np.testing.assert_equal(self.weigth_map[key][2:1, ...], safe_slice[2:1, ...])
64+
np.testing.assert_equal(self.weigth_map[key][0, ...], safe_slice[0, ...])
65+
np.testing.assert_equal(self.weigth_map[key][0:1, ...], safe_slice[0:1, ...])
66+
np.testing.assert_equal(self.weigth_map[key][..., 2:], safe_slice[..., 2:])
67+
np.testing.assert_equal(self.weigth_map[key][..., 1], safe_slice[..., 1])
5668
np.testing.assert_equal(self.weigth_map[key][:2, ...], safe_slice[:2, ...])
5769
np.testing.assert_equal(self.weigth_map[key][..., :4], safe_slice[..., :4])

0 commit comments

Comments
 (0)