|
4 | 4 |
|
5 | 5 | import inspect
|
6 | 6 | import re
|
7 |
| -import types |
8 | 7 | from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
|
9 | 8 | from contextlib import (
|
10 | 9 | AbstractAsyncContextManager,
|
11 | 10 | asynccontextmanager,
|
12 | 11 | )
|
13 | 12 | from itertools import chain
|
14 |
| -from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin |
| 13 | +from typing import Any, Generic, Literal |
15 | 14 |
|
16 | 15 | import anyio
|
17 | 16 | import pydantic_core
|
18 |
| -from pydantic import BaseModel, Field, ValidationError |
19 |
| -from pydantic.fields import FieldInfo |
| 17 | +from pydantic import BaseModel, Field |
20 | 18 | from pydantic.networks import AnyUrl
|
21 | 19 | from pydantic_settings import BaseSettings, SettingsConfigDict
|
22 | 20 | from starlette.applications import Starlette
|
|
36 | 34 | from mcp.server.auth.settings import (
|
37 | 35 | AuthSettings,
|
38 | 36 | )
|
| 37 | +from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation |
39 | 38 | from mcp.server.fastmcp.exceptions import ResourceError
|
40 | 39 | from mcp.server.fastmcp.prompts import Prompt, PromptManager
|
41 | 40 | from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
|
|
67 | 66 |
|
68 | 67 | logger = get_logger(__name__)
|
69 | 68 |
|
70 |
| -ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) |
71 |
| - |
72 |
| - |
73 |
| -class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]): |
74 |
| - """Result of an elicitation request.""" |
75 |
| - |
76 |
| - action: Literal["accept", "decline", "cancel"] |
77 |
| - """The user's action in response to the elicitation.""" |
78 |
| - |
79 |
| - data: ElicitSchemaModelT | None = None |
80 |
| - """The validated data if action is 'accept', None otherwise.""" |
81 |
| - |
82 |
| - validation_error: str | None = None |
83 |
| - """Validation error message if data failed to validate.""" |
84 |
| - |
85 | 69 |
|
86 | 70 | class Settings(BaseSettings, Generic[LifespanResultT]):
|
87 | 71 | """FastMCP server settings.
|
@@ -893,43 +877,6 @@ def _convert_to_content(
|
893 | 877 | return [TextContent(type="text", text=result)]
|
894 | 878 |
|
895 | 879 |
|
896 |
| -# Primitive types allowed in elicitation schemas |
897 |
| -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) |
898 |
| - |
899 |
| - |
900 |
| -def _validate_elicitation_schema(schema: type[BaseModel]) -> None: |
901 |
| - """Validate that a Pydantic model only contains primitive field types.""" |
902 |
| - for field_name, field_info in schema.model_fields.items(): |
903 |
| - if not _is_primitive_field(field_info): |
904 |
| - raise TypeError( |
905 |
| - f"Elicitation schema field '{field_name}' must be a primitive type " |
906 |
| - f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " |
907 |
| - f"Complex types like lists, dicts, or nested models are not allowed." |
908 |
| - ) |
909 |
| - |
910 |
| - |
911 |
| -def _is_primitive_field(field_info: FieldInfo) -> bool: |
912 |
| - """Check if a field is a primitive type allowed in elicitation schemas.""" |
913 |
| - annotation = field_info.annotation |
914 |
| - |
915 |
| - # Handle None type |
916 |
| - if annotation is types.NoneType: |
917 |
| - return True |
918 |
| - |
919 |
| - # Handle basic primitive types |
920 |
| - if annotation in _ELICITATION_PRIMITIVE_TYPES: |
921 |
| - return True |
922 |
| - |
923 |
| - # Handle Union types |
924 |
| - origin = get_origin(annotation) |
925 |
| - if origin is Union or origin is types.UnionType: |
926 |
| - args = get_args(annotation) |
927 |
| - # All args must be primitive types or None |
928 |
| - return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) |
929 |
| - |
930 |
| - return False |
931 |
| - |
932 |
| - |
933 | 880 | class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
934 | 881 | """Context object providing access to MCP capabilities.
|
935 | 882 |
|
@@ -1053,27 +1000,10 @@ async def elicit(
|
1053 | 1000 | The result.data will only be populated if action is "accept" and validation succeeded.
|
1054 | 1001 | """
|
1055 | 1002 |
|
1056 |
| - # Validate that schema only contains primitive types and fail loudly if not |
1057 |
| - _validate_elicitation_schema(schema) |
1058 |
| - |
1059 |
| - json_schema = schema.model_json_schema() |
1060 |
| - |
1061 |
| - result = await self.request_context.session.elicit( |
1062 |
| - message=message, |
1063 |
| - requestedSchema=json_schema, |
1064 |
| - related_request_id=self.request_id, |
| 1003 | + return await elicit_with_validation( |
| 1004 | + session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id |
1065 | 1005 | )
|
1066 | 1006 |
|
1067 |
| - if result.action == "accept" and result.content: |
1068 |
| - # Validate and parse the content using the schema |
1069 |
| - try: |
1070 |
| - validated_data = schema.model_validate(result.content) |
1071 |
| - return ElicitationResult(action="accept", data=validated_data) |
1072 |
| - except ValidationError as e: |
1073 |
| - return ElicitationResult(action="accept", validation_error=str(e)) |
1074 |
| - else: |
1075 |
| - return ElicitationResult(action=result.action) |
1076 |
| - |
1077 | 1007 | async def log(
|
1078 | 1008 | self,
|
1079 | 1009 | level: Literal["debug", "info", "warning", "error"],
|
|
0 commit comments