Skip to content

Commit 83251c3

Browse files
committed
move utility methods (_get_output_schema, _get_handoffs, _get_all_tools, _get_model) from DefaultRunner to Runner base class
1 parent b25beb6 commit 83251c3

File tree

1 file changed

+37
-35
lines changed

1 file changed

+37
-35
lines changed

src/agents/run.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,43 @@ def run_streamed(
316316
previous_response_id=previous_response_id,
317317
)
318318

319+
320+
@classmethod
321+
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
322+
if agent.output_type is None or agent.output_type is str:
323+
return None
324+
elif isinstance(agent.output_type, AgentOutputSchemaBase):
325+
return agent.output_type
326+
327+
return AgentOutputSchema(agent.output_type)
328+
329+
@classmethod
330+
def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
331+
handoffs = []
332+
for handoff_item in agent.handoffs:
333+
if isinstance(handoff_item, Handoff):
334+
handoffs.append(handoff_item)
335+
elif isinstance(handoff_item, Agent):
336+
handoffs.append(handoff(handoff_item))
337+
return handoffs
338+
339+
@classmethod
340+
async def _get_all_tools(
341+
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
342+
) -> list[Tool]:
343+
return await agent.get_all_tools(context_wrapper)
344+
345+
@classmethod
346+
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
347+
if isinstance(run_config.model, Model):
348+
return run_config.model
349+
elif isinstance(run_config.model, str):
350+
return run_config.model_provider.get_model(run_config.model)
351+
elif isinstance(agent.model, Model):
352+
return agent.model
353+
354+
return run_config.model_provider.get_model(agent.model)
355+
319356
class DefaultRunner(Runner):
320357
async def _run_impl(
321358
self,
@@ -1072,40 +1109,5 @@ async def _get_new_response(
10721109

10731110
return new_response
10741111

1075-
@classmethod
1076-
def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None:
1077-
if agent.output_type is None or agent.output_type is str:
1078-
return None
1079-
elif isinstance(agent.output_type, AgentOutputSchemaBase):
1080-
return agent.output_type
1081-
1082-
return AgentOutputSchema(agent.output_type)
1083-
1084-
@classmethod
1085-
def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
1086-
handoffs = []
1087-
for handoff_item in agent.handoffs:
1088-
if isinstance(handoff_item, Handoff):
1089-
handoffs.append(handoff_item)
1090-
elif isinstance(handoff_item, Agent):
1091-
handoffs.append(handoff(handoff_item))
1092-
return handoffs
1093-
1094-
@classmethod
1095-
async def _get_all_tools(
1096-
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
1097-
) -> list[Tool]:
1098-
return await agent.get_all_tools(context_wrapper)
1099-
1100-
@classmethod
1101-
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
1102-
if isinstance(run_config.model, Model):
1103-
return run_config.model
1104-
elif isinstance(run_config.model, str):
1105-
return run_config.model_provider.get_model(run_config.model)
1106-
elif isinstance(agent.model, Model):
1107-
return agent.model
1108-
1109-
return run_config.model_provider.get_model(agent.model)
11101112

11111113
DEFAULT_RUNNER = DefaultRunner()

0 commit comments

Comments
 (0)