@@ -28,7 +28,14 @@ class FastSafetensors(unittest.TestCase):
28
28
def setUp (self ):
29
29
super ().setUp ()
30
30
self .weigth_map = {}
31
- tensors = [([10 , 10 ], "float32" ), ([8 ], "float16" ), ([5 , 5 , 5 ], "int32" )]
31
+ tensors = [
32
+ ([10 , 1 , 10 ], "float32" ),
33
+ ([1 , 1 , 10 ], "float32" ),
34
+ ([1 , 1 , 1 , 10 ], "float32" ),
35
+ ([10 , 10 ], "float32" ),
36
+ ([8 ], "float16" ),
37
+ ([5 , 5 , 5 ], "int32" ),
38
+ ]
32
39
count = 0
33
40
for shape , dtype in tensors :
34
41
self .weigth_map [f"weight_{ count } " ] = (np .random .random (shape ) * 100 ).astype (dtype )
@@ -53,5 +60,10 @@ def test_safe_open(self):
53
60
with fast_safe_open (path , framework = "np" ) as f :
54
61
for key in f .keys ():
55
62
safe_slice = f .get_slice (key )
63
+ # np.testing.assert_equal(self.weigth_map[key][2:1, ...], safe_slice[2:1, ...])
64
+ np .testing .assert_equal (self .weigth_map [key ][0 , ...], safe_slice [0 , ...])
65
+ np .testing .assert_equal (self .weigth_map [key ][0 :1 , ...], safe_slice [0 :1 , ...])
66
+ np .testing .assert_equal (self .weigth_map [key ][..., 2 :], safe_slice [..., 2 :])
67
+ np .testing .assert_equal (self .weigth_map [key ][..., 1 ], safe_slice [..., 1 ])
56
68
np .testing .assert_equal (self .weigth_map [key ][:2 , ...], safe_slice [:2 , ...])
57
69
np .testing .assert_equal (self .weigth_map [key ][..., :4 ], safe_slice [..., :4 ])
0 commit comments