Skip to content

Commit 51b5ee8

Browse files
committed
Update
1 parent 67bfd9a commit 51b5ee8

File tree

4 files changed

+132
-84
lines changed

4 files changed

+132
-84
lines changed

README.md

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ Request additional information from users during tool execution:
380380

381381
```python
382382
from mcp.server.fastmcp import FastMCP, Context
383+
from mcp.server.elicitation import (
384+
AcceptedElicitation,
385+
DeclinedElicitation,
386+
CancelledElicitation,
387+
)
383388
from pydantic import BaseModel, Field
384389

385390
mcp = FastMCP("Booking System")
@@ -398,13 +403,15 @@ async def book_table(date: str, party_size: int, ctx: Context) -> str:
398403
message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking
399404
)
400405

401-
if result.action == "accept" and result.data:
402-
if result.data.confirm:
403-
return f"Booked! Notes: {result.data.notes or 'None'}"
404-
return "Booking cancelled"
405-
406-
# User declined or cancelled
407-
return f"Booking {result.action}"
406+
match result:
407+
case AcceptedElicitation(data=data):
408+
if data.confirm:
409+
return f"Booked! Notes: {data.notes or 'None'}"
410+
return "Booking cancelled"
411+
case DeclinedElicitation():
412+
return "Booking declined"
413+
case CancelledElicitation():
414+
return "Booking cancelled"
408415
```
409416

410417
The `elicit()` method returns an `ElicitationResult` with:

src/mcp/server/elicitation.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Elicitation utilities for MCP servers."""
2+
3+
from __future__ import annotations
4+
5+
import types
6+
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
7+
8+
from pydantic import BaseModel
9+
from pydantic.fields import FieldInfo
10+
11+
from mcp.server.session import ServerSession
12+
from mcp.types import RequestId
13+
14+
ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
15+
16+
17+
class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]):
18+
"""Result when user accepts the elicitation."""
19+
20+
action: Literal["accept"] = "accept"
21+
data: ElicitSchemaModelT
22+
23+
24+
class DeclinedElicitation(BaseModel):
25+
"""Result when user declines the elicitation."""
26+
27+
action: Literal["decline"] = "decline"
28+
29+
30+
class CancelledElicitation(BaseModel):
31+
"""Result when user cancels the elicitation."""
32+
33+
action: Literal["cancel"] = "cancel"
34+
35+
36+
ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation
37+
38+
39+
# Primitive types allowed in elicitation schemas
40+
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
41+
42+
43+
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
44+
"""Validate that a Pydantic model only contains primitive field types."""
45+
for field_name, field_info in schema.model_fields.items():
46+
if not _is_primitive_field(field_info):
47+
raise TypeError(
48+
f"Elicitation schema field '{field_name}' must be a primitive type "
49+
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
50+
f"Complex types like lists, dicts, or nested models are not allowed."
51+
)
52+
53+
54+
def _is_primitive_field(field_info: FieldInfo) -> bool:
55+
"""Check if a field is a primitive type allowed in elicitation schemas."""
56+
annotation = field_info.annotation
57+
58+
# Handle None type
59+
if annotation is types.NoneType:
60+
return True
61+
62+
# Handle basic primitive types
63+
if annotation in _ELICITATION_PRIMITIVE_TYPES:
64+
return True
65+
66+
# Handle Union types
67+
origin = get_origin(annotation)
68+
if origin is Union or origin is types.UnionType:
69+
args = get_args(annotation)
70+
# All args must be primitive types or None
71+
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
72+
73+
return False
74+
75+
76+
async def elicit_with_validation(
77+
session: ServerSession,
78+
message: str,
79+
schema: type[ElicitSchemaModelT],
80+
related_request_id: RequestId | None = None,
81+
) -> ElicitationResult[ElicitSchemaModelT]:
82+
"""Elicit information from the client/user with schema validation.
83+
84+
This method can be used to interactively ask for additional information from the
85+
client within a tool's execution. The client might display the message to the
86+
user and collect a response according to the provided schema. Or in case a
87+
client is an agent, it might decide how to handle the elicitation -- either by asking
88+
the user or automatically generating a response.
89+
"""
90+
# Validate that schema only contains primitive types and fail loudly if not
91+
_validate_elicitation_schema(schema)
92+
93+
json_schema = schema.model_json_schema()
94+
95+
result = await session.elicit(
96+
message=message,
97+
requestedSchema=json_schema,
98+
related_request_id=related_request_id,
99+
)
100+
101+
if result.action == "accept" and result.content:
102+
# Validate and parse the content using the schema
103+
validated_data = schema.model_validate(result.content)
104+
return AcceptedElicitation(data=validated_data)
105+
elif result.action == "decline":
106+
return DeclinedElicitation()
107+
elif result.action == "cancel":
108+
return CancelledElicitation()
109+
else:
110+
# This should never happen, but handle it just in case
111+
raise ValueError(f"Unexpected elicitation action: {result.action}")

src/mcp/server/fastmcp/server.py

Lines changed: 5 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,17 @@
44

55
import inspect
66
import re
7-
import types
87
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
98
from contextlib import (
109
AbstractAsyncContextManager,
1110
asynccontextmanager,
1211
)
1312
from itertools import chain
14-
from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin
13+
from typing import Any, Generic, Literal
1514

1615
import anyio
1716
import pydantic_core
18-
from pydantic import BaseModel, Field, ValidationError
19-
from pydantic.fields import FieldInfo
17+
from pydantic import BaseModel, Field
2018
from pydantic.networks import AnyUrl
2119
from pydantic_settings import BaseSettings, SettingsConfigDict
2220
from starlette.applications import Starlette
@@ -36,6 +34,7 @@
3634
from mcp.server.auth.settings import (
3735
AuthSettings,
3836
)
37+
from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation
3938
from mcp.server.fastmcp.exceptions import ResourceError
4039
from mcp.server.fastmcp.prompts import Prompt, PromptManager
4140
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
@@ -67,21 +66,6 @@
6766

6867
logger = get_logger(__name__)
6968

70-
ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
71-
72-
73-
class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]):
74-
"""Result of an elicitation request."""
75-
76-
action: Literal["accept", "decline", "cancel"]
77-
"""The user's action in response to the elicitation."""
78-
79-
data: ElicitSchemaModelT | None = None
80-
"""The validated data if action is 'accept', None otherwise."""
81-
82-
validation_error: str | None = None
83-
"""Validation error message if data failed to validate."""
84-
8569

8670
class Settings(BaseSettings, Generic[LifespanResultT]):
8771
"""FastMCP server settings.
@@ -893,43 +877,6 @@ def _convert_to_content(
893877
return [TextContent(type="text", text=result)]
894878

895879

896-
# Primitive types allowed in elicitation schemas
897-
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
898-
899-
900-
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
901-
"""Validate that a Pydantic model only contains primitive field types."""
902-
for field_name, field_info in schema.model_fields.items():
903-
if not _is_primitive_field(field_info):
904-
raise TypeError(
905-
f"Elicitation schema field '{field_name}' must be a primitive type "
906-
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
907-
f"Complex types like lists, dicts, or nested models are not allowed."
908-
)
909-
910-
911-
def _is_primitive_field(field_info: FieldInfo) -> bool:
912-
"""Check if a field is a primitive type allowed in elicitation schemas."""
913-
annotation = field_info.annotation
914-
915-
# Handle None type
916-
if annotation is types.NoneType:
917-
return True
918-
919-
# Handle basic primitive types
920-
if annotation in _ELICITATION_PRIMITIVE_TYPES:
921-
return True
922-
923-
# Handle Union types
924-
origin = get_origin(annotation)
925-
if origin is Union or origin is types.UnionType:
926-
args = get_args(annotation)
927-
# All args must be primitive types or None
928-
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
929-
930-
return False
931-
932-
933880
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
934881
"""Context object providing access to MCP capabilities.
935882
@@ -1053,27 +1000,10 @@ async def elicit(
10531000
The result.data will only be populated if action is "accept" and validation succeeded.
10541001
"""
10551002

1056-
# Validate that schema only contains primitive types and fail loudly if not
1057-
_validate_elicitation_schema(schema)
1058-
1059-
json_schema = schema.model_json_schema()
1060-
1061-
result = await self.request_context.session.elicit(
1062-
message=message,
1063-
requestedSchema=json_schema,
1064-
related_request_id=self.request_id,
1003+
return await elicit_with_validation(
1004+
session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id
10651005
)
10661006

1067-
if result.action == "accept" and result.content:
1068-
# Validate and parse the content using the schema
1069-
try:
1070-
validated_data = schema.model_validate(result.content)
1071-
return ElicitationResult(action="accept", data=validated_data)
1072-
except ValidationError as e:
1073-
return ElicitationResult(action="accept", validation_error=str(e))
1074-
else:
1075-
return ElicitationResult(action=result.action)
1076-
10771007
async def log(
10781008
self,
10791009
level: Literal["debug", "info", "warning", "error"],

tests/server/fastmcp/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,8 @@ class AlternativeDateSchema(BaseModel):
313313
elif result.action in ("decline", "cancel"):
314314
return "❌ Booking cancelled"
315315
else:
316-
# Validation error
317-
return f"❌ Invalid input: {result.validation_error}"
316+
# Handle case where action is "accept" but data is None
317+
return "❌ No booking data received"
318318
else:
319319
# Available - book directly
320320
return f"✅ Booked table for {party_size} on {date} at {time}"

0 commit comments

Comments
 (0)