Skip to content

Commit 7b42065

Browse files
authored
Fix uncaught exception in MCP server (#967)
1 parent 1eb1bba commit 7b42065

File tree

2 files changed

+137
-83
lines changed

2 files changed

+137
-83
lines changed

src/mcp/shared/session.py

Lines changed: 100 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -333,90 +333,107 @@ async def _receive_loop(self) -> None:
333333
self._read_stream,
334334
self._write_stream,
335335
):
336-
async for message in self._read_stream:
337-
if isinstance(message, Exception):
338-
await self._handle_incoming(message)
339-
elif isinstance(message.message.root, JSONRPCRequest):
340-
try:
341-
validated_request = self._receive_request_type.model_validate(
342-
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
343-
)
344-
responder = RequestResponder(
345-
request_id=message.message.root.id,
346-
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
347-
request=validated_request,
348-
session=self,
349-
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
350-
message_metadata=message.metadata,
351-
)
352-
self._in_flight[responder.request_id] = responder
353-
await self._received_request(responder)
354-
355-
if not responder._completed: # type: ignore[reportPrivateUsage]
356-
await self._handle_incoming(responder)
357-
except Exception as e:
358-
# For request validation errors, send a proper JSON-RPC error
359-
# response instead of crashing the server
360-
logging.warning(f"Failed to validate request: {e}")
361-
logging.debug(f"Message that failed validation: {message.message.root}")
362-
error_response = JSONRPCError(
363-
jsonrpc="2.0",
364-
id=message.message.root.id,
365-
error=ErrorData(
366-
code=INVALID_PARAMS,
367-
message="Invalid request parameters",
368-
data="",
369-
),
370-
)
371-
session_message = SessionMessage(message=JSONRPCMessage(error_response))
372-
await self._write_stream.send(session_message)
373-
374-
elif isinstance(message.message.root, JSONRPCNotification):
375-
try:
376-
notification = self._receive_notification_type.model_validate(
377-
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
378-
)
379-
# Handle cancellation notifications
380-
if isinstance(notification.root, CancelledNotification):
381-
cancelled_id = notification.root.params.requestId
382-
if cancelled_id in self._in_flight:
383-
await self._in_flight[cancelled_id].cancel()
336+
try:
337+
async for message in self._read_stream:
338+
if isinstance(message, Exception):
339+
await self._handle_incoming(message)
340+
elif isinstance(message.message.root, JSONRPCRequest):
341+
try:
342+
validated_request = self._receive_request_type.model_validate(
343+
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
344+
)
345+
responder = RequestResponder(
346+
request_id=message.message.root.id,
347+
request_meta=validated_request.root.params.meta
348+
if validated_request.root.params
349+
else None,
350+
request=validated_request,
351+
session=self,
352+
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
353+
message_metadata=message.metadata,
354+
)
355+
self._in_flight[responder.request_id] = responder
356+
await self._received_request(responder)
357+
358+
if not responder._completed: # type: ignore[reportPrivateUsage]
359+
await self._handle_incoming(responder)
360+
except Exception as e:
361+
# For request validation errors, send a proper JSON-RPC error
362+
# response instead of crashing the server
363+
logging.warning(f"Failed to validate request: {e}")
364+
logging.debug(f"Message that failed validation: {message.message.root}")
365+
error_response = JSONRPCError(
366+
jsonrpc="2.0",
367+
id=message.message.root.id,
368+
error=ErrorData(
369+
code=INVALID_PARAMS,
370+
message="Invalid request parameters",
371+
data="",
372+
),
373+
)
374+
session_message = SessionMessage(message=JSONRPCMessage(error_response))
375+
await self._write_stream.send(session_message)
376+
377+
elif isinstance(message.message.root, JSONRPCNotification):
378+
try:
379+
notification = self._receive_notification_type.model_validate(
380+
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
381+
)
382+
# Handle cancellation notifications
383+
if isinstance(notification.root, CancelledNotification):
384+
cancelled_id = notification.root.params.requestId
385+
if cancelled_id in self._in_flight:
386+
await self._in_flight[cancelled_id].cancel()
387+
else:
388+
# Handle progress notifications callback
389+
if isinstance(notification.root, ProgressNotification):
390+
progress_token = notification.root.params.progressToken
391+
# If there is a progress callback for this token,
392+
# call it with the progress information
393+
if progress_token in self._progress_callbacks:
394+
callback = self._progress_callbacks[progress_token]
395+
await callback(
396+
notification.root.params.progress,
397+
notification.root.params.total,
398+
notification.root.params.message,
399+
)
400+
await self._received_notification(notification)
401+
await self._handle_incoming(notification)
402+
except Exception as e:
403+
# For other validation errors, log and continue
404+
logging.warning(
405+
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
406+
)
407+
else: # Response or error
408+
stream = self._response_streams.pop(message.message.root.id, None)
409+
if stream:
410+
await stream.send(message.message.root)
384411
else:
385-
# Handle progress notifications callback
386-
if isinstance(notification.root, ProgressNotification):
387-
progress_token = notification.root.params.progressToken
388-
# If there is a progress callback for this token,
389-
# call it with the progress information
390-
if progress_token in self._progress_callbacks:
391-
callback = self._progress_callbacks[progress_token]
392-
await callback(
393-
notification.root.params.progress,
394-
notification.root.params.total,
395-
notification.root.params.message,
396-
)
397-
await self._received_notification(notification)
398-
await self._handle_incoming(notification)
399-
except Exception as e:
400-
# For other validation errors, log and continue
401-
logging.warning(
402-
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
403-
)
404-
else: # Response or error
405-
stream = self._response_streams.pop(message.message.root.id, None)
406-
if stream:
407-
await stream.send(message.message.root)
408-
else:
409-
await self._handle_incoming(
410-
RuntimeError("Received response with an unknown " f"request ID: {message}")
411-
)
412-
413-
# after the read stream is closed, we need to send errors
414-
# to any pending requests
415-
for id, stream in self._response_streams.items():
416-
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
417-
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
418-
await stream.aclose()
419-
self._response_streams.clear()
412+
await self._handle_incoming(
413+
RuntimeError("Received response with an unknown " f"request ID: {message}")
414+
)
415+
416+
except anyio.ClosedResourceError:
417+
# This is expected when the client disconnects abruptly.
418+
# Without this handler, the exception would propagate up and
419+
# crash the server's task group.
420+
logging.debug("Read stream closed by client")
421+
except Exception as e:
422+
# Other exceptions are not expected and should be logged. We purposefully
423+
# catch all exceptions here to avoid crashing the server.
424+
logging.exception(f"Unhandled exception in receive loop: {e}")
425+
finally:
426+
# after the read stream is closed, we need to send errors
427+
# to any pending requests
428+
for id, stream in self._response_streams.items():
429+
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
430+
try:
431+
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
432+
await stream.aclose()
433+
except Exception:
434+
# Stream might already be closed
435+
pass
436+
self._response_streams.clear()
420437

421438
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
422439
"""

tests/shared/test_streamable_http.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,3 +1521,40 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_
15211521
)
15221522
assert response.status_code == 200 # Should succeed for backwards compatibility
15231523
assert response.headers.get("Content-Type") == "text/event-stream"
1524+
1525+
1526+
@pytest.mark.anyio
1527+
async def test_client_crash_handled(basic_server, basic_server_url):
1528+
"""Test that cases where the client crashes are handled gracefully."""
1529+
1530+
# Simulate bad client that crashes after init
1531+
async def bad_client():
1532+
"""Client that triggers ClosedResourceError"""
1533+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1534+
read_stream,
1535+
write_stream,
1536+
_,
1537+
):
1538+
async with ClientSession(read_stream, write_stream) as session:
1539+
await session.initialize()
1540+
raise Exception("client crash")
1541+
1542+
# Run bad client a few times to trigger the crash
1543+
for _ in range(3):
1544+
try:
1545+
await bad_client()
1546+
except Exception:
1547+
pass
1548+
await anyio.sleep(0.1)
1549+
1550+
# Try a good client, it should still be able to connect and list tools
1551+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1552+
read_stream,
1553+
write_stream,
1554+
_,
1555+
):
1556+
async with ClientSession(read_stream, write_stream) as session:
1557+
result = await session.initialize()
1558+
assert isinstance(result, InitializeResult)
1559+
tools = await session.list_tools()
1560+
assert tools.tools

0 commit comments

Comments
 (0)