From 41c2ffbf8fba9f13b1550991cb2c01222aa04763 Mon Sep 17 00:00:00 2001 From: Zain Memon Date: Mon, 30 Jun 2025 08:28:15 +1000 Subject: [PATCH] Add logprobs support for responses --- docs/ja/models/index.md | 16 ++++++- docs/models/index.md | 14 ++++++ src/agents/model_settings.py | 17 +++++--- src/agents/models/openai_responses.py | 15 +++++-- tests/model_settings/test_serialization.py | 4 +- tests/test_logprobs.py | 50 ++++++++++++++++++++++ 6 files changed, 104 insertions(+), 12 deletions(-) create mode 100644 tests/test_logprobs.py diff --git a/docs/ja/models/index.md b/docs/ja/models/index.md index 410c01676..daa03eeb3 100644 --- a/docs/ja/models/index.md +++ b/docs/ja/models/index.md @@ -103,7 +103,7 @@ OpenAI の Responses API を使用する場合、`user` や `service_tier` な ```python from agents import Agent, ModelSettings -english_agent = Agent( + english_agent = Agent( name="English agent", instructions="You only speak English", model="gpt-4o", @@ -114,6 +114,20 @@ english_agent = Agent( ) ``` +Responses API でトークンの対数確率を取得したい場合は、 +`ModelSettings` の `top_logprobs` を設定してください。 + +```python +from agents import Agent, ModelSettings + +agent = Agent( + name="English agent", + instructions="You only speak English", + model="gpt-4o", + model_settings=ModelSettings(top_logprobs=2), +) +``` + ## 他の LLM プロバイダー使用時の一般的な問題 ### Tracing クライアントの 401 エラー diff --git a/docs/models/index.md b/docs/models/index.md index b3b2b7f0b..ed7b09f37 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -109,6 +109,20 @@ english_agent = Agent( ) ``` +You can also request token log probabilities when using the Responses API by +setting `top_logprobs` in `ModelSettings`. + +```python +from agents import Agent, ModelSettings + +agent = Agent( + name="English agent", + instructions="You only speak English", + model="gpt-4o", + model_settings=ModelSettings(top_logprobs=2), +) +``` + ## Common issues with using other LLM providers ### Tracing client error 401 diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 26af94ba3..d06aa3db9 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -17,9 +17,9 @@ class _OmitTypeAnnotation: @classmethod def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: def validate_from_none(value: None) -> _Omit: return _Omit() @@ -39,13 +39,14 @@ def validate_from_none(value: None) -> _Omit: from_none_schema, ] ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: None - ), + serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None), ) + + Omit = Annotated[_Omit, _OmitTypeAnnotation] Headers: TypeAlias = Mapping[str, Union[str, Omit]] + @dataclass class ModelSettings: """Settings to use when calling an LLM. @@ -107,6 +108,10 @@ class ModelSettings: """Additional output data to include in the model response. [include parameter](https://platform.openai.com/docs/api-reference/responses/create#responses-create-include)""" + top_logprobs: int | None = None + """Number of top tokens to return logprobs for. Setting this will + automatically include ``"message.output_text.logprobs"`` in the response.""" + extra_query: Query | None = None """Additional query fields to provide with the request. Defaults to None if not provided.""" diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index a7ce62983..e43559c6b 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -3,7 +3,7 @@ import json from collections.abc import AsyncIterator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven from openai.types import ChatModel @@ -246,9 +246,12 @@ async def _fetch_response( converted_tools = Converter.convert_tools(tools, handoffs) response_format = Converter.get_response_format(output_schema) - include: list[ResponseIncludable] = converted_tools.includes + include_set: set[str] = set(converted_tools.includes) if model_settings.response_include is not None: - include = list({*include, *model_settings.response_include}) + include_set.update(model_settings.response_include) + if model_settings.top_logprobs is not None: + include_set.add("message.output_text.logprobs") + include = cast(list[ResponseIncludable], list(include_set)) if _debug.DONT_LOG_MODEL_DATA: logger.debug("Calling LLM") @@ -263,6 +266,10 @@ async def _fetch_response( f"Previous response id: {previous_response_id}\n" ) + extra_args = dict(model_settings.extra_args or {}) + if model_settings.top_logprobs is not None: + extra_args["top_logprobs"] = model_settings.top_logprobs + return await self._client.responses.create( previous_response_id=self._non_null_or_not_given(previous_response_id), instructions=self._non_null_or_not_given(system_instructions), @@ -285,7 +292,7 @@ async def _fetch_response( store=self._non_null_or_not_given(model_settings.store), reasoning=self._non_null_or_not_given(model_settings.reasoning), metadata=self._non_null_or_not_given(model_settings.metadata), - **(model_settings.extra_args or {}), + **extra_args, ) def _get_client(self) -> AsyncOpenAI: diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index 94d11def3..ba405553c 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -47,6 +47,7 @@ def test_all_fields_serialization() -> None: store=False, include_usage=False, response_include=["reasoning.encrypted_content"], + top_logprobs=1, extra_query={"foo": "bar"}, extra_body={"foo": "bar"}, extra_headers={"foo": "bar"}, @@ -135,8 +136,8 @@ def test_extra_args_resolve_both_none() -> None: assert resolved.temperature == 0.5 assert resolved.top_p == 0.9 -def test_pydantic_serialization() -> None: +def test_pydantic_serialization() -> None: """Tests whether ModelSettings can be serialized with Pydantic.""" # First, lets create a ModelSettings instance @@ -153,6 +154,7 @@ def test_pydantic_serialization() -> None: metadata={"foo": "bar"}, store=False, include_usage=False, + top_logprobs=1, extra_query={"foo": "bar"}, extra_body={"foo": "bar"}, extra_headers={"foo": "bar"}, diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py new file mode 100644 index 000000000..aa5bb06f8 --- /dev/null +++ b/tests/test_logprobs.py @@ -0,0 +1,50 @@ +import pytest +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents import ModelSettings, ModelTracing, OpenAIResponsesModel + + +class DummyResponses: + async def create(self, **kwargs): + self.kwargs = kwargs + + class DummyResponse: + id = "dummy" + output = [] + usage = type( + "Usage", + (), + { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "input_tokens_details": InputTokensDetails(cached_tokens=0), + "output_tokens_details": OutputTokensDetails(reasoning_tokens=0), + }, + )() + + return DummyResponse() + + +class DummyClient: + def __init__(self): + self.responses = DummyResponses() + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_top_logprobs_param_passed(): + client = DummyClient() + model = OpenAIResponsesModel(model="gpt-4", openai_client=client) # type: ignore + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(top_logprobs=2), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert client.responses.kwargs["top_logprobs"] == 2 + assert "message.output_text.logprobs" in client.responses.kwargs["include"]