Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit a988738

Browse files
authored
Fix Anthropic FIM with muxing. (#1304)
In the context of muxing, the code determining which mapper to use when receiving requests to be routed towards Anthropic was relying in `is_fim_request` only, and was not taking into account if the actual endpoint receiving the request was the legacy one (i.e. `/completions`) or the current one (i.e. `/chat/completions`). This caused the use of the wrong mapper, which led to an empty text content for the FIM request. A better way to determine which mapper to use is looking at the effective type, since that's the real source of truth for the translation.
1 parent 17fab51 commit a988738

File tree

9 files changed

+120
-56
lines changed

9 files changed

+120
-56
lines changed

src/codegate/muxing/router.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,15 @@ async def route_to_dest_provider(
138138
# TODO this should be improved
139139
match model_route.endpoint.provider_type:
140140
case ProviderType.anthropic:
141-
if is_fim_request:
141+
# Note: despite `is_fim_request` being true, our
142+
# integration tests query the `/chat/completions`
143+
# endpoint, which causes the
144+
# `anthropic_from_legacy_openai` to incorrectly
145+
# populate the struct.
146+
#
147+
# Checking for the actual type is a much more
148+
# reliable way of determining the right mapper.
149+
if isinstance(parsed, openai.LegacyCompletionRequest):
142150
completion_function = anthropic.acompletion
143151
from_openai = anthropic_from_legacy_openai
144152
to_openai = anthropic_to_legacy_openai

src/codegate/providers/anthropic/provider.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@
1111
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
1212
from codegate.providers.base import BaseProvider, ModelFetchError
1313
from codegate.providers.fim_analyzer import FIMAnalyzer
14-
from codegate.types.anthropic import ChatCompletionRequest, stream_generator
14+
from codegate.types.anthropic import (
15+
ChatCompletionRequest,
16+
single_message,
17+
single_response,
18+
stream_generator,
19+
)
20+
from codegate.types.generators import (
21+
completion_handler_replacement,
22+
)
1523

1624
logger = structlog.get_logger("codegate")
1725

@@ -118,18 +126,29 @@ async def create_message(
118126
body = await request.body()
119127

120128
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
121-
print(f"{create_message.__name__}: {body}")
129+
print(f"{body.decode('utf-8')}")
122130

123131
req = ChatCompletionRequest.model_validate_json(body)
124132
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, req)
125133

126-
return await self.process_request(
127-
req,
128-
x_api_key,
129-
self.base_url,
130-
is_fim_request,
131-
request.state.detected_client,
132-
)
134+
if req.stream:
135+
return await self.process_request(
136+
req,
137+
x_api_key,
138+
self.base_url,
139+
is_fim_request,
140+
request.state.detected_client,
141+
)
142+
else:
143+
return await self.process_request(
144+
req,
145+
x_api_key,
146+
self.base_url,
147+
is_fim_request,
148+
request.state.detected_client,
149+
completion_handler=completion_handler_replacement(single_message),
150+
stream_generator=single_response,
151+
)
133152

134153

135154
async def dumper(stream):

src/codegate/providers/ollama/completion_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def _ollama_dispatcher( # noqa: C901
7373
stream = openai_stream_generator(prepend(first, stream))
7474

7575
if isinstance(first, OpenAIChatCompletion):
76-
stream = openai_single_response_generator(first, stream)
76+
stream = openai_single_response_generator(first)
7777

7878
async for item in stream:
7979
yield item

src/codegate/types/anthropic/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from ._generators import (
22
acompletion,
33
message_wrapper,
4+
single_message,
5+
single_response,
46
stream_generator,
57
)
68
from ._request_models import (
@@ -49,6 +51,8 @@
4951
__all__ = [
5052
"acompletion",
5153
"message_wrapper",
54+
"single_message",
55+
"single_response",
5256
"stream_generator",
5357
"AssistantMessage",
5458
"CacheControl",

src/codegate/types/anthropic/_generators.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ContentBlockDelta,
1313
ContentBlockStart,
1414
ContentBlockStop,
15+
Message,
1516
MessageDelta,
1617
MessageError,
1718
MessagePing,
@@ -27,7 +28,7 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
2728
try:
2829
async for chunk in stream:
2930
try:
30-
body = chunk.json(exclude_defaults=True, exclude_unset=True)
31+
body = chunk.json(exclude_unset=True)
3132
except Exception as e:
3233
logger.error("failed serializing payload", exc_info=e)
3334
err = MessageError(
@@ -37,7 +38,7 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
3738
message=str(e),
3839
),
3940
)
40-
body = err.json(exclude_defaults=True, exclude_unset=True)
41+
body = err.json(exclude_unset=True)
4142
yield f"event: error\ndata: {body}\n\n"
4243

4344
data = f"event: {chunk.type}\ndata: {body}\n\n"
@@ -55,10 +56,60 @@ async def stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
5556
message=str(e),
5657
),
5758
)
58-
body = err.json(exclude_defaults=True, exclude_unset=True)
59+
body = err.json(exclude_unset=True)
5960
yield f"event: error\ndata: {body}\n\n"
6061

6162

63+
async def single_response(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
64+
"""Wraps a single response object in an AsyncIterator. This is
65+
meant to be used for non-streaming responses.
66+
67+
"""
68+
resp = await anext(stream)
69+
yield resp.model_dump_json(exclude_unset=True)
70+
71+
72+
async def single_message(request, api_key, base_url, stream=None, is_fim_request=None):
73+
headers = {
74+
"anthropic-version": "2023-06-01",
75+
"x-api-key": api_key,
76+
"accept": "application/json",
77+
"content-type": "application/json",
78+
}
79+
payload = request.model_dump_json(exclude_unset=True)
80+
81+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
82+
print(payload)
83+
84+
client = httpx.AsyncClient()
85+
async with client.stream(
86+
"POST",
87+
f"{base_url}/v1/messages",
88+
headers=headers,
89+
content=payload,
90+
timeout=60, # TODO this should not be hardcoded
91+
) as resp:
92+
match resp.status_code:
93+
case 200:
94+
text = await resp.aread()
95+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
96+
print(text.decode("utf-8"))
97+
yield Message.model_validate_json(text)
98+
case 400 | 401 | 403 | 404 | 413 | 429:
99+
text = await resp.aread()
100+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
101+
print(text.decode("utf-8"))
102+
yield MessageError.model_validate_json(text)
103+
case 500 | 529:
104+
text = await resp.aread()
105+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
106+
print(text.decode("utf-8"))
107+
yield MessageError.model_validate_json(text)
108+
case _:
109+
logger.error(f"unexpected status code {resp.status_code}", provider="anthropic")
110+
raise ValueError(f"unexpected status code {resp.status_code}", provider="anthropic")
111+
112+
62113
async def acompletion(request, api_key, base_url):
63114
headers = {
64115
"anthropic-version": "2023-06-01",
@@ -86,9 +137,13 @@ async def acompletion(request, api_key, base_url):
86137
yield event
87138
case 400 | 401 | 403 | 404 | 413 | 429:
88139
text = await resp.aread()
140+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
141+
print(text.decode("utf-8"))
89142
yield MessageError.model_validate_json(text)
90143
case 500 | 529:
91144
text = await resp.aread()
145+
if os.getenv("CODEGATE_DEBUG_ANTHROPIC") is not None:
146+
print(text.decode("utf-8"))
92147
yield MessageError.model_validate_json(text)
93148
case _:
94149
logger.error(f"unexpected status code {resp.status_code}", provider="anthropic")

src/codegate/types/anthropic/_request_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class ToolDef(pydantic.BaseModel):
155155
Literal["auto"],
156156
Literal["any"],
157157
Literal["tool"],
158+
Literal["none"],
158159
]
159160

160161

src/codegate/types/generators.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,27 @@
1-
import os
21
from typing import (
3-
Any,
4-
AsyncIterator,
2+
Callable,
53
)
64

7-
import pydantic
85
import structlog
96

107
logger = structlog.get_logger("codegate")
118

129

13-
# Since different providers typically use one of these formats for streaming
14-
# responses, we have a single stream generator for each format that is then plugged
15-
# into the adapter.
10+
def completion_handler_replacement(
11+
completion_handler: Callable,
12+
):
13+
async def _inner(
14+
request,
15+
base_url,
16+
api_key,
17+
stream=None,
18+
is_fim_request=None,
19+
):
20+
# Execute e.g. acompletion from Anthropic types
21+
return completion_handler(
22+
request,
23+
api_key,
24+
base_url,
25+
)
1626

17-
18-
async def sse_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
19-
"""OpenAI-style SSE format"""
20-
try:
21-
async for chunk in stream:
22-
if isinstance(chunk, pydantic.BaseModel):
23-
# alternatively we might want to just dump the whole object
24-
# this might even allow us to tighten the typing of the stream
25-
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
26-
try:
27-
if os.getenv("CODEGATE_DEBUG_OPENAI") is not None:
28-
print(chunk)
29-
yield f"data: {chunk}\n\n"
30-
except Exception as e:
31-
logger.error("failed generating output payloads", exc_info=e)
32-
yield f"data: {str(e)}\n\n"
33-
except Exception as e:
34-
logger.error("failed generating output payloads", exc_info=e)
35-
yield f"data: {str(e)}\n\n"
36-
finally:
37-
yield "data: [DONE]\n\n"
27+
return _inner

src/codegate/types/ollama/_generators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def stream_generator(
2323
try:
2424
async for chunk in stream:
2525
try:
26-
body = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
26+
body = chunk.model_dump_json(exclude_unset=True)
2727
data = f"{body}\n"
2828

2929
if os.getenv("CODEGATE_DEBUG_OLLAMA") is not None:

src/codegate/types/openai/_generators.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,13 @@ async def stream_generator(stream: AsyncIterator[StreamingChatCompletion]) -> As
5050

5151
async def single_response_generator(
5252
first: ChatCompletion,
53-
stream: AsyncIterator[ChatCompletion],
5453
) -> AsyncIterator[ChatCompletion]:
5554
"""Wraps a single response object in an AsyncIterator. This is
5655
meant to be used for non-streaming responses.
5756
5857
"""
5958
yield first.model_dump_json(exclude_none=True, exclude_unset=True)
6059

61-
# Note: this async for loop is necessary to force Python to return
62-
# an AsyncIterator. This is necessary because of the wiring at the
63-
# Provider level expecting an AsyncIterator rather than a single
64-
# response payload.
65-
#
66-
# Refactoring this means adding a code path specific for when we
67-
# expect single response payloads rather than an SSE stream.
68-
async for item in stream:
69-
if item:
70-
logger.error("no further items were expected", item=item)
71-
yield item.model_dump_json(exclude_none=True, exclude_unset=True)
72-
7360

7461
async def completions_streaming(request, api_key, base_url):
7562
if base_url is None:

0 commit comments

Comments
 (0)