Skip to content

Commit 96ae475

Browse files
committed
Add kernels as an extra hub-kernels
Also add it to `testing`, so that the kernel replacement gets tested when using CUDA in CI.
1 parent 2083474 commit 96ae475

File tree

3 files changed

+51
-21
lines changed

3 files changed

+51
-21
lines changed

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,9 @@ def run(self):
302302
extras["optuna"] = deps_list("optuna")
303303
extras["ray"] = deps_list("ray[tune]")
304304
extras["sigopt"] = deps_list("sigopt")
305+
extras["hub-kernels"] = deps_list("kernels")
305306

306-
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
307+
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
307308

308309
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
309310
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")
@@ -412,7 +413,6 @@ def run(self):
412413
"filelock",
413414
"huggingface-hub",
414415
"importlib_metadata",
415-
"kernels",
416416
"numpy",
417417
"packaging",
418418
"protobuf",
@@ -434,7 +434,6 @@ def run(self):
434434
install_requires = [
435435
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
436436
deps["huggingface-hub"],
437-
deps["kernels"], # download kernels from the Hub
438437
deps["numpy"],
439438
deps["packaging"], # utilities from PyPA to e.g., compare versions
440439
deps["pyyaml"], # used for the model cards metadata

src/transformers/dependency_versions_check.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"requests",
3030
"packaging",
3131
"filelock",
32-
"kernels",
3332
"numpy",
3433
"tokenizers",
3534
"huggingface-hub",

src/transformers/integrations/hub_kernels.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,60 @@
1313
# limitations under the License.
1414
from typing import Dict, Union
1515

16-
from kernels import (
17-
Device,
18-
LayerRepository,
19-
register_kernel_mapping,
20-
replace_kernel_forward_from_hub,
21-
use_kernel_forward_from_hub,
22-
)
23-
24-
25-
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
26-
"MultiScaleDeformableAttention": {
27-
"cuda": LayerRepository(
28-
repo_id="kernels-community/deformable-detr",
29-
layer_name="MultiScaleDeformableAttention",
30-
)
16+
17+
try:
18+
from kernels import (
19+
Device,
20+
LayerRepository,
21+
register_kernel_mapping,
22+
replace_kernel_forward_from_hub,
23+
use_kernel_forward_from_hub,
24+
)
25+
26+
_hub_kernels_available = True
27+
28+
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
29+
"MultiScaleDeformableAttention": {
30+
"cuda": LayerRepository(
31+
repo_id="kernels-community/deformable-detr",
32+
layer_name="MultiScaleDeformableAttention",
33+
)
34+
}
3135
}
32-
}
3336

34-
register_kernel_mapping(_KERNEL_MAPPING)
37+
register_kernel_mapping(_KERNEL_MAPPING)
38+
39+
except ImportError:
40+
# Stub to make decorators int transformers work when `kernels`
41+
# is not installed.
42+
def use_kernel_forward_from_hub(*args, **kwargs):
43+
def decorator(cls):
44+
return cls
45+
46+
return decorator
47+
48+
class LayerRepository:
49+
def __init__(self, *args, **kwargs):
50+
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
51+
52+
def replace_kernel_forward_from_hub(*args, **kwargs):
53+
raise RuntimeError(
54+
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
55+
)
56+
57+
def register_kernel_mapping(*args, **kwargs):
58+
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
59+
60+
_hub_kernels_available = False
61+
62+
63+
def is_hub_kernels_available():
64+
return _hub_kernels_available
65+
3566

3667
__all__ = [
3768
"LayerRepository",
69+
"is_hub_kernels_available",
3870
"use_kernel_forward_from_hub",
3971
"register_kernel_mapping",
4072
"replace_kernel_forward_from_hub",

0 commit comments

Comments
 (0)