Skip to content

Commit 1322ec5

Browse files
authored
Simplified Testing Interface (huggingface#289)
1 parent 48e9818 commit 1322ec5

File tree

41 files changed

+326
-1009
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+326
-1009
lines changed

.github/workflows/test-models.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ jobs:
3232
suite: gpu
3333
- os: MacStudio
3434
suite: cpu
35-
- os: MacStudio
36-
suite: vulkan
3735

3836
runs-on: ${{ matrix.os }}
3937

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ transformers==4.18.0
1414
pytest
1515
pytest-xdist
1616
Pillow
17+
parameterized

tank/MiniLM-L12-H384-uncased/MiniLM-L12-H384-uncased_test.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from shark.shark_inference import SharkInference
33
from shark.shark_downloader import download_tf_model
44
from shark.parser import shark_args
5+
from tank.test_utils import get_valid_test_params, shark_test_name_func
6+
from parameterized import parameterized
57

68
import iree.compiler as ireec
79
import unittest
@@ -67,25 +69,10 @@ def configure(self, pytestconfig):
6769
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
6870
self.module_tester.onnx_bench = pytestconfig.getoption("onnx_bench")
6971

70-
def test_module_static_cpu(self):
71-
dynamic = False
72-
device = "cpu"
73-
self.module_tester.create_and_check_module(dynamic, device)
74-
75-
@pytest.mark.skipif(
76-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
77-
)
78-
def test_module_static_gpu(self):
79-
dynamic = False
80-
device = "gpu"
81-
self.module_tester.create_and_check_module(dynamic, device)
72+
param_list = get_valid_test_params()
8273

83-
@pytest.mark.skipif(
84-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
85-
)
86-
def test_module_static_vulkan(self):
87-
dynamic = False
88-
device = "vulkan"
74+
@parameterized.expand(param_list, name_func=shark_test_name_func)
75+
def test_module(self, dynamic, device):
8976
self.module_tester.create_and_check_module(dynamic, device)
9077

9178

tank/MiniLM-L12-H384-uncased_torch/MiniLM-L12-H384-uncased_torch_test.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from tank.model_utils import compare_tensors
44
from shark.shark_downloader import download_torch_model
55
from shark.parser import shark_args
6+
from tank.test_utils import get_valid_test_params, shark_test_name_func
7+
from parameterized import parameterized
68

79
import unittest
810
import numpy as np
@@ -59,46 +61,10 @@ def configure(self, pytestconfig):
5961
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
6062
self.module_tester.onnx_bench = pytestconfig.getoption("onnx_bench")
6163

62-
def test_module_static_cpu(self):
63-
dynamic = False
64-
device = "cpu"
65-
self.module_tester.create_and_check_module(dynamic, device)
66-
67-
def test_module_dynamic_cpu(self):
68-
dynamic = True
69-
device = "cpu"
70-
self.module_tester.create_and_check_module(dynamic, device)
71-
72-
@pytest.mark.skipif(
73-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
74-
)
75-
def test_module_static_gpu(self):
76-
dynamic = False
77-
device = "gpu"
78-
self.module_tester.create_and_check_module(dynamic, device)
79-
80-
@pytest.mark.skipif(
81-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
82-
)
83-
def test_module_dynamic_gpu(self):
84-
dynamic = True
85-
device = "gpu"
86-
self.module_tester.create_and_check_module(dynamic, device)
87-
88-
@pytest.mark.skipif(
89-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
90-
)
91-
def test_module_static_vulkan(self):
92-
dynamic = False
93-
device = "vulkan"
94-
self.module_tester.create_and_check_module(dynamic, device)
64+
param_list = get_valid_test_params()
9565

96-
@pytest.mark.skipif(
97-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
98-
)
99-
def test_module_dynamic_vulkan(self):
100-
dynamic = True
101-
device = "vulkan"
66+
@parameterized.expand(param_list, name_func=shark_test_name_func)
67+
def test_module(self, dynamic, device):
10268
self.module_tester.create_and_check_module(dynamic, device)
10369

10470

tank/albert-base-v2_tf/albert-base-v2_tf_test.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from shark.iree_utils._common import check_device_drivers, device_driver_info
22
from shark.shark_inference import SharkInference
33
from shark.shark_downloader import download_tf_model
4+
from tank.test_utils import get_valid_test_params, shark_test_name_func
5+
from parameterized import parameterized
46

57
import iree.compiler as ireec
68
import unittest
@@ -34,34 +36,10 @@ def configure(self, pytestconfig):
3436
self.module_tester = AlbertBaseModuleTester(self)
3537
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
3638

37-
def test_module_static_cpu(self):
38-
dynamic = False
39-
device = "cpu"
40-
self.module_tester.create_and_check_module(dynamic, device)
41-
42-
@pytest.mark.skipif(
43-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
44-
)
45-
def test_module_static_gpu(self):
46-
dynamic = False
47-
device = "gpu"
48-
self.module_tester.create_and_check_module(dynamic, device)
49-
50-
@pytest.mark.skipif(
51-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
52-
)
53-
def test_module_static_vulkan(self):
54-
dynamic = False
55-
device = "vulkan"
56-
self.module_tester.create_and_check_module(dynamic, device)
39+
param_list = get_valid_test_params()
5740

58-
@pytest.mark.skipif(
59-
check_device_drivers("intel-gpu"),
60-
reason=device_driver_info("intel-gpu"),
61-
)
62-
def test_module_static_intel_gpu(self):
63-
dynamic = False
64-
device = "intel-gpu"
41+
@parameterized.expand(param_list, name_func=shark_test_name_func)
42+
def test_module(self, dynamic, device):
6543
self.module_tester.create_and_check_module(dynamic, device)
6644

6745

tank/albert-base-v2_torch/albert-base-v2_torch_test.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from shark.iree_utils._common import check_device_drivers, device_driver_info
33
from tank.model_utils import compare_tensors
44
from shark.shark_downloader import download_torch_model
5+
from tank.test_utils import get_valid_test_params, shark_test_name_func
6+
from parameterized import parameterized
57

68
import unittest
79
import numpy as np
@@ -57,46 +59,10 @@ def configure(self, pytestconfig):
5759
self.module_tester = AlbertModuleTester(self)
5860
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
5961

60-
def test_module_static_cpu(self):
61-
dynamic = False
62-
device = "cpu"
63-
self.module_tester.create_and_check_module(dynamic, device)
64-
65-
def test_module_dynamic_cpu(self):
66-
dynamic = True
67-
device = "cpu"
68-
self.module_tester.create_and_check_module(dynamic, device)
69-
70-
@pytest.mark.skipif(
71-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
72-
)
73-
def test_module_static_gpu(self):
74-
dynamic = False
75-
device = "gpu"
76-
self.module_tester.create_and_check_module(dynamic, device)
77-
78-
@pytest.mark.skipif(
79-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
80-
)
81-
def test_module_dynamic_gpu(self):
82-
dynamic = True
83-
device = "gpu"
84-
self.module_tester.create_and_check_module(dynamic, device)
85-
86-
@pytest.mark.skipif(
87-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
88-
)
89-
def test_module_static_vulkan(self):
90-
dynamic = False
91-
device = "vulkan"
92-
self.module_tester.create_and_check_module(dynamic, device)
62+
param_list = get_valid_test_params()
9363

94-
@pytest.mark.skipif(
95-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
96-
)
97-
def test_module_dynamic_vulkan(self):
98-
dynamic = True
99-
device = "vulkan"
64+
@parameterized.expand(param_list, name_func=shark_test_name_func)
65+
def test_module(self, dynamic, device):
10066
self.module_tester.create_and_check_module(dynamic, device)
10167

10268

tank/alexnet_torch/alexnet_torch_test.py

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from shark.shark_inference import SharkInference
22
from shark.iree_utils._common import check_device_drivers, device_driver_info
33
from tank.model_utils import compare_tensors
4+
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
45
from shark.shark_downloader import download_torch_model
6+
from tank.test_utils import get_valid_test_params, shark_test_name_func
57

8+
from parameterized import parameterized
69
import unittest
710
import numpy as np
811
import pytest
@@ -57,49 +60,16 @@ def configure(self, pytestconfig):
5760
self.module_tester = AlexnetModuleTester(self)
5861
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
5962

60-
def test_module_static_cpu(self):
61-
dynamic = False
62-
device = "cpu"
63-
self.module_tester.create_and_check_module(dynamic, device)
64-
65-
def test_module_dynamic_cpu(self):
66-
dynamic = True
67-
device = "cpu"
68-
self.module_tester.create_and_check_module(dynamic, device)
69-
70-
@pytest.mark.skipif(
71-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
72-
)
73-
def test_module_static_gpu(self):
74-
dynamic = False
75-
device = "gpu"
76-
self.module_tester.create_and_check_module(dynamic, device)
77-
78-
@pytest.mark.skipif(
79-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
80-
)
81-
def test_module_dynamic_gpu(self):
82-
dynamic = True
83-
device = "gpu"
84-
self.module_tester.create_and_check_module(dynamic, device)
85-
86-
@pytest.mark.skipif(
87-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
88-
)
89-
@pytest.mark.xfail(
90-
reason="Issue known, WIP",
91-
)
92-
def test_module_static_vulkan(self):
93-
dynamic = False
94-
device = "vulkan"
95-
self.module_tester.create_and_check_module(dynamic, device)
63+
param_list = get_valid_test_params()
9664

97-
@pytest.mark.skipif(
98-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
99-
)
100-
def test_module_dynamic_vulkan(self):
101-
dynamic = True
102-
device = "vulkan"
65+
@parameterized.expand(param_list, name_func=shark_test_name_func)
66+
def test_module(self, dynamic, device):
67+
if device in ["metal", "vulkan"]:
68+
if dynamic == False:
69+
if "m1-moltenvk-macos" in get_vulkan_triple_flag():
70+
pytest.xfail(
71+
reason="Assert Error:https://github.com/iree-org/iree/issues/10075"
72+
)
10373
self.module_tester.create_and_check_module(dynamic, device)
10474

10575

tank/bert-base-cased_torch/bert-base-cased_torch_test.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from shark.iree_utils._common import check_device_drivers, device_driver_info
33
from tank.model_utils import compare_tensors
44
from shark.shark_downloader import download_torch_model
5+
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
6+
from tank.test_utils import get_valid_test_params, shark_test_name_func
7+
from parameterized import parameterized
58

69
import torch
710
import unittest
@@ -62,46 +65,14 @@ def configure(self, pytestconfig):
6265
self.module_tester = BertBaseUncasedModuleTester(self)
6366
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
6467

65-
def test_module_static_cpu(self):
66-
dynamic = False
67-
device = "cpu"
68-
self.module_tester.create_and_check_module(dynamic, device)
69-
70-
def test_module_dynamic_cpu(self):
71-
dynamic = True
72-
device = "cpu"
73-
self.module_tester.create_and_check_module(dynamic, device)
74-
75-
@pytest.mark.skipif(
76-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
77-
)
78-
def test_module_static_gpu(self):
79-
dynamic = False
80-
device = "gpu"
81-
self.module_tester.create_and_check_module(dynamic, device)
82-
83-
@pytest.mark.skipif(
84-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
85-
)
86-
def test_module_dynamic_gpu(self):
87-
dynamic = True
88-
device = "gpu"
89-
self.module_tester.create_and_check_module(dynamic, device)
90-
91-
@pytest.mark.skipif(
92-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
93-
)
94-
def test_module_static_vulkan(self):
95-
dynamic = False
96-
device = "vulkan"
97-
self.module_tester.create_and_check_module(dynamic, device)
68+
param_list = get_valid_test_params()
9869

99-
@pytest.mark.skipif(
100-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
101-
)
102-
def test_module_dynamic_vulkan(self):
103-
dynamic = True
104-
device = "vulkan"
70+
@parameterized.expand(param_list, name_func=shark_test_name_func)
71+
def test_module(self, dynamic, device):
72+
if device in ["metal", "vulkan"]:
73+
if dynamic == False:
74+
if "m1-moltenvk-macos" in get_vulkan_triple_flag():
75+
pytest.xfail(reason="M1: CompilerToolError | M2: Pass")
10576
self.module_tester.create_and_check_module(dynamic, device)
10677

10778

tank/bert-base-uncased_tf/bert-base-uncased_tf_test.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from shark.shark_inference import SharkInference
33
from shark.shark_downloader import download_tf_model
44
from shark.parser import shark_args
5+
from tank.test_utils import get_valid_test_params, shark_test_name_func
6+
from parameterized import parameterized
57

68
import unittest
79
import pytest
@@ -37,25 +39,10 @@ def configure(self, pytestconfig):
3739
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
3840
self.module_tester.benchmark = pytestconfig.getoption("benchmark")
3941

40-
def test_module_static_cpu(self):
41-
dynamic = False
42-
device = "cpu"
43-
self.module_tester.create_and_check_module(dynamic, device)
44-
45-
@pytest.mark.skipif(
46-
check_device_drivers("gpu"), reason=device_driver_info("gpu")
47-
)
48-
def test_module_static_gpu(self):
49-
dynamic = False
50-
device = "gpu"
51-
self.module_tester.create_and_check_module(dynamic, device)
42+
param_list = get_valid_test_params()
5243

53-
@pytest.mark.skipif(
54-
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
55-
)
56-
def test_module_static_vulkan(self):
57-
dynamic = False
58-
device = "vulkan"
44+
@parameterized.expand(param_list, name_func=shark_test_name_func)
45+
def test_module(self, dynamic, device):
5946
self.module_tester.create_and_check_module(dynamic, device)
6047

6148

0 commit comments

Comments
 (0)