Skip to content

Commit 0ba33c4

Browse files
authored
Introduce AgentRunner class (#886)
## Summary - rename DefaultRunner to DefaultAgentRunner and extend new AgentRunner base - remove `set_default_runner`/`get_default_runner` helpers - rename abstract methods on AgentRunner (no underscore) - update tests and imports for new API ## Testing - `make format` - `make lint` - `make mypy` - `make tests` ------ https://chatgpt.com/codex/tasks/task_i_6851851acce8832099adc70d8197016c
1 parent 83251c3 commit 0ba33c4

File tree

4 files changed

+64
-106
lines changed

4 files changed

+64
-106
lines changed

src/agents/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from .models.openai_responses import OpenAIResponsesModel
4848
from .repl import run_demo_loop
4949
from .result import RunResult, RunResultStreaming
50-
from .run import DefaultRunner, RunConfig, Runner, set_default_runner
50+
from .run import AgentRunner, DefaultAgentRunner, RunConfig, Runner
5151
from .run_context import RunContextWrapper, TContext
5252
from .stream_events import (
5353
AgentUpdatedStreamEvent,
@@ -162,7 +162,8 @@ def enable_verbose_stdout_logging():
162162
"ToolsToFinalOutputFunction",
163163
"ToolsToFinalOutputResult",
164164
"Runner",
165-
"DefaultRunner",
165+
"AgentRunner",
166+
"DefaultAgentRunner",
166167
"run_demo_loop",
167168
"Model",
168169
"ModelProvider",
@@ -241,7 +242,6 @@ def enable_verbose_stdout_logging():
241242
"generation_span",
242243
"get_current_span",
243244
"get_current_trace",
244-
"get_default_runner",
245245
"guardrail_span",
246246
"handoff_span",
247247
"set_trace_processors",
@@ -270,7 +270,6 @@ def enable_verbose_stdout_logging():
270270
"set_default_openai_key",
271271
"set_default_openai_client",
272272
"set_default_openai_api",
273-
"set_default_runner",
274273
"set_tracing_export_api_key",
275274
"enable_verbose_stdout_logging",
276275
"gen_trace_id",

src/agents/run.py

Lines changed: 61 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import asyncio
55
import copy
66
from dataclasses import dataclass, field
7-
from typing import Any, cast
7+
from typing import Any, Generic, cast
88

99
from openai.types.responses import ResponseCompletedEvent
10+
from typing_extensions import TypedDict, Unpack
1011

1112
from ._run_impl import (
1213
AgentToolUseTracker,
@@ -47,23 +48,6 @@
4748
from .util import _coro, _error_tracing
4849

4950
DEFAULT_MAX_TURNS = 10
50-
DEFAULT_RUNNER: Runner = None # type: ignore
51-
# assigned at the end of the module initialization
52-
53-
54-
def set_default_runner(runner: Runner | None) -> None:
55-
"""
56-
Set the default runner to use for the agent run.
57-
"""
58-
global DEFAULT_RUNNER
59-
DEFAULT_RUNNER = runner or DefaultRunner()
60-
61-
def get_default_runner() -> Runner | None:
62-
"""
63-
Get the default runner to use for the agent run.
64-
"""
65-
global DEFAULT_RUNNER
66-
return DEFAULT_RUNNER
6751

6852

6953
@dataclass
@@ -125,48 +109,57 @@ class RunConfig:
125109
"""
126110

127111

128-
class Runner(abc.ABC):
112+
class AgentRunnerParams(TypedDict, Generic[TContext]):
113+
"""Arguments for ``AgentRunner`` methods."""
114+
115+
context: TContext | None
116+
"""The context for the run."""
117+
118+
max_turns: int
119+
"""The maximum number of turns to run for."""
120+
121+
hooks: RunHooks[TContext] | None
122+
"""Lifecycle hooks for the run."""
123+
124+
run_config: RunConfig | None
125+
"""Run configuration."""
126+
127+
previous_response_id: str | None
128+
"""The ID of the previous response, if any."""
129+
130+
131+
class AgentRunner(abc.ABC):
129132
@abc.abstractmethod
130-
async def _run_impl(
133+
async def run(
131134
self,
132135
starting_agent: Agent[TContext],
133136
input: str | list[TResponseInputItem],
134-
*,
135-
context: TContext | None = None,
136-
max_turns: int = DEFAULT_MAX_TURNS,
137-
hooks: RunHooks[TContext] | None = None,
138-
run_config: RunConfig | None = None,
139-
previous_response_id: str | None = None,
137+
**kwargs: Unpack[AgentRunnerParams[TContext]],
140138
) -> RunResult:
141139
pass
142140

143141
@abc.abstractmethod
144-
def _run_sync_impl(
142+
def run_sync(
145143
self,
146144
starting_agent: Agent[TContext],
147145
input: str | list[TResponseInputItem],
148-
*,
149-
context: TContext | None = None,
150-
max_turns: int = DEFAULT_MAX_TURNS,
151-
hooks: RunHooks[TContext] | None = None,
152-
run_config: RunConfig | None = None,
153-
previous_response_id: str | None = None,
146+
**kwargs: Unpack[AgentRunnerParams[TContext]],
154147
) -> RunResult:
155148
pass
156149

157150
@abc.abstractmethod
158-
def _run_streamed_impl(
151+
def run_streamed(
159152
self,
160153
starting_agent: Agent[TContext],
161154
input: str | list[TResponseInputItem],
162-
context: TContext | None = None,
163-
max_turns: int = DEFAULT_MAX_TURNS,
164-
hooks: RunHooks[TContext] | None = None,
165-
run_config: RunConfig | None = None,
166-
previous_response_id: str | None = None,
155+
**kwargs: Unpack[AgentRunnerParams[TContext]],
167156
) -> RunResultStreaming:
168157
pass
169158

159+
160+
class Runner:
161+
pass
162+
170163
@classmethod
171164
async def run(
172165
cls,
@@ -205,8 +198,8 @@ async def run(
205198
A run result containing all the inputs, guardrail results and the output of the last
206199
agent. Agents may perform handoffs, so we don't know the specific type of the output.
207200
"""
208-
runner = DEFAULT_RUNNER
209-
return await runner._run_impl(
201+
runner = DefaultAgentRunner()
202+
return await runner.run(
210203
starting_agent,
211204
input,
212205
context=context,
@@ -257,8 +250,8 @@ def run_sync(
257250
A run result containing all the inputs, guardrail results and the output of the last
258251
agent. Agents may perform handoffs, so we don't know the specific type of the output.
259252
"""
260-
runner = DEFAULT_RUNNER
261-
return runner._run_sync_impl(
253+
runner = DefaultAgentRunner()
254+
return runner.run_sync(
262255
starting_agent,
263256
input,
264257
context=context,
@@ -305,8 +298,8 @@ def run_streamed(
305298
Returns:
306299
A result object that contains data about the run, as well as a method to stream events.
307300
"""
308-
runner = DEFAULT_RUNNER
309-
return runner._run_streamed_impl(
301+
runner = DefaultAgentRunner()
302+
return runner.run_streamed(
310303
starting_agent,
311304
input,
312305
context=context,
@@ -316,7 +309,6 @@ def run_streamed(
316309
previous_response_id=previous_response_id,
317310
)
318311

319-
320312
@classmethod
321313
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
322314
if agent.output_type is None or agent.output_type is str:
@@ -353,18 +345,19 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
353345

354346
return run_config.model_provider.get_model(agent.model)
355347

356-
class DefaultRunner(Runner):
357-
async def _run_impl(
348+
349+
class DefaultAgentRunner(AgentRunner, Runner):
350+
async def run( # type: ignore[override]
358351
self,
359352
starting_agent: Agent[TContext],
360353
input: str | list[TResponseInputItem],
361-
*,
362-
context: TContext | None = None,
363-
max_turns: int = DEFAULT_MAX_TURNS,
364-
hooks: RunHooks[TContext] | None = None,
365-
run_config: RunConfig | None = None,
366-
previous_response_id: str | None = None,
354+
**kwargs: Unpack[AgentRunnerParams[TContext]],
367355
) -> RunResult:
356+
context = kwargs.get("context")
357+
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
358+
hooks = kwargs.get("hooks")
359+
run_config = kwargs.get("run_config")
360+
previous_response_id = kwargs.get("previous_response_id")
368361
if hooks is None:
369362
hooks = RunHooks[Any]()
370363
if run_config is None:
@@ -514,17 +507,17 @@ async def _run_impl(
514507
if current_span:
515508
current_span.finish(reset_current=True)
516509

517-
def _run_sync_impl(
510+
def run_sync( # type: ignore[override]
518511
self,
519512
starting_agent: Agent[TContext],
520513
input: str | list[TResponseInputItem],
521-
*,
522-
context: TContext | None = None,
523-
max_turns: int = DEFAULT_MAX_TURNS,
524-
hooks: RunHooks[TContext] | None = None,
525-
run_config: RunConfig | None = None,
526-
previous_response_id: str | None = None,
514+
**kwargs: Unpack[AgentRunnerParams[TContext]],
527515
) -> RunResult:
516+
context = kwargs.get("context")
517+
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
518+
hooks = kwargs.get("hooks")
519+
run_config = kwargs.get("run_config")
520+
previous_response_id = kwargs.get("previous_response_id")
528521
return asyncio.get_event_loop().run_until_complete(
529522
self.run(
530523
starting_agent,
@@ -537,16 +530,17 @@ def _run_sync_impl(
537530
)
538531
)
539532

540-
def _run_streamed_impl(
533+
def run_streamed( # type: ignore[override]
541534
self,
542535
starting_agent: Agent[TContext],
543536
input: str | list[TResponseInputItem],
544-
context: TContext | None = None,
545-
max_turns: int = DEFAULT_MAX_TURNS,
546-
hooks: RunHooks[TContext] | None = None,
547-
run_config: RunConfig | None = None,
548-
previous_response_id: str | None = None,
537+
**kwargs: Unpack[AgentRunnerParams[TContext]],
549538
) -> RunResultStreaming:
539+
context = kwargs.get("context")
540+
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
541+
hooks = kwargs.get("hooks")
542+
run_config = kwargs.get("run_config")
543+
previous_response_id = kwargs.get("previous_response_id")
550544
if hooks is None:
551545
hooks = RunHooks[Any]()
552546
if run_config is None:
@@ -1108,6 +1102,3 @@ async def _get_new_response(
11081102
context_wrapper.usage.add(new_response.usage)
11091103

11101104
return new_response
1111-
1112-
1113-
DEFAULT_RUNNER = DefaultRunner()

tests/conftest.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from agents.models import _openai_shared
66
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
77
from agents.models.openai_responses import OpenAIResponsesModel
8-
from agents.run import set_default_runner
98
from agents.tracing import set_trace_processors
109
from agents.tracing.setup import get_trace_provider
1110

@@ -34,11 +33,6 @@ def clear_openai_settings():
3433
_openai_shared._use_responses_by_default = True
3534

3635

37-
@pytest.fixture(autouse=True)
38-
def clear_default_runner():
39-
set_default_runner(None)
40-
41-
4236
# This fixture will run after all tests end
4337
@pytest.fixture(autouse=True, scope="session")
4438
def shutdown_trace_provider():

tests/test_run.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)