From 27ccf18d460118b404dee082fb9a6a1d5a3d8092 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Tue, 17 Jun 2025 08:30:43 -0700 Subject: [PATCH 1/2] Refine AgentRunner API --- src/agents/__init__.py | 7 +- src/agents/run.py | 160 ++++++++++++++++++++--------------------- tests/conftest.py | 4 +- tests/test_run.py | 13 ++-- 4 files changed, 89 insertions(+), 95 deletions(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index afa578b5e..d2e0857e5 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -47,7 +47,7 @@ from .models.openai_responses import OpenAIResponsesModel from .repl import run_demo_loop from .result import RunResult, RunResultStreaming -from .run import DefaultRunner, RunConfig, Runner, set_default_runner +from .run import AgentRunner, DefaultAgentRunner, RunConfig, Runner from .run_context import RunContextWrapper, TContext from .stream_events import ( AgentUpdatedStreamEvent, @@ -162,7 +162,8 @@ def enable_verbose_stdout_logging(): "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", - "DefaultRunner", + "AgentRunner", + "DefaultAgentRunner", "run_demo_loop", "Model", "ModelProvider", @@ -241,7 +242,6 @@ def enable_verbose_stdout_logging(): "generation_span", "get_current_span", "get_current_trace", - "get_default_runner", "guardrail_span", "handoff_span", "set_trace_processors", @@ -270,7 +270,6 @@ def enable_verbose_stdout_logging(): "set_default_openai_key", "set_default_openai_client", "set_default_openai_api", - "set_default_runner", "set_tracing_export_api_key", "enable_verbose_stdout_logging", "gen_trace_id", diff --git a/src/agents/run.py b/src/agents/run.py index 1c301cb00..eb6aa8c9e 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -4,9 +4,10 @@ import asyncio import copy from dataclasses import dataclass, field -from typing import Any, cast +from typing import Any, Generic, cast from openai.types.responses import ResponseCompletedEvent +from typing_extensions import TypedDict, Unpack from ._run_impl import ( AgentToolUseTracker, @@ -47,23 +48,8 @@ from .util import _coro, _error_tracing DEFAULT_MAX_TURNS = 10 -DEFAULT_RUNNER: Runner = None # type: ignore -# assigned at the end of the module initialization - -def set_default_runner(runner: Runner | None) -> None: - """ - Set the default runner to use for the agent run. - """ - global DEFAULT_RUNNER - DEFAULT_RUNNER = runner or DefaultRunner() - -def get_default_runner() -> Runner | None: - """ - Get the default runner to use for the agent run. - """ - global DEFAULT_RUNNER - return DEFAULT_RUNNER +DEFAULT_RUNNER: AgentRunner = None # type: ignore @dataclass @@ -125,48 +111,57 @@ class RunConfig: """ -class Runner(abc.ABC): +class AgentRunnerParams(TypedDict, Generic[TContext]): + """Arguments for ``AgentRunner`` methods.""" + + starting_agent: Agent[TContext] + """The starting agent to run.""" + + input: str | list[TResponseInputItem] + """The initial input passed to the agent.""" + + context: TContext | None + """The context for the run.""" + + max_turns: int + """The maximum number of turns to run for.""" + + hooks: RunHooks[TContext] | None + """Lifecycle hooks for the run.""" + + run_config: RunConfig | None + """Run configuration.""" + + previous_response_id: str | None + """The ID of the previous response, if any.""" + + +class AgentRunner(abc.ABC): @abc.abstractmethod - async def _run_impl( + async def run_impl( self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: pass @abc.abstractmethod - def _run_sync_impl( + def run_sync_impl( self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: pass @abc.abstractmethod - def _run_streamed_impl( + def run_streamed_impl( self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResultStreaming: pass + +class Runner: + pass + @classmethod async def run( cls, @@ -206,9 +201,9 @@ async def run( agent. Agents may perform handoffs, so we don't know the specific type of the output. """ runner = DEFAULT_RUNNER - return await runner._run_impl( - starting_agent, - input, + return await runner.run_impl( + starting_agent=starting_agent, + input=input, context=context, max_turns=max_turns, hooks=hooks, @@ -258,9 +253,9 @@ def run_sync( agent. Agents may perform handoffs, so we don't know the specific type of the output. """ runner = DEFAULT_RUNNER - return runner._run_sync_impl( - starting_agent, - input, + return runner.run_sync_impl( + starting_agent=starting_agent, + input=input, context=context, max_turns=max_turns, hooks=hooks, @@ -306,9 +301,9 @@ def run_streamed( A result object that contains data about the run, as well as a method to stream events. """ runner = DEFAULT_RUNNER - return runner._run_streamed_impl( - starting_agent, - input, + return runner.run_streamed_impl( + starting_agent=starting_agent, + input=input, context=context, max_turns=max_turns, hooks=hooks, @@ -316,7 +311,6 @@ def run_streamed( previous_response_id=previous_response_id, ) - @classmethod def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: if agent.output_type is None or agent.output_type is str: @@ -353,18 +347,19 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return run_config.model_provider.get_model(agent.model) -class DefaultRunner(Runner): - async def _run_impl( + +class DefaultAgentRunner(AgentRunner, Runner): + async def run_impl( self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: + starting_agent = kwargs["starting_agent"] + input = kwargs["input"] + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = kwargs.get("hooks") + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -514,17 +509,17 @@ async def _run_impl( if current_span: current_span.finish(reset_current=True) - def _run_sync_impl( + def run_sync_impl( self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: + starting_agent = kwargs["starting_agent"] + input = kwargs["input"] + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = kwargs.get("hooks") + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") return asyncio.get_event_loop().run_until_complete( self.run( starting_agent, @@ -537,16 +532,17 @@ def _run_sync_impl( ) ) - def _run_streamed_impl( + def run_streamed_impl( self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, + **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResultStreaming: + starting_agent = kwargs["starting_agent"] + input = kwargs["input"] + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = kwargs.get("hooks") + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -1110,4 +1106,4 @@ async def _get_new_response( return new_response -DEFAULT_RUNNER = DefaultRunner() +DEFAULT_RUNNER = DefaultAgentRunner() diff --git a/tests/conftest.py b/tests/conftest.py index f87e85594..beec12e0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,10 @@ import pytest +from agents import run as run_module from agents.models import _openai_shared from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel from agents.models.openai_responses import OpenAIResponsesModel -from agents.run import set_default_runner from agents.tracing import set_trace_processors from agents.tracing.setup import get_trace_provider @@ -36,7 +36,7 @@ def clear_openai_settings(): @pytest.fixture(autouse=True) def clear_default_runner(): - set_default_runner(None) + run_module.DEFAULT_RUNNER = run_module.DefaultAgentRunner() # This fixture will run after all tests end diff --git a/tests/test_run.py b/tests/test_run.py index 57e33d50d..3dbe71ea4 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -4,23 +4,22 @@ import pytest -from agents import Agent, Runner -from agents.run import set_default_runner +from agents import Agent, AgentRunner, Runner, run as run_module from .fake_model import FakeModel @pytest.mark.asyncio async def test_static_run_methods_call_into_default_runner() -> None: - runner = mock.Mock(spec=Runner) - set_default_runner(runner) + runner = mock.Mock(spec=AgentRunner) + run_module.DEFAULT_RUNNER = runner agent = Agent(name="test", model=FakeModel()) await Runner.run(agent, input="test") - runner._run_impl.assert_called_once() + runner.run_impl.assert_called_once() Runner.run_streamed(agent, input="test") - runner._run_streamed_impl.assert_called_once() + runner.run_streamed_impl.assert_called_once() Runner.run_sync(agent, input="test") - runner._run_sync_impl.assert_called_once() + runner.run_sync_impl.assert_called_once() From 798a94132a3323edf729b845830a91a0eb74a210 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Tue, 17 Jun 2025 08:46:56 -0700 Subject: [PATCH 2/2] Update AgentRunner interface --- src/agents/run.py | 65 ++++++++++++++++++++++------------------------- tests/conftest.py | 6 ----- tests/test_run.py | 25 ------------------ 3 files changed, 30 insertions(+), 66 deletions(-) delete mode 100644 tests/test_run.py diff --git a/src/agents/run.py b/src/agents/run.py index eb6aa8c9e..ce48d1dc7 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -49,8 +49,6 @@ DEFAULT_MAX_TURNS = 10 -DEFAULT_RUNNER: AgentRunner = None # type: ignore - @dataclass class RunConfig: @@ -114,12 +112,6 @@ class RunConfig: class AgentRunnerParams(TypedDict, Generic[TContext]): """Arguments for ``AgentRunner`` methods.""" - starting_agent: Agent[TContext] - """The starting agent to run.""" - - input: str | list[TResponseInputItem] - """The initial input passed to the agent.""" - context: TContext | None """The context for the run.""" @@ -138,22 +130,28 @@ class AgentRunnerParams(TypedDict, Generic[TContext]): class AgentRunner(abc.ABC): @abc.abstractmethod - async def run_impl( + async def run( self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: pass @abc.abstractmethod - def run_sync_impl( + def run_sync( self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: pass @abc.abstractmethod - def run_streamed_impl( + def run_streamed( self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResultStreaming: pass @@ -200,10 +198,10 @@ async def run( A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. """ - runner = DEFAULT_RUNNER - return await runner.run_impl( - starting_agent=starting_agent, - input=input, + runner = DefaultAgentRunner() + return await runner.run( + starting_agent, + input, context=context, max_turns=max_turns, hooks=hooks, @@ -252,10 +250,10 @@ def run_sync( A run result containing all the inputs, guardrail results and the output of the last agent. Agents may perform handoffs, so we don't know the specific type of the output. """ - runner = DEFAULT_RUNNER - return runner.run_sync_impl( - starting_agent=starting_agent, - input=input, + runner = DefaultAgentRunner() + return runner.run_sync( + starting_agent, + input, context=context, max_turns=max_turns, hooks=hooks, @@ -300,10 +298,10 @@ def run_streamed( Returns: A result object that contains data about the run, as well as a method to stream events. """ - runner = DEFAULT_RUNNER - return runner.run_streamed_impl( - starting_agent=starting_agent, - input=input, + runner = DefaultAgentRunner() + return runner.run_streamed( + starting_agent, + input, context=context, max_turns=max_turns, hooks=hooks, @@ -349,12 +347,12 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: class DefaultAgentRunner(AgentRunner, Runner): - async def run_impl( + async def run( # type: ignore[override] self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: - starting_agent = kwargs["starting_agent"] - input = kwargs["input"] context = kwargs.get("context") max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) hooks = kwargs.get("hooks") @@ -509,12 +507,12 @@ async def run_impl( if current_span: current_span.finish(reset_current=True) - def run_sync_impl( + def run_sync( # type: ignore[override] self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResult: - starting_agent = kwargs["starting_agent"] - input = kwargs["input"] context = kwargs.get("context") max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) hooks = kwargs.get("hooks") @@ -532,12 +530,12 @@ def run_sync_impl( ) ) - def run_streamed_impl( + def run_streamed( # type: ignore[override] self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], **kwargs: Unpack[AgentRunnerParams[TContext]], ) -> RunResultStreaming: - starting_agent = kwargs["starting_agent"] - input = kwargs["input"] context = kwargs.get("context") max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) hooks = kwargs.get("hooks") @@ -1104,6 +1102,3 @@ async def _get_new_response( context_wrapper.usage.add(new_response.usage) return new_response - - -DEFAULT_RUNNER = DefaultAgentRunner() diff --git a/tests/conftest.py b/tests/conftest.py index beec12e0c..7527e11b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ import pytest -from agents import run as run_module from agents.models import _openai_shared from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel from agents.models.openai_responses import OpenAIResponsesModel @@ -34,11 +33,6 @@ def clear_openai_settings(): _openai_shared._use_responses_by_default = True -@pytest.fixture(autouse=True) -def clear_default_runner(): - run_module.DEFAULT_RUNNER = run_module.DefaultAgentRunner() - - # This fixture will run after all tests end @pytest.fixture(autouse=True, scope="session") def shutdown_trace_provider(): diff --git a/tests/test_run.py b/tests/test_run.py deleted file mode 100644 index 3dbe71ea4..000000000 --- a/tests/test_run.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from unittest import mock - -import pytest - -from agents import Agent, AgentRunner, Runner, run as run_module - -from .fake_model import FakeModel - - -@pytest.mark.asyncio -async def test_static_run_methods_call_into_default_runner() -> None: - runner = mock.Mock(spec=AgentRunner) - run_module.DEFAULT_RUNNER = runner - - agent = Agent(name="test", model=FakeModel()) - await Runner.run(agent, input="test") - runner.run_impl.assert_called_once() - - Runner.run_streamed(agent, input="test") - runner.run_streamed_impl.assert_called_once() - - Runner.run_sync(agent, input="test") - runner.run_sync_impl.assert_called_once()