diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index 76c93a298..cb2752e4f 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -15,6 +15,7 @@ from .strict_schema import ensure_strict_json_schema from .tracing.spans import SpanError from .util import _error_tracing, _json, _transforms +from .util._types import MaybeAwaitable if TYPE_CHECKING: from .agent import Agent @@ -99,6 +100,11 @@ class Handoff(Generic[TContext]): True, as it increases the likelihood of correct JSON input. """ + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and + agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable + a handoff based on your context/state.""" + def get_transfer_message(self, agent: Agent[Any]) -> str: return json.dumps({"assistant": agent.name}) @@ -121,6 +127,7 @@ def handoff( tool_name_override: str | None = None, tool_description_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -133,6 +140,7 @@ def handoff( tool_description_override: str | None = None, tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -144,6 +152,7 @@ def handoff( tool_description_override: str | None = None, tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -154,6 +163,7 @@ def handoff( on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None, input_type: type[THandoffInput] | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: """Create a handoff from an agent. @@ -166,6 +176,9 @@ def handoff( input_type: the type of the input to the handoff. If provided, the input will be validated against this type. Only relevant if you pass a function that takes an input. input_filter: a function that filters the inputs that are passed to the next agent. + is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run + context and agent and returns whether the handoff is enabled. Disabled handoffs are + hidden from the LLM at runtime. """ assert (on_handoff and input_type) or not (on_handoff and input_type), ( "You must provide either both on_handoff and input_type, or neither" @@ -233,4 +246,5 @@ async def _invoke_handoff( on_invoke_handoff=_invoke_handoff, input_filter=input_filter, agent_name=agent.name, + is_enabled=is_enabled, ) diff --git a/src/agents/run.py b/src/agents/run.py index 8a44a0e54..e5f9378ec 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -2,6 +2,7 @@ import asyncio import copy +import inspect from dataclasses import dataclass, field from typing import Any, Generic, cast @@ -361,7 +362,8 @@ async def run( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [ - h.agent_name for h in AgentRunner._get_handoffs(current_agent) + h.agent_name + for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) ] if output_schema := AgentRunner._get_output_schema(current_agent): output_type_name = output_schema.name() @@ -641,7 +643,10 @@ async def _start_streaming( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] + handoff_names = [ + h.agent_name + for h in await cls._get_handoffs(current_agent, context_wrapper) + ] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.name() else: @@ -798,7 +803,7 @@ async def _run_single_turn_streamed( agent.get_prompt(context_wrapper), ) - handoffs = cls._get_handoffs(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) @@ -898,7 +903,7 @@ async def _run_single_turn( ) output_schema = cls._get_output_schema(agent) - handoffs = cls._get_handoffs(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) input = ItemHelpers.input_to_new_input_list(original_input) input.extend([generated_item.to_input_item() for generated_item in generated_items]) @@ -1091,14 +1096,28 @@ def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: return AgentOutputSchema(agent.output_type) @classmethod - def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: + async def _get_handoffs( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Handoff]: handoffs = [] for handoff_item in agent.handoffs: if isinstance(handoff_item, Handoff): handoffs.append(handoff_item) elif isinstance(handoff_item, Agent): handoffs.append(handoff(handoff_item)) - return handoffs + + async def _check_handoff_enabled(handoff_obj: Handoff) -> bool: + attr = handoff_obj.is_enabled + if isinstance(attr, bool): + return attr + res = attr(context_wrapper, agent) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) + enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok] + return enabled @classmethod async def _get_all_tools( diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py index f9423619d..a985fd60d 100644 --- a/tests/test_agent_config.py +++ b/tests/test_agent_config.py @@ -43,7 +43,7 @@ async def test_handoff_with_agents(): handoffs=[agent_1, agent_2], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" @@ -78,7 +78,7 @@ async def test_handoff_with_handoff_obj(): ], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" @@ -112,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent(): handoffs=[handoff(agent_1), agent_2], ) - handoffs = AgentRunner._get_handoffs(agent_3) + handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" diff --git a/tests/test_handoff_tool.py b/tests/test_handoff_tool.py index a1b5b80ba..0f7fc2166 100644 --- a/tests/test_handoff_tool.py +++ b/tests/test_handoff_tool.py @@ -38,16 +38,17 @@ def get_len(data: HandoffInputData) -> int: return input_len + pre_handoff_len + new_items_len -def test_single_handoff_setup(): +@pytest.mark.asyncio +async def test_single_handoff_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2", handoffs=[agent_1]) assert not agent_1.handoffs assert agent_2.handoffs == [agent_1] - assert not AgentRunner._get_handoffs(agent_1) + assert not (await AgentRunner._get_handoffs(agent_1, RunContextWrapper(agent_1))) - handoff_objects = AgentRunner._get_handoffs(agent_2) + handoff_objects = await AgentRunner._get_handoffs(agent_2, RunContextWrapper(agent_2)) assert len(handoff_objects) == 1 obj = handoff_objects[0] assert obj.tool_name == Handoff.default_tool_name(agent_1) @@ -55,7 +56,8 @@ def test_single_handoff_setup(): assert obj.agent_name == agent_1.name -def test_multiple_handoffs_setup(): +@pytest.mark.asyncio +async def test_multiple_handoffs_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2]) @@ -64,7 +66,7 @@ def test_multiple_handoffs_setup(): assert not agent_1.handoffs assert not agent_2.handoffs - handoff_objects = AgentRunner._get_handoffs(agent_3) + handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3)) assert len(handoff_objects) == 2 assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1) assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2) @@ -76,7 +78,8 @@ def test_multiple_handoffs_setup(): assert handoff_objects[1].agent_name == agent_2.name -def test_custom_handoff_setup(): +@pytest.mark.asyncio +async def test_custom_handoff_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent( @@ -95,7 +98,7 @@ def test_custom_handoff_setup(): assert not agent_1.handoffs assert not agent_2.handoffs - handoff_objects = AgentRunner._get_handoffs(agent_3) + handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3)) assert len(handoff_objects) == 2 first_handoff = handoff_objects[0] @@ -284,3 +287,86 @@ def test_get_transfer_message_is_valid_json() -> None: obj = handoff(agent) transfer = obj.get_transfer_message(agent) assert json.loads(transfer) == {"assistant": agent.name} + + +def test_handoff_is_enabled_bool(): + """Test that handoff respects is_enabled boolean parameter.""" + agent = Agent(name="test") + + # Test enabled handoff (default) + handoff_enabled = handoff(agent) + assert handoff_enabled.is_enabled is True + + # Test explicitly enabled handoff + handoff_explicit_enabled = handoff(agent, is_enabled=True) + assert handoff_explicit_enabled.is_enabled is True + + # Test disabled handoff + handoff_disabled = handoff(agent, is_enabled=False) + assert handoff_disabled.is_enabled is False + + +@pytest.mark.asyncio +async def test_handoff_is_enabled_callable(): + """Test that handoff respects is_enabled callable parameter.""" + agent = Agent(name="test") + + # Test callable that returns True + def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return True + + handoff_callable_enabled = handoff(agent, is_enabled=always_enabled) + assert callable(handoff_callable_enabled.is_enabled) + result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent) + assert result is True + + # Test callable that returns False + def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return False + + handoff_callable_disabled = handoff(agent, is_enabled=always_disabled) + assert callable(handoff_callable_disabled.is_enabled) + result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent) + assert result is False + + # Test async callable + async def async_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: + return True + + handoff_async_enabled = handoff(agent, is_enabled=async_enabled) + assert callable(handoff_async_enabled.is_enabled) + result = await handoff_async_enabled.is_enabled(RunContextWrapper(agent), agent) # type: ignore + assert result is True + + +@pytest.mark.asyncio +async def test_handoff_is_enabled_filtering_integration(): + """Integration test that disabled handoffs are filtered out by the runner.""" + + # Set up agents + agent_1 = Agent(name="agent_1") + agent_2 = Agent(name="agent_2") + agent_3 = Agent(name="agent_3") + + # Create main agent with mixed enabled/disabled handoffs + main_agent = Agent( + name="main_agent", + handoffs=[ + handoff(agent_1, is_enabled=True), # enabled + handoff(agent_2, is_enabled=False), # disabled + handoff(agent_3, is_enabled=lambda ctx, agent: True), # enabled callable + ], + ) + + context_wrapper = RunContextWrapper(main_agent) + + # Get filtered handoffs using the runner's method + filtered_handoffs = await AgentRunner._get_handoffs(main_agent, context_wrapper) + + # Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out + assert len(filtered_handoffs) == 2 + + # Check that the correct agents are present + agent_names = {h.agent_name for h in filtered_handoffs} + assert agent_names == {"agent_1", "agent_3"} + assert "agent_2" not in agent_names diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 2454a4462..4cf9ae832 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -325,7 +325,7 @@ async def get_execute_result( run_config: RunConfig | None = None, ) -> SingleStepResult: output_schema = AgentRunner._get_output_schema(agent) - handoffs = AgentRunner._get_handoffs(agent) + handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None)) processed_response = RunImpl.process_model_response( agent=agent, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 5a75ec837..6a2904791 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 1, "Should have a handoff here" @@ -216,7 +216,7 @@ async def test_missing_handoff_fails(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) @@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly(): agent=agent_3, response=response, output_schema=None, - handoffs=AgentRunner._get_handoffs(agent_3), + handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert result.functions and len(result.functions) == 1