Skip to content

Commit 49322cc

Browse files
committed
Updated cudf-spill test
This updates the cudf-spilling test to rely on a few less environment variables.
1 parent c2ba834 commit 49322cc

File tree

2 files changed

+47
-28
lines changed

2 files changed

+47
-28
lines changed
Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,36 @@
11
from __future__ import annotations
22

3-
import asyncio
4-
import os
3+
import sys
54

65
import pytest
76

8-
from distributed.utils_test import gen_cluster
7+
import dask
8+
from dask.distributed import worker
99

10-
pytestmark = [
11-
pytest.mark.gpu,
12-
pytest.mark.skipif(
13-
os.environ.get("CUDF_SPILL", "off") != "on"
14-
or os.environ.get("CUDF_SPILL_STATS", "0") != "1"
15-
or os.environ.get("DASK_DISTRIBUTED__DIAGNOSTICS__CUDF", "0") != "1",
16-
reason="cuDF spill stats monitoring must be enabled manually",
17-
),
18-
]
10+
from distributed.utils_test import async_poll_for, gen_cluster
1911

2012
cudf = pytest.importorskip("cudf")
2113

2214

15+
@pytest.fixture
16+
def cudf_spill():
17+
"""
18+
Configures cuDF options to enable spilling.
19+
20+
Returns the settings to their original values after the test.
21+
"""
22+
spill = cudf.get_option("spill")
23+
spill_stats = cudf.get_option("spill_stats")
24+
25+
cudf.set_option("spill", True)
26+
cudf.set_option("spill_stats", 1)
27+
28+
yield
29+
30+
cudf.set_option("spill", spill)
31+
cudf.set_option("spill_stats", spill_stats)
32+
33+
2334
def force_spill():
2435
from cudf.core.buffer.spill_manager import get_global_manager
2536

@@ -37,15 +48,27 @@ def force_spill():
3748
@gen_cluster(
3849
client=True,
3950
nthreads=[("127.0.0.1", 1)],
51+
# whether worker.cudf_metric is in DEFAULT_METRICS depends on the value
52+
# of distributed.diagnostics.cudf when distributed.worker is imported.
53+
worker_kwargs={
54+
"metrics": {**worker.DEFAULT_METRICS, "cudf": worker.cudf_metric},
55+
},
4056
)
57+
@pytest.mark.usefixtures("cudf_spill")
4158
async def test_cudf_metrics(c, s, *workers):
4259
w = list(s.workers.values())[0]
4360
assert "cudf" in w.metrics
4461
assert w.metrics["cudf"]["cudf-spilled"] == 0
4562

4663
spill_totals = (await c.run(force_spill, workers=[w.address]))[w.address]
4764
assert spill_totals > 0
48-
# We have to wait for the worker's metrics to update.
49-
# TODO: avoid sleep, is it possible to wait on the next update of metrics?
50-
await asyncio.sleep(1)
65+
await async_poll_for(lambda: w.metrics["cudf"]["cudf-spilled"] > 0, timeout=2)
5166
assert w.metrics["cudf"]["cudf-spilled"] == spill_totals
67+
68+
69+
def test_cudf_default_metrics(monkeypatch):
70+
with dask.config.set(**{"distributed.diagnostics.cudf": 1}):
71+
del sys.modules["distributed.worker"]
72+
import distributed.worker
73+
74+
assert "cudf" in distributed.worker.DEFAULT_METRICS

distributed/worker.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3224,21 +3224,17 @@ async def rmm_metric(worker):
32243224
DEFAULT_METRICS["rmm"] = rmm_metric
32253225
del _rmm
32263226

3227-
# avoid importing cuDF unless explicitly enabled
3228-
if dask.config.get("distributed.diagnostics.cudf"):
3229-
try:
3230-
import cudf as _cudf # noqa: F401
3231-
except Exception:
3232-
pass
3233-
else:
3234-
from distributed.diagnostics import cudf
32353227

3236-
async def cudf_metric(worker):
3237-
result = await offload(cudf.real_time)
3238-
return result
3228+
async def cudf_metric(worker):
3229+
# avoid importing optional cudf at the top-level
3230+
from distributed.diagnostics import cudf
32393231

3240-
DEFAULT_METRICS["cudf"] = cudf_metric
3241-
del _cudf
3232+
result = await offload(cudf.real_time)
3233+
return result
3234+
3235+
3236+
if dask.config.get("distributed.diagnostics.cudf"):
3237+
DEFAULT_METRICS["cudf"] = cudf_metric
32423238

32433239

32443240
def print(

0 commit comments

Comments
 (0)