Skip to content

Commit 05b6ee7

Browse files
committed
Use deformable_detr kernel from the Hub
Remove the `deformable_detr` kernel from `kernels/` and use the pre-built kernel from the Hub instead.
1 parent cf8091c commit 05b6ee7

22 files changed

+361
-3833
lines changed

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
130130
"keras>2.9,<2.16",
131131
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
132+
"kernels>=0.3.2,<0.4",
132133
"librosa",
133134
"natten>=0.14.6,<0.15.0",
134135
"nltk<=3.8.1",
@@ -411,6 +412,7 @@ def run(self):
411412
"filelock",
412413
"huggingface-hub",
413414
"importlib_metadata",
415+
"kernels",
414416
"numpy",
415417
"packaging",
416418
"protobuf",
@@ -432,6 +434,7 @@ def run(self):
432434
install_requires = [
433435
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
434436
deps["huggingface-hub"],
437+
deps["kernels"], # download kernels from the Hub
435438
deps["numpy"],
436439
deps["packaging"], # utilities from PyPA to e.g., compare versions
437440
deps["pyyaml"], # used for the model cards metadata

src/transformers/dependency_versions_check.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"requests",
3030
"packaging",
3131
"filelock",
32+
"kernels",
3233
"numpy",
3334
"tokenizers",
3435
"huggingface-hub",

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"kenlm": "kenlm",
3636
"keras": "keras>2.9,<2.16",
3737
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
38+
"kernels": "kernels>=0.3.2,<0.4",
3839
"librosa": "librosa",
3940
"natten": "natten>=0.14.6,<0.15.0",
4041
"nltk": "nltk<=3.8.1",

src/transformers/integrations/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@
7070
"replace_with_higgs_linear",
7171
],
7272
"hqq": ["prepare_for_hqq_linear"],
73+
"hub_kernels": [
74+
"LayerRepository",
75+
"register_kernel_mapping",
76+
"replace_kernel_forward_from_hub",
77+
"use_kernel_forward_from_hub",
78+
],
7379
"integration_utils": [
7480
"INTEGRATION_TO_CALLBACK",
7581
"AzureMLCallback",
@@ -198,6 +204,12 @@
198204
)
199205
from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
200206
from .hqq import prepare_for_hqq_linear
207+
from .hub_kernels import (
208+
LayerRepository,
209+
register_kernel_mapping,
210+
replace_kernel_forward_from_hub,
211+
use_kernel_forward_from_hub,
212+
)
201213
from .integration_utils import (
202214
INTEGRATION_TO_CALLBACK,
203215
AzureMLCallback,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Dict, Union
2+
3+
from kernels import (
4+
Device,
5+
LayerRepository,
6+
register_kernel_mapping,
7+
replace_kernel_forward_from_hub,
8+
use_kernel_forward_from_hub,
9+
)
10+
11+
12+
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
13+
"MultiScaleDeformableAttention": {
14+
"cuda": LayerRepository(
15+
repo_id="kernels-community/deformable-detr",
16+
layer_name="MultiScaleDeformableAttention",
17+
)
18+
}
19+
}
20+
21+
register_kernel_mapping(_KERNEL_MAPPING)
22+
23+
__all__ = [
24+
"LayerRepository",
25+
"use_kernel_forward_from_hub",
26+
"register_kernel_mapping",
27+
"replace_kernel_forward_from_hub",
28+
]

src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.cpp

Lines changed: 0 additions & 40 deletions
This file was deleted.

src/transformers/kernels/deformable_detr/cpu/ms_deform_attn_cpu.h

Lines changed: 0 additions & 32 deletions
This file was deleted.

src/transformers/kernels/deformable_detr/cuda/ms_deform_attn_cuda.cu

Lines changed: 0 additions & 159 deletions
This file was deleted.

0 commit comments

Comments
 (0)