Skip to content

Commit 237aa59

Browse files
Revert "Add is_enabled to handoffs (openai#925)"
This reverts commit c1466c6.
1 parent 18cb55e commit 237aa59

File tree

6 files changed

+21
-140
lines changed

6 files changed

+21
-140
lines changed

src/agents/handoffs.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .strict_schema import ensure_strict_json_schema
1616
from .tracing.spans import SpanError
1717
from .util import _error_tracing, _json, _transforms
18-
from .util._types import MaybeAwaitable
1918

2019
if TYPE_CHECKING:
2120
from .agent import Agent
@@ -100,11 +99,6 @@ class Handoff(Generic[TContext]):
10099
True, as it increases the likelihood of correct JSON input.
101100
"""
102101

103-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
104-
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
105-
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
106-
a handoff based on your context/state."""
107-
108102
def get_transfer_message(self, agent: Agent[Any]) -> str:
109103
return json.dumps({"assistant": agent.name})
110104

@@ -127,7 +121,6 @@ def handoff(
127121
tool_name_override: str | None = None,
128122
tool_description_override: str | None = None,
129123
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
130-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
131124
) -> Handoff[TContext]: ...
132125

133126

@@ -140,7 +133,6 @@ def handoff(
140133
tool_description_override: str | None = None,
141134
tool_name_override: str | None = None,
142135
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
143-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
144136
) -> Handoff[TContext]: ...
145137

146138

@@ -152,7 +144,6 @@ def handoff(
152144
tool_description_override: str | None = None,
153145
tool_name_override: str | None = None,
154146
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
155-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
156147
) -> Handoff[TContext]: ...
157148

158149

@@ -163,7 +154,6 @@ def handoff(
163154
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
164155
input_type: type[THandoffInput] | None = None,
165156
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
166-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
167157
) -> Handoff[TContext]:
168158
"""Create a handoff from an agent.
169159
@@ -176,9 +166,6 @@ def handoff(
176166
input_type: the type of the input to the handoff. If provided, the input will be validated
177167
against this type. Only relevant if you pass a function that takes an input.
178168
input_filter: a function that filters the inputs that are passed to the next agent.
179-
is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run
180-
context and agent and returns whether the handoff is enabled. Disabled handoffs are
181-
hidden from the LLM at runtime.
182169
"""
183170
assert (on_handoff and input_type) or not (on_handoff and input_type), (
184171
"You must provide either both on_handoff and input_type, or neither"
@@ -246,5 +233,4 @@ async def _invoke_handoff(
246233
on_invoke_handoff=_invoke_handoff,
247234
input_filter=input_filter,
248235
agent_name=agent.name,
249-
is_enabled=is_enabled,
250236
)

src/agents/run.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import asyncio
44
import copy
5-
import inspect
65
from dataclasses import dataclass, field
76
from typing import Any, Generic, cast
87

@@ -362,8 +361,7 @@ async def run(
362361
# agent changes, or if the agent loop ends.
363362
if current_span is None:
364363
handoff_names = [
365-
h.agent_name
366-
for h in await AgentRunner._get_handoffs(current_agent, context_wrapper)
364+
h.agent_name for h in AgentRunner._get_handoffs(current_agent)
367365
]
368366
if output_schema := AgentRunner._get_output_schema(current_agent):
369367
output_type_name = output_schema.name()
@@ -643,10 +641,7 @@ async def _start_streaming(
643641
# Start an agent span if we don't have one. This span is ended if the current
644642
# agent changes, or if the agent loop ends.
645643
if current_span is None:
646-
handoff_names = [
647-
h.agent_name
648-
for h in await cls._get_handoffs(current_agent, context_wrapper)
649-
]
644+
handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)]
650645
if output_schema := cls._get_output_schema(current_agent):
651646
output_type_name = output_schema.name()
652647
else:
@@ -803,7 +798,7 @@ async def _run_single_turn_streamed(
803798
agent.get_prompt(context_wrapper),
804799
)
805800

806-
handoffs = await cls._get_handoffs(agent, context_wrapper)
801+
handoffs = cls._get_handoffs(agent)
807802
model = cls._get_model(agent, run_config)
808803
model_settings = agent.model_settings.resolve(run_config.model_settings)
809804
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
@@ -903,7 +898,7 @@ async def _run_single_turn(
903898
)
904899

905900
output_schema = cls._get_output_schema(agent)
906-
handoffs = await cls._get_handoffs(agent, context_wrapper)
901+
handoffs = cls._get_handoffs(agent)
907902
input = ItemHelpers.input_to_new_input_list(original_input)
908903
input.extend([generated_item.to_input_item() for generated_item in generated_items])
909904

@@ -1096,28 +1091,14 @@ def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
10961091
return AgentOutputSchema(agent.output_type)
10971092

10981093
@classmethod
1099-
async def _get_handoffs(
1100-
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
1101-
) -> list[Handoff]:
1094+
def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
11021095
handoffs = []
11031096
for handoff_item in agent.handoffs:
11041097
if isinstance(handoff_item, Handoff):
11051098
handoffs.append(handoff_item)
11061099
elif isinstance(handoff_item, Agent):
11071100
handoffs.append(handoff(handoff_item))
1108-
1109-
async def _check_handoff_enabled(handoff_obj: Handoff) -> bool:
1110-
attr = handoff_obj.is_enabled
1111-
if isinstance(attr, bool):
1112-
return attr
1113-
res = attr(context_wrapper, agent)
1114-
if inspect.isawaitable(res):
1115-
return bool(await res)
1116-
return bool(res)
1117-
1118-
results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs))
1119-
enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok]
1120-
return enabled
1101+
return handoffs
11211102

11221103
@classmethod
11231104
async def _get_all_tools(

tests/test_agent_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def test_handoff_with_agents():
4343
handoffs=[agent_1, agent_2],
4444
)
4545

46-
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
46+
handoffs = AgentRunner._get_handoffs(agent_3)
4747
assert len(handoffs) == 2
4848

4949
assert handoffs[0].agent_name == "agent_1"
@@ -78,7 +78,7 @@ async def test_handoff_with_handoff_obj():
7878
],
7979
)
8080

81-
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
81+
handoffs = AgentRunner._get_handoffs(agent_3)
8282
assert len(handoffs) == 2
8383

8484
assert handoffs[0].agent_name == "agent_1"
@@ -112,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent():
112112
handoffs=[handoff(agent_1), agent_2],
113113
)
114114

115-
handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None))
115+
handoffs = AgentRunner._get_handoffs(agent_3)
116116
assert len(handoffs) == 2
117117

118118
assert handoffs[0].agent_name == "agent_1"

tests/test_handoff_tool.py

Lines changed: 7 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,24 @@ def get_len(data: HandoffInputData) -> int:
3838
return input_len + pre_handoff_len + new_items_len
3939

4040

41-
@pytest.mark.asyncio
42-
async def test_single_handoff_setup():
41+
def test_single_handoff_setup():
4342
agent_1 = Agent(name="test_1")
4443
agent_2 = Agent(name="test_2", handoffs=[agent_1])
4544

4645
assert not agent_1.handoffs
4746
assert agent_2.handoffs == [agent_1]
4847

49-
assert not (await AgentRunner._get_handoffs(agent_1, RunContextWrapper(agent_1)))
48+
assert not AgentRunner._get_handoffs(agent_1)
5049

51-
handoff_objects = await AgentRunner._get_handoffs(agent_2, RunContextWrapper(agent_2))
50+
handoff_objects = AgentRunner._get_handoffs(agent_2)
5251
assert len(handoff_objects) == 1
5352
obj = handoff_objects[0]
5453
assert obj.tool_name == Handoff.default_tool_name(agent_1)
5554
assert obj.tool_description == Handoff.default_tool_description(agent_1)
5655
assert obj.agent_name == agent_1.name
5756

5857

59-
@pytest.mark.asyncio
60-
async def test_multiple_handoffs_setup():
58+
def test_multiple_handoffs_setup():
6159
agent_1 = Agent(name="test_1")
6260
agent_2 = Agent(name="test_2")
6361
agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2])
@@ -66,7 +64,7 @@ async def test_multiple_handoffs_setup():
6664
assert not agent_1.handoffs
6765
assert not agent_2.handoffs
6866

69-
handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3))
67+
handoff_objects = AgentRunner._get_handoffs(agent_3)
7068
assert len(handoff_objects) == 2
7169
assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1)
7270
assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2)
@@ -78,8 +76,7 @@ async def test_multiple_handoffs_setup():
7876
assert handoff_objects[1].agent_name == agent_2.name
7977

8078

81-
@pytest.mark.asyncio
82-
async def test_custom_handoff_setup():
79+
def test_custom_handoff_setup():
8380
agent_1 = Agent(name="test_1")
8481
agent_2 = Agent(name="test_2")
8582
agent_3 = Agent(
@@ -98,7 +95,7 @@ async def test_custom_handoff_setup():
9895
assert not agent_1.handoffs
9996
assert not agent_2.handoffs
10097

101-
handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3))
98+
handoff_objects = AgentRunner._get_handoffs(agent_3)
10299
assert len(handoff_objects) == 2
103100

104101
first_handoff = handoff_objects[0]
@@ -287,86 +284,3 @@ def test_get_transfer_message_is_valid_json() -> None:
287284
obj = handoff(agent)
288285
transfer = obj.get_transfer_message(agent)
289286
assert json.loads(transfer) == {"assistant": agent.name}
290-
291-
292-
def test_handoff_is_enabled_bool():
293-
"""Test that handoff respects is_enabled boolean parameter."""
294-
agent = Agent(name="test")
295-
296-
# Test enabled handoff (default)
297-
handoff_enabled = handoff(agent)
298-
assert handoff_enabled.is_enabled is True
299-
300-
# Test explicitly enabled handoff
301-
handoff_explicit_enabled = handoff(agent, is_enabled=True)
302-
assert handoff_explicit_enabled.is_enabled is True
303-
304-
# Test disabled handoff
305-
handoff_disabled = handoff(agent, is_enabled=False)
306-
assert handoff_disabled.is_enabled is False
307-
308-
309-
@pytest.mark.asyncio
310-
async def test_handoff_is_enabled_callable():
311-
"""Test that handoff respects is_enabled callable parameter."""
312-
agent = Agent(name="test")
313-
314-
# Test callable that returns True
315-
def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
316-
return True
317-
318-
handoff_callable_enabled = handoff(agent, is_enabled=always_enabled)
319-
assert callable(handoff_callable_enabled.is_enabled)
320-
result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent)
321-
assert result is True
322-
323-
# Test callable that returns False
324-
def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
325-
return False
326-
327-
handoff_callable_disabled = handoff(agent, is_enabled=always_disabled)
328-
assert callable(handoff_callable_disabled.is_enabled)
329-
result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent)
330-
assert result is False
331-
332-
# Test async callable
333-
async def async_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool:
334-
return True
335-
336-
handoff_async_enabled = handoff(agent, is_enabled=async_enabled)
337-
assert callable(handoff_async_enabled.is_enabled)
338-
result = await handoff_async_enabled.is_enabled(RunContextWrapper(agent), agent) # type: ignore
339-
assert result is True
340-
341-
342-
@pytest.mark.asyncio
343-
async def test_handoff_is_enabled_filtering_integration():
344-
"""Integration test that disabled handoffs are filtered out by the runner."""
345-
346-
# Set up agents
347-
agent_1 = Agent(name="agent_1")
348-
agent_2 = Agent(name="agent_2")
349-
agent_3 = Agent(name="agent_3")
350-
351-
# Create main agent with mixed enabled/disabled handoffs
352-
main_agent = Agent(
353-
name="main_agent",
354-
handoffs=[
355-
handoff(agent_1, is_enabled=True), # enabled
356-
handoff(agent_2, is_enabled=False), # disabled
357-
handoff(agent_3, is_enabled=lambda ctx, agent: True), # enabled callable
358-
],
359-
)
360-
361-
context_wrapper = RunContextWrapper(main_agent)
362-
363-
# Get filtered handoffs using the runner's method
364-
filtered_handoffs = await AgentRunner._get_handoffs(main_agent, context_wrapper)
365-
366-
# Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out
367-
assert len(filtered_handoffs) == 2
368-
369-
# Check that the correct agents are present
370-
agent_names = {h.agent_name for h in filtered_handoffs}
371-
assert agent_names == {"agent_1", "agent_3"}
372-
assert "agent_2" not in agent_names

tests/test_run_step_execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ async def get_execute_result(
325325
run_config: RunConfig | None = None,
326326
) -> SingleStepResult:
327327
output_schema = AgentRunner._get_output_schema(agent)
328-
handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None))
328+
handoffs = AgentRunner._get_handoffs(agent)
329329

330330
processed_response = RunImpl.process_model_response(
331331
agent=agent,

tests/test_run_step_processing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly():
186186
agent=agent_3,
187187
response=response,
188188
output_schema=None,
189-
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
189+
handoffs=AgentRunner._get_handoffs(agent_3),
190190
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
191191
)
192192
assert len(result.handoffs) == 1, "Should have a handoff here"
@@ -216,7 +216,7 @@ async def test_missing_handoff_fails():
216216
agent=agent_3,
217217
response=response,
218218
output_schema=None,
219-
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
219+
handoffs=AgentRunner._get_handoffs(agent_3),
220220
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
221221
)
222222

@@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error():
239239
agent=agent_3,
240240
response=response,
241241
output_schema=None,
242-
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
242+
handoffs=AgentRunner._get_handoffs(agent_3),
243243
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
244244
)
245245
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
@@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly():
471471
agent=agent_3,
472472
response=response,
473473
output_schema=None,
474-
handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()),
474+
handoffs=AgentRunner._get_handoffs(agent_3),
475475
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
476476
)
477477
assert result.functions and len(result.functions) == 1

0 commit comments

Comments
 (0)