Skip to content

Commit c1466c6

Browse files
authored
Add is_enabled to handoffs (#925)
Was added to function tools before, now handoffs. Towards #918
1 parent 91c62c1 commit c1466c6

File tree

6 files changed

+140
-21
lines changed

6 files changed

+140
-21
lines changed

src/agents/handoffs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .strict_schema import ensure_strict_json_schema
1616
from .tracing.spans import SpanError
1717
from .util import _error_tracing, _json, _transforms
18+
from .util._types import MaybeAwaitable
1819

1920
if TYPE_CHECKING:
2021
from .agent import Agent
@@ -99,6 +100,11 @@ class Handoff(Generic[TContext]):
99100
True, as it increases the likelihood of correct JSON input.
100101
"""
101102

103+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
104+
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
105+
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
106+
a handoff based on your context/state."""
107+
102108
def get_transfer_message(self, agent: Agent[Any]) -> str:
103109
return json.dumps({"assistant": agent.name})
104110

@@ -121,6 +127,7 @@ def handoff(
121127
tool_name_override: str | None = None,
122128
tool_description_override: str | None = None,
123129
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
130+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
124131
) -> Handoff[TContext]: ...
125132

126133

@@ -133,6 +140,7 @@ def handoff(
133140
tool_description_override: str | None = None,
134141
tool_name_override: str | None = None,
135142
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
143+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
136144
) -> Handoff[TContext]: ...
137145

138146

@@ -144,6 +152,7 @@ def handoff(
144152
tool_description_override: str | None = None,
145153
tool_name_override: str | None = None,
146154
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
155+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
147156
) -> Handoff[TContext]: ...
148157

149158

@@ -154,6 +163,7 @@ def handoff(
154163
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
155164
input_type: type[THandoffInput] | None = None,
156165
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
166+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
157167
) -> Handoff[TContext]:
158168
"""Create a handoff from an agent.
159169
@@ -166,6 +176,9 @@ def handoff(
166176
input_type: the type of the input to the handoff. If provided, the input will be validated
167177
against this type. Only relevant if you pass a function that takes an input.
168178
input_filter: a function that filters the inputs that are passed to the next agent.
179+
is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run
180+
context and agent and returns whether the handoff is enabled. Disabled handoffs are
181+
hidden from the LLM at runtime.
169182
"""
170183
assert (on_handoff and input_type) or not (on_handoff and input_type), (
171184
"You must provide either both on_handoff and input_type, or neither"
@@ -233,4 +246,5 @@ async def _invoke_handoff(
233246
on_invoke_handoff=_invoke_handoff,
234247
input_filter=input_filter,
235248
agent_name=agent.name,
249+
is_enabled=is_enabled,
236250
)

src/agents/run.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import copy
5+
import inspect
56
from dataclasses import dataclass, field
67
from typing import Any, Generic, cast
78

@@ -361,7 +362,8 @@ async def run(
361362
# agent changes, or if the agent loop ends.
362363
if current_span is None:
363364
handoff_names = [
364-
h.agent_name for h in AgentRunner._get_handoffs(current_agent)
365+
h.agent_name
366+
for h in await AgentRunner._get_handoffs(current_agent, context_wrapper)
365367
]
366368
if output_schema := AgentRunner._get_output_schema(current_agent):
367369
output_type_name = output_schema.name()
@@ -641,7 +643,10 @@ async def _start_streaming(
641643
# Start an agent span if we don't have one. This span is ended if the current
642644
# agent changes, or if the agent loop ends.
643645
if current_span is None:
644-
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
646+
handoff_names = [
647+
h.agent_name
648+
for h in await cls._get_handoffs(current_agent, context_wrapper)
649+
]
645650
if output_schema := cls._get_output_schema(current_agent):
646651
output_type_name = output_schema.name()
647652
else:
@@ -798,7 +803,7 @@ async def _run_single_turn_streamed(
798803
agent.get_prompt(context_wrapper),
799804
)
800805

801-
handoffs = cls._get_handoffs(agent)
806+
handoffs = await cls._get_handoffs(agent, context_wrapper)
802807
model = cls._get_model(agent, run_config)
803808
model_settings = agent.model_settings.resolve(run_config.model_settings)
804809
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
@@ -898,7 +903,7 @@ async def _run_single_turn(
898903
)
899904

900905
output_schema = cls._get_output_schema(agent)
901-
handoffs = cls._get_handoffs(agent)
906+
handoffs = await cls._get_handoffs(agent, context_wrapper)
902907
input = ItemHelpers.input_to_new_input_list(original_input)
903908
input.extend([generated_item.to_input_item() for generated_item in generated_items])
904909

@@ -1091,14 +1096,28 @@ def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
10911096
return AgentOutputSchema(agent.output_type)
10921097

10931098
@classmethod
1094-
def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
1099+
async def _get_handoffs(
1100+
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
1101+
) -> list[Handoff]:
10951102
handoffs = []
10961103
for handoff_item in agent.handoffs:
10971104
if isinstance(handoff_item, Handoff):
10981105
handoffs.append(handoff_item)
10991106
elif isinstance(handoff_item, Agent):
11001107
handoffs.append(handoff(handoff_item))
1101-
return handoffs
1108+
1109+
async def _check_handoff_enabled(handoff_obj: Handoff) -> bool:
1110+
attr = handoff_obj.is_enabled
1111+
if isinstance(attr, bool):
1112+
return attr
1113+
res = attr(context_wrapper, agent)
1114+
if inspect.isawaitable(res):
1115+
return bool(await res)
1116+
return bool(res)
1117+
1118+
results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
1119+
enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok]
1120+
return enabled
11021121

11031122
@classmethod
11041123
async def _get_all_tools(

tests/test_agent_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def test_handoff_with_agents():
4343
handoffs=[agent_1, agent_2],
4444
)
4545

46-
handoffs = AgentRunner._get_handoffs(agent_3)
46+
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
4747
assert len(handoffs) == 2
4848

4949
assert handoffs[0].agent_name == "agent_1"
@@ -78,7 +78,7 @@ async def test_handoff_with_handoff_obj():
7878
],
7979
)
8080

81-
handoffs = AgentRunner._get_handoffs(agent_3)
81+
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
8282
assert len(handoffs) == 2
8383

8484
assert handoffs[0].agent_name == "agent_1"
@@ -112,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent():
112112
handoffs=[handoff(agent_1), agent_2],
113113
)
114114

115-
handoffs = AgentRunner._get_handoffs(agent_3)
115+
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
116116
assert len(handoffs) == 2
117117

118118
assert handoffs[0].agent_name == "agent_1"

tests/test_handoff_tool.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,26 @@ def get_len(data: HandoffInputData) -> int:
3838
return input_len + pre_handoff_len + new_items_len
3939

4040

41-
def test_single_handoff_setup():
41+
@pytest.mark.asyncio
42+
async def test_single_handoff_setup():
4243
agent_1 = Agent(name="test_1")
4344
agent_2 = Agent(name="test_2", handoffs=[agent_1])
4445

4546
assert not agent_1.handoffs
4647
assert agent_2.handoffs == [agent_1]
4748

48-
assert not AgentRunner._get_handoffs(agent_1)
49+
assert not (await AgentRunner._get_handoffs(agent_1, RunContextWrapper(agent_1)))
4950

50-
handoff_objects = AgentRunner._get_handoffs(agent_2)
51+
handoff_objects = await AgentRunner._get_handoffs(agent_2, RunContextWrapper(agent_2))
5152
assert len(handoff_objects) == 1
5253
obj = handoff_objects[0]
5354
assert obj.tool_name == Handoff.default_tool_name(agent_1)
5455
assert obj.tool_description == Handoff.default_tool_description(agent_1)
5556
assert obj.agent_name == agent_1.name
5657

5758

58-
def test_multiple_handoffs_setup():
59+
@pytest.mark.asyncio
60+
async def test_multiple_handoffs_setup():
5961
agent_1 = Agent(name="test_1")
6062
agent_2 = Agent(name="test_2")
6163
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
@@ -64,7 +66,7 @@ def test_multiple_handoffs_setup():
6466
assert not agent_1.handoffs
6567
assert not agent_2.handoffs
6668

67-
handoff_objects = AgentRunner._get_handoffs(agent_3)
69+
handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3))
6870
assert len(handoff_objects) == 2
6971
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
7072
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
@@ -76,7 +78,8 @@ def test_multiple_handoffs_setup():
7678
assert handoff_objects[1].agent_name == agent_2.name
7779

7880

79-
def test_custom_handoff_setup():
81+
@pytest.mark.asyncio
82+
async def test_custom_handoff_setup():
8083
agent_1 = Agent(name="test_1")
8184
agent_2 = Agent(name="test_2")
8285
agent_3 = Agent(
@@ -95,7 +98,7 @@ def test_custom_handoff_setup():
9598
assert not agent_1.handoffs
9699
assert not agent_2.handoffs
97100

98-
handoff_objects = AgentRunner._get_handoffs(agent_3)
101+
handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3))
99102
assert len(handoff_objects) == 2
100103

101104
first_handoff = handoff_objects[0]
@@ -284,3 +287,86 @@ def test_get_transfer_message_is_valid_json() -> None:
284287
obj = handoff(agent)
285288
transfer = obj.get_transfer_message(agent)
286289
assert json.loads(transfer) == {"assistant": agent.name}
290+
291+
292+
def test_handoff_is_enabled_bool():
293+
"""Test that handoff respects is_enabled boolean parameter."""
294+
agent = Agent(name="test")
295+
296+
# Test enabled handoff (default)
297+
handoff_enabled = handoff(agent)
298+
assert handoff_enabled.is_enabled is True
299+
300+
# Test explicitly enabled handoff
301+
handoff_explicit_enabled = handoff(agent, is_enabled=True)
302+
assert handoff_explicit_enabled.is_enabled is True
303+
304+
# Test disabled handoff
305+
handoff_disabled = handoff(agent, is_enabled=False)
306+
assert handoff_disabled.is_enabled is False
307+
308+
309+
@pytest.mark.asyncio
310+
async def test_handoff_is_enabled_callable():
311+
"""Test that handoff respects is_enabled callable parameter."""
312+
agent = Agent(name="test")
313+
314+
# Test callable that returns True
315+
def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
316+
return True
317+
318+
handoff_callable_enabled = handoff(agent, is_enabled=always_enabled)
319+
assert callable(handoff_callable_enabled.is_enabled)
320+
result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent)
321+
assert result is True
322+
323+
# Test callable that returns False
324+
def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
325+
return False
326+
327+
handoff_callable_disabled = handoff(agent, is_enabled=always_disabled)
328+
assert callable(handoff_callable_disabled.is_enabled)
329+
result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent)
330+
assert result is False
331+
332+
# Test async callable
333+
async def async_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
334+
return True
335+
336+
handoff_async_enabled = handoff(agent, is_enabled=async_enabled)
337+
assert callable(handoff_async_enabled.is_enabled)
338+
result = await handoff_async_enabled.is_enabled(RunContextWrapper(agent), agent) # type: ignore
339+
assert result is True
340+
341+
342+
@pytest.mark.asyncio
343+
async def test_handoff_is_enabled_filtering_integration():
344+
"""Integration test that disabled handoffs are filtered out by the runner."""
345+
346+
# Set up agents
347+
agent_1 = Agent(name="agent_1")
348+
agent_2 = Agent(name="agent_2")
349+
agent_3 = Agent(name="agent_3")
350+
351+
# Create main agent with mixed enabled/disabled handoffs
352+
main_agent = Agent(
353+
name="main_agent",
354+
handoffs=[
355+
handoff(agent_1, is_enabled=True), # enabled
356+
handoff(agent_2, is_enabled=False), # disabled
357+
handoff(agent_3, is_enabled=lambda ctx, agent: True), # enabled callable
358+
],
359+
)
360+
361+
context_wrapper = RunContextWrapper(main_agent)
362+
363+
# Get filtered handoffs using the runner's method
364+
filtered_handoffs = await AgentRunner._get_handoffs(main_agent, context_wrapper)
365+
366+
# Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out
367+
assert len(filtered_handoffs) == 2
368+
369+
# Check that the correct agents are present
370+
agent_names = {h.agent_name for h in filtered_handoffs}
371+
assert agent_names == {"agent_1", "agent_3"}
372+
assert "agent_2" not in agent_names

tests/test_run_step_execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ async def get_execute_result(
325325
run_config: RunConfig | None = None,
326326
) -> SingleStepResult:
327327
output_schema = AgentRunner._get_output_schema(agent)
328-
handoffs = AgentRunner._get_handoffs(agent)
328+
handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None))
329329

330330
processed_response = RunImpl.process_model_response(
331331
agent=agent,

tests/test_run_step_processing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly():
186186
agent=agent_3,
187187
response=response,
188188
output_schema=None,
189-
handoffs=AgentRunner._get_handoffs(agent_3),
189+
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
190190
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
191191
)
192192
assert len(result.handoffs) == 1, "Should have a handoff here"
@@ -216,7 +216,7 @@ async def test_missing_handoff_fails():
216216
agent=agent_3,
217217
response=response,
218218
output_schema=None,
219-
handoffs=AgentRunner._get_handoffs(agent_3),
219+
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
220220
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
221221
)
222222

@@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error():
239239
agent=agent_3,
240240
response=response,
241241
output_schema=None,
242-
handoffs=AgentRunner._get_handoffs(agent_3),
242+
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
243243
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
244244
)
245245
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
@@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly():
471471
agent=agent_3,
472472
response=response,
473473
output_schema=None,
474-
handoffs=AgentRunner._get_handoffs(agent_3),
474+
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
475475
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
476476
)
477477
assert result.functions and len(result.functions) == 1

0 commit comments

Comments
 (0)