4
4
import asyncio
5
5
import copy
6
6
from dataclasses import dataclass , field
7
- from typing import Any , cast
7
+ from typing import Any , Generic , cast
8
8
9
9
from openai .types .responses import ResponseCompletedEvent
10
+ from typing_extensions import TypedDict , Unpack
10
11
11
12
from ._run_impl import (
12
13
AgentToolUseTracker ,
47
48
from .util import _coro , _error_tracing
48
49
49
50
DEFAULT_MAX_TURNS = 10
50
- DEFAULT_RUNNER : Runner = None # type: ignore
51
- # assigned at the end of the module initialization
52
-
53
-
54
- def set_default_runner (runner : Runner | None ) -> None :
55
- """
56
- Set the default runner to use for the agent run.
57
- """
58
- global DEFAULT_RUNNER
59
- DEFAULT_RUNNER = runner or DefaultRunner ()
60
-
61
- def get_default_runner () -> Runner | None :
62
- """
63
- Get the default runner to use for the agent run.
64
- """
65
- global DEFAULT_RUNNER
66
- return DEFAULT_RUNNER
67
51
68
52
69
53
@dataclass
@@ -125,48 +109,57 @@ class RunConfig:
125
109
"""
126
110
127
111
128
- class Runner (abc .ABC ):
112
+ class AgentRunnerParams (TypedDict , Generic [TContext ]):
113
+ """Arguments for ``AgentRunner`` methods."""
114
+
115
+ context : TContext | None
116
+ """The context for the run."""
117
+
118
+ max_turns : int
119
+ """The maximum number of turns to run for."""
120
+
121
+ hooks : RunHooks [TContext ] | None
122
+ """Lifecycle hooks for the run."""
123
+
124
+ run_config : RunConfig | None
125
+ """Run configuration."""
126
+
127
+ previous_response_id : str | None
128
+ """The ID of the previous response, if any."""
129
+
130
+
131
+ class AgentRunner (abc .ABC ):
129
132
@abc .abstractmethod
130
- async def _run_impl (
133
+ async def run (
131
134
self ,
132
135
starting_agent : Agent [TContext ],
133
136
input : str | list [TResponseInputItem ],
134
- * ,
135
- context : TContext | None = None ,
136
- max_turns : int = DEFAULT_MAX_TURNS ,
137
- hooks : RunHooks [TContext ] | None = None ,
138
- run_config : RunConfig | None = None ,
139
- previous_response_id : str | None = None ,
137
+ ** kwargs : Unpack [AgentRunnerParams [TContext ]],
140
138
) -> RunResult :
141
139
pass
142
140
143
141
@abc .abstractmethod
144
- def _run_sync_impl (
142
+ def run_sync (
145
143
self ,
146
144
starting_agent : Agent [TContext ],
147
145
input : str | list [TResponseInputItem ],
148
- * ,
149
- context : TContext | None = None ,
150
- max_turns : int = DEFAULT_MAX_TURNS ,
151
- hooks : RunHooks [TContext ] | None = None ,
152
- run_config : RunConfig | None = None ,
153
- previous_response_id : str | None = None ,
146
+ ** kwargs : Unpack [AgentRunnerParams [TContext ]],
154
147
) -> RunResult :
155
148
pass
156
149
157
150
@abc .abstractmethod
158
- def _run_streamed_impl (
151
+ def run_streamed (
159
152
self ,
160
153
starting_agent : Agent [TContext ],
161
154
input : str | list [TResponseInputItem ],
162
- context : TContext | None = None ,
163
- max_turns : int = DEFAULT_MAX_TURNS ,
164
- hooks : RunHooks [TContext ] | None = None ,
165
- run_config : RunConfig | None = None ,
166
- previous_response_id : str | None = None ,
155
+ ** kwargs : Unpack [AgentRunnerParams [TContext ]],
167
156
) -> RunResultStreaming :
168
157
pass
169
158
159
+
160
+ class Runner :
161
+ pass
162
+
170
163
@classmethod
171
164
async def run (
172
165
cls ,
@@ -205,8 +198,8 @@ async def run(
205
198
A run result containing all the inputs, guardrail results and the output of the last
206
199
agent. Agents may perform handoffs, so we don't know the specific type of the output.
207
200
"""
208
- runner = DEFAULT_RUNNER
209
- return await runner ._run_impl (
201
+ runner = DefaultAgentRunner ()
202
+ return await runner .run (
210
203
starting_agent ,
211
204
input ,
212
205
context = context ,
@@ -257,8 +250,8 @@ def run_sync(
257
250
A run result containing all the inputs, guardrail results and the output of the last
258
251
agent. Agents may perform handoffs, so we don't know the specific type of the output.
259
252
"""
260
- runner = DEFAULT_RUNNER
261
- return runner ._run_sync_impl (
253
+ runner = DefaultAgentRunner ()
254
+ return runner .run_sync (
262
255
starting_agent ,
263
256
input ,
264
257
context = context ,
@@ -305,8 +298,8 @@ def run_streamed(
305
298
Returns:
306
299
A result object that contains data about the run, as well as a method to stream events.
307
300
"""
308
- runner = DEFAULT_RUNNER
309
- return runner ._run_streamed_impl (
301
+ runner = DefaultAgentRunner ()
302
+ return runner .run_streamed (
310
303
starting_agent ,
311
304
input ,
312
305
context = context ,
@@ -316,7 +309,6 @@ def run_streamed(
316
309
previous_response_id = previous_response_id ,
317
310
)
318
311
319
-
320
312
@classmethod
321
313
def _get_output_schema (cls , agent : Agent [Any ]) -> AgentOutputSchemaBase | None :
322
314
if agent .output_type is None or agent .output_type is str :
@@ -353,18 +345,19 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
353
345
354
346
return run_config .model_provider .get_model (agent .model )
355
347
356
- class DefaultRunner (Runner ):
357
- async def _run_impl (
348
+
349
+ class DefaultAgentRunner (AgentRunner , Runner ):
350
+ async def run ( # type: ignore[override]
358
351
self ,
359
352
starting_agent : Agent [TContext ],
360
353
input : str | list [TResponseInputItem ],
361
- * ,
362
- context : TContext | None = None ,
363
- max_turns : int = DEFAULT_MAX_TURNS ,
364
- hooks : RunHooks [TContext ] | None = None ,
365
- run_config : RunConfig | None = None ,
366
- previous_response_id : str | None = None ,
354
+ ** kwargs : Unpack [AgentRunnerParams [TContext ]],
367
355
) -> RunResult :
356
+ context = kwargs .get ("context" )
357
+ max_turns = kwargs .get ("max_turns" , DEFAULT_MAX_TURNS )
358
+ hooks = kwargs .get ("hooks" )
359
+ run_config = kwargs .get ("run_config" )
360
+ previous_response_id = kwargs .get ("previous_response_id" )
368
361
if hooks is None :
369
362
hooks = RunHooks [Any ]()
370
363
if run_config is None :
@@ -514,17 +507,17 @@ async def _run_impl(
514
507
if current_span :
515
508
current_span .finish (reset_current = True )
516
509
517
- def _run_sync_impl (
510
+ def run_sync ( # type: ignore[override]
518
511
self ,
519
512
starting_agent : Agent [TContext ],
520
513
input : str | list [TResponseInputItem ],
521
- * ,
522
- context : TContext | None = None ,
523
- max_turns : int = DEFAULT_MAX_TURNS ,
524
- hooks : RunHooks [TContext ] | None = None ,
525
- run_config : RunConfig | None = None ,
526
- previous_response_id : str | None = None ,
514
+ ** kwargs : Unpack [AgentRunnerParams [TContext ]],
527
515
) -> RunResult :
516
+ context = kwargs .get ("context" )
517
+ max_turns = kwargs .get ("max_turns" , DEFAULT_MAX_TURNS )
518
+ hooks = kwargs .get ("hooks" )
519
+ run_config = kwargs .get ("run_config" )
520
+ previous_response_id = kwargs .get ("previous_response_id" )
528
521
return asyncio .get_event_loop ().run_until_complete (
529
522
self .run (
530
523
starting_agent ,
@@ -537,16 +530,17 @@ def _run_sync_impl(
537
530
)
538
531
)
539
532
540
- def _run_streamed_impl (
533
+ def run_streamed ( # type: ignore[override]
541
534
self ,
542
535
starting_agent : Agent [TContext ],
543
536
input : str | list [TResponseInputItem ],
544
- context : TContext | None = None ,
545
- max_turns : int = DEFAULT_MAX_TURNS ,
546
- hooks : RunHooks [TContext ] | None = None ,
547
- run_config : RunConfig | None = None ,
548
- previous_response_id : str | None = None ,
537
+ ** kwargs : Unpack [AgentRunnerParams [TContext ]],
549
538
) -> RunResultStreaming :
539
+ context = kwargs .get ("context" )
540
+ max_turns = kwargs .get ("max_turns" , DEFAULT_MAX_TURNS )
541
+ hooks = kwargs .get ("hooks" )
542
+ run_config = kwargs .get ("run_config" )
543
+ previous_response_id = kwargs .get ("previous_response_id" )
550
544
if hooks is None :
551
545
hooks = RunHooks [Any ]()
552
546
if run_config is None :
@@ -1108,6 +1102,3 @@ async def _get_new_response(
1108
1102
context_wrapper .usage .add (new_response .usage )
1109
1103
1110
1104
return new_response
1111
-
1112
-
1113
- DEFAULT_RUNNER = DefaultRunner ()
0 commit comments