1
1
from __future__ import annotations
2
2
3
- import asyncio
4
- import os
3
+ import sys
5
4
6
5
import pytest
7
6
8
- from distributed .utils_test import gen_cluster
7
+ import dask
8
+ from dask .distributed import worker
9
9
10
- pytestmark = [
11
- pytest .mark .gpu ,
12
- pytest .mark .skipif (
13
- os .environ .get ("CUDF_SPILL" , "off" ) != "on"
14
- or os .environ .get ("CUDF_SPILL_STATS" , "0" ) != "1"
15
- or os .environ .get ("DASK_DISTRIBUTED__DIAGNOSTICS__CUDF" , "0" ) != "1" ,
16
- reason = "cuDF spill stats monitoring must be enabled manually" ,
17
- ),
18
- ]
10
+ from distributed .utils_test import async_poll_for , gen_cluster
19
11
20
12
cudf = pytest .importorskip ("cudf" )
21
13
22
14
15
+ @pytest .fixture
16
+ def cudf_spill ():
17
+ """
18
+ Configures cuDF options to enable spilling.
19
+
20
+ Returns the settings to their original values after the test.
21
+ """
22
+ spill = cudf .get_option ("spill" )
23
+ spill_stats = cudf .get_option ("spill_stats" )
24
+
25
+ cudf .set_option ("spill" , True )
26
+ cudf .set_option ("spill_stats" , 1 )
27
+
28
+ yield
29
+
30
+ cudf .set_option ("spill" , spill )
31
+ cudf .set_option ("spill_stats" , spill_stats )
32
+
33
+
23
34
def force_spill ():
24
35
from cudf .core .buffer .spill_manager import get_global_manager
25
36
@@ -37,15 +48,27 @@ def force_spill():
37
48
@gen_cluster (
38
49
client = True ,
39
50
nthreads = [("127.0.0.1" , 1 )],
51
+ # whether worker.cudf_metric is in DEFAULT_METRICS depends on the value
52
+ # of distributed.diagnostics.cudf when distributed.worker is imported.
53
+ worker_kwargs = {
54
+ "metrics" : {** worker .DEFAULT_METRICS , "cudf" : worker .cudf_metric },
55
+ },
40
56
)
57
+ @pytest .mark .usefixtures ("cudf_spill" )
41
58
async def test_cudf_metrics (c , s , * workers ):
42
59
w = list (s .workers .values ())[0 ]
43
60
assert "cudf" in w .metrics
44
61
assert w .metrics ["cudf" ]["cudf-spilled" ] == 0
45
62
46
63
spill_totals = (await c .run (force_spill , workers = [w .address ]))[w .address ]
47
64
assert spill_totals > 0
48
- # We have to wait for the worker's metrics to update.
49
- # TODO: avoid sleep, is it possible to wait on the next update of metrics?
50
- await asyncio .sleep (1 )
65
+ await async_poll_for (lambda : w .metrics ["cudf" ]["cudf-spilled" ] > 0 , timeout = 2 )
51
66
assert w .metrics ["cudf" ]["cudf-spilled" ] == spill_totals
67
+
68
+
69
+ def test_cudf_default_metrics (monkeypatch ):
70
+ with dask .config .set (** {"distributed.diagnostics.cudf" : 1 }):
71
+ del sys .modules ["distributed.worker" ]
72
+ import distributed .worker
73
+
74
+ assert "cudf" in distributed .worker .DEFAULT_METRICS
0 commit comments