From e10ab3d03111fe042b91f9088a2b754912d4124c Mon Sep 17 00:00:00 2001 From: ihrpr Date: Mon, 23 Jun 2025 21:41:05 +0100 Subject: [PATCH 1/3] remove github from auth examples --- examples/servers/simple-auth/README.md | 18 +- .../mcp_simple_auth/auth_server.py | 127 +++----- .../mcp_simple_auth/github_oauth_provider.py | 266 ----------------- .../mcp_simple_auth/legacy_as_server.py | 111 ++++--- .../simple-auth/mcp_simple_auth/server.py | 86 +----- .../mcp_simple_auth/simple_auth_provider.py | 270 ++++++++++++++++++ 6 files changed, 376 insertions(+), 502 deletions(-) delete mode 100644 examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py create mode 100644 examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py diff --git a/examples/servers/simple-auth/README.md b/examples/servers/simple-auth/README.md index 2c21143c8..a15ad1daf 100644 --- a/examples/servers/simple-auth/README.md +++ b/examples/servers/simple-auth/README.md @@ -4,20 +4,6 @@ This example demonstrates OAuth 2.0 authentication with the Model Context Protoc --- -## Setup Requirements - -**Create a GitHub OAuth App:** -- Go to GitHub Settings > Developer settings > OAuth Apps > New OAuth App -- **Authorization callback URL:** `http://localhost:9000/github/callback` -- Note down your **Client ID** and **Client Secret** - -**Set environment variables:** -```bash -export MCP_GITHUB_CLIENT_ID="your_client_id_here" -export MCP_GITHUB_CLIENT_SECRET="your_client_secret_here" -``` - ---- ## Running the Servers @@ -33,9 +19,8 @@ uv run mcp-simple-auth-as --port=9000 **What it provides:** - OAuth 2.0 flows (registration, authorization, token exchange) -- GitHub OAuth integration for user authentication +- Simple credential-based authentication (no external provider needed) - Token introspection endpoint for Resource Servers (`/introspect`) -- User data proxy endpoint (`/github/user`) --- @@ -90,6 +75,7 @@ curl http://localhost:9000/.well-known/oauth-authorization-server } ``` + ## Legacy MCP Server as Authorization Server (Backwards Compatibility) For backwards compatibility with older MCP implementations, a legacy server is provided that acts as an Authorization Server (following the old spec where MCP servers could optionally provide OAuth): diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 2594f81d6..80a2e8b8a 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -7,8 +7,6 @@ NOTE: this is a simplified example for demonstration purposes. This is not a production-ready implementation. -Usage: - python -m mcp_simple_auth.auth_server --port=9000 """ import asyncio @@ -20,14 +18,14 @@ from starlette.applications import Starlette from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.responses import JSONResponse, Response from starlette.routing import Route from uvicorn import Config, Server from mcp.server.auth.routes import cors_middleware, create_auth_routes from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions -from .github_oauth_provider import GitHubOAuthProvider, GitHubOAuthSettings +from .simple_auth_provider import SimpleAuthSettings, SimpleOAuthProvider logger = logging.getLogger(__name__) @@ -39,60 +37,64 @@ class AuthServerSettings(BaseModel): host: str = "localhost" port: int = 9000 server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000") - github_callback_path: str = "http://localhost:9000/github/callback" + auth_callback_path: str = "http://localhost:9000/login/callback" -class GitHubProxyAuthProvider(GitHubOAuthProvider): +class SimpleAuthProvider(SimpleOAuthProvider): """ - Authorization Server provider that proxies GitHub OAuth. + Authorization Server provider with simple demo authentication. This provider: - 1. Issues MCP tokens after GitHub authentication + 1. Issues MCP tokens after simple credential authentication 2. Stores token state for introspection by Resource Servers - 3. Maps MCP tokens to GitHub tokens for API access """ - def __init__(self, github_settings: GitHubOAuthSettings, github_callback_path: str): - super().__init__(github_settings, github_callback_path) + def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str): + super().__init__(auth_settings, auth_callback_path, server_url) -def create_authorization_server(server_settings: AuthServerSettings, github_settings: GitHubOAuthSettings) -> Starlette: +def create_authorization_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings) -> Starlette: """Create the Authorization Server application.""" - oauth_provider = GitHubProxyAuthProvider(github_settings, server_settings.github_callback_path) + oauth_provider = SimpleAuthProvider( + auth_settings, server_settings.auth_callback_path, str(server_settings.server_url) + ) - auth_settings = AuthSettings( + mcp_auth_settings = AuthSettings( issuer_url=server_settings.server_url, client_registration_options=ClientRegistrationOptions( enabled=True, - valid_scopes=[github_settings.mcp_scope], - default_scopes=[github_settings.mcp_scope], + valid_scopes=[auth_settings.mcp_scope], + default_scopes=[auth_settings.mcp_scope], ), - required_scopes=[github_settings.mcp_scope], + required_scopes=[auth_settings.mcp_scope], resource_server_url=None, ) # Create OAuth routes routes = create_auth_routes( provider=oauth_provider, - issuer_url=auth_settings.issuer_url, - service_documentation_url=auth_settings.service_documentation_url, - client_registration_options=auth_settings.client_registration_options, - revocation_options=auth_settings.revocation_options, + issuer_url=mcp_auth_settings.issuer_url, + service_documentation_url=mcp_auth_settings.service_documentation_url, + client_registration_options=mcp_auth_settings.client_registration_options, + revocation_options=mcp_auth_settings.revocation_options, ) - # Add GitHub callback route - async def github_callback_handler(request: Request) -> Response: - """Handle GitHub OAuth callback.""" - code = request.query_params.get("code") + # Add login page route (GET) + async def login_page_handler(request: Request) -> Response: + """Show login form.""" state = request.query_params.get("state") + if not state: + raise HTTPException(400, "Missing state parameter") + return await oauth_provider.get_login_page(state) - if not code or not state: - raise HTTPException(400, "Missing code or state parameter") + routes.append(Route("/login", endpoint=login_page_handler, methods=["GET"])) - redirect_uri = await oauth_provider.handle_github_callback(code, state) - return RedirectResponse(url=redirect_uri, status_code=302) + # Add login callback route (POST) + async def login_callback_handler(request: Request) -> Response: + """Handle simple authentication callback.""" + return await oauth_provider.handle_login_callback(request) - routes.append(Route("/github/callback", endpoint=github_callback_handler, methods=["GET"])) + routes.append(Route("/login/callback", endpoint=login_callback_handler, methods=["POST"])) # Add token introspection endpoint (RFC 7662) for Resource Servers async def introspect_handler(request: Request) -> Response: @@ -112,7 +114,6 @@ async def introspect_handler(request: Request) -> Response: if not access_token: return JSONResponse({"active": False}) - # Return token info for Resource Server return JSONResponse( { "active": True, @@ -133,39 +134,12 @@ async def introspect_handler(request: Request) -> Response: ) ) - # Add GitHub user info endpoint (for Resource Server to fetch user data) - async def github_user_handler(request: Request) -> Response: - """ - Proxy endpoint to get GitHub user info using stored GitHub tokens. - - Resource Servers call this with MCP tokens to get GitHub user data - without exposing GitHub tokens to clients. - """ - # Extract Bearer token - auth_header = request.headers.get("authorization", "") - if not auth_header.startswith("Bearer "): - return JSONResponse({"error": "unauthorized"}, status_code=401) - - mcp_token = auth_header[7:] - - # Get GitHub user info using the provider method - user_info = await oauth_provider.get_github_user_info(mcp_token) - return JSONResponse(user_info) - - routes.append( - Route( - "/github/user", - endpoint=cors_middleware(github_user_handler, ["GET", "OPTIONS"]), - methods=["GET", "OPTIONS"], - ) - ) - return Starlette(routes=routes) -async def run_server(server_settings: AuthServerSettings, github_settings: GitHubOAuthSettings): +async def run_server(server_settings: AuthServerSettings, auth_settings: SimpleAuthSettings): """Run the Authorization Server.""" - auth_server = create_authorization_server(server_settings, github_settings) + auth_server = create_authorization_server(server_settings, auth_settings) config = Config( auth_server, @@ -175,22 +149,7 @@ async def run_server(server_settings: AuthServerSettings, github_settings: GitHu ) server = Server(config) - logger.info("=" * 80) - logger.info("MCP AUTHORIZATION SERVER") - logger.info("=" * 80) - logger.info(f"Server URL: {server_settings.server_url}") - logger.info("Endpoints:") - logger.info(f" - OAuth Metadata: {server_settings.server_url}/.well-known/oauth-authorization-server") - logger.info(f" - Client Registration: {server_settings.server_url}/register") - logger.info(f" - Authorization: {server_settings.server_url}/authorize") - logger.info(f" - Token Exchange: {server_settings.server_url}/token") - logger.info(f" - Token Introspection: {server_settings.server_url}/introspect") - logger.info(f" - GitHub Callback: {server_settings.server_url}/github/callback") - logger.info(f" - GitHub User Proxy: {server_settings.server_url}/github/user") - logger.info("") - logger.info("Resource Servers should use /introspect to validate tokens") - logger.info("Configure GitHub App callback URL: " + server_settings.github_callback_path) - logger.info("=" * 80) + logger.info(f"🚀 MCP Authorization Server running on {server_settings.server_url}") await server.serve() @@ -203,18 +162,12 @@ def main(port: int) -> int: This server handles OAuth flows and can be used by multiple Resource Servers. - Environment variables needed: - - MCP_GITHUB_CLIENT_ID: GitHub OAuth Client ID - - MCP_GITHUB_CLIENT_SECRET: GitHub OAuth Client Secret + Uses simple hardcoded credentials for demo purposes. """ logging.basicConfig(level=logging.INFO) - # Load GitHub settings from environment variables - github_settings = GitHubOAuthSettings() - - # Validate required fields - if not github_settings.github_client_id or not github_settings.github_client_secret: - raise ValueError("GitHub credentials not provided") + # Load simple auth settings + auth_settings = SimpleAuthSettings() # Create server settings host = "localhost" @@ -223,10 +176,10 @@ def main(port: int) -> int: host=host, port=port, server_url=AnyHttpUrl(server_url), - github_callback_path=f"{server_url}/github/callback", + auth_callback_path=f"{server_url}/login", ) - asyncio.run(run_server(server_settings, github_settings)) + asyncio.run(run_server(server_settings, auth_settings)) return 0 diff --git a/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py deleted file mode 100644 index c64db96b7..000000000 --- a/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py +++ /dev/null @@ -1,266 +0,0 @@ -""" -Shared GitHub OAuth provider for MCP servers. - -This module contains the common GitHub OAuth functionality used by both -the standalone authorization server and the legacy combined server. - -NOTE: this is a simplified example for demonstration purposes. -This is not a production-ready implementation. - -""" - -import logging -import secrets -import time -from typing import Any - -from pydantic import AnyHttpUrl -from pydantic_settings import BaseSettings, SettingsConfigDict -from starlette.exceptions import HTTPException - -from mcp.server.auth.provider import ( - AccessToken, - AuthorizationCode, - AuthorizationParams, - OAuthAuthorizationServerProvider, - RefreshToken, - construct_redirect_uri, -) -from mcp.shared._httpx_utils import create_mcp_http_client -from mcp.shared.auth import OAuthClientInformationFull, OAuthToken - -logger = logging.getLogger(__name__) - - -class GitHubOAuthSettings(BaseSettings): - """Common GitHub OAuth settings.""" - - model_config = SettingsConfigDict(env_prefix="MCP_") - - # GitHub OAuth settings - MUST be provided via environment variables - github_client_id: str | None = None - github_client_secret: str | None = None - - # GitHub OAuth URLs - github_auth_url: str = "https://github.com/login/oauth/authorize" - github_token_url: str = "https://github.com/login/oauth/access_token" - - mcp_scope: str = "user" - github_scope: str = "read:user" - - -class GitHubOAuthProvider(OAuthAuthorizationServerProvider): - """ - OAuth provider that uses GitHub as the identity provider. - - This provider handles the OAuth flow by: - 1. Redirecting users to GitHub for authentication - 2. Exchanging GitHub tokens for MCP tokens - 3. Maintaining token mappings for API access - """ - - def __init__(self, settings: GitHubOAuthSettings, github_callback_url: str): - self.settings = settings - self.github_callback_url = github_callback_url - self.clients: dict[str, OAuthClientInformationFull] = {} - self.auth_codes: dict[str, AuthorizationCode] = {} - self.tokens: dict[str, AccessToken] = {} - self.state_mapping: dict[str, dict[str, str | None]] = {} - # Maps MCP tokens to GitHub tokens - self.token_mapping: dict[str, str] = {} - - async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: - """Get OAuth client information.""" - return self.clients.get(client_id) - - async def register_client(self, client_info: OAuthClientInformationFull): - """Register a new OAuth client.""" - self.clients[client_info.client_id] = client_info - - async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """Generate an authorization URL for GitHub OAuth flow.""" - state = params.state or secrets.token_hex(16) - - # Store state mapping for callback - self.state_mapping[state] = { - "redirect_uri": str(params.redirect_uri), - "code_challenge": params.code_challenge, - "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), - "client_id": client.client_id, - "resource": params.resource, # RFC 8707 - } - - # Build GitHub authorization URL - auth_url = ( - f"{self.settings.github_auth_url}" - f"?client_id={self.settings.github_client_id}" - f"&redirect_uri={self.github_callback_url}" - f"&scope={self.settings.github_scope}" - f"&state={state}" - ) - - return auth_url - - async def handle_github_callback(self, code: str, state: str) -> str: - """Handle GitHub OAuth callback and return redirect URI.""" - state_data = self.state_mapping.get(state) - if not state_data: - raise HTTPException(400, "Invalid state parameter") - - redirect_uri = state_data["redirect_uri"] - code_challenge = state_data["code_challenge"] - redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" - client_id = state_data["client_id"] - resource = state_data.get("resource") # RFC 8707 - - # These are required values from our own state mapping - assert redirect_uri is not None - assert code_challenge is not None - assert client_id is not None - - # Exchange code for token with GitHub - async with create_mcp_http_client() as client: - response = await client.post( - self.settings.github_token_url, - data={ - "client_id": self.settings.github_client_id, - "client_secret": self.settings.github_client_secret, - "code": code, - "redirect_uri": self.github_callback_url, - }, - headers={"Accept": "application/json"}, - ) - - if response.status_code != 200: - raise HTTPException(400, "Failed to exchange code for token") - - data = response.json() - - if "error" in data: - raise HTTPException(400, data.get("error_description", data["error"])) - - github_token = data["access_token"] - - # Create MCP authorization code - new_code = f"mcp_{secrets.token_hex(16)}" - auth_code = AuthorizationCode( - code=new_code, - client_id=client_id, - redirect_uri=AnyHttpUrl(redirect_uri), - redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, - expires_at=time.time() + 300, - scopes=[self.settings.mcp_scope], - code_challenge=code_challenge, - resource=resource, # RFC 8707 - ) - self.auth_codes[new_code] = auth_code - - # Store GitHub token with MCP client_id - self.tokens[github_token] = AccessToken( - token=github_token, - client_id=client_id, - scopes=[self.settings.github_scope], - expires_at=None, - ) - - del self.state_mapping[state] - return construct_redirect_uri(redirect_uri, code=new_code, state=state) - - async def load_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: str - ) -> AuthorizationCode | None: - """Load an authorization code.""" - return self.auth_codes.get(authorization_code) - - async def exchange_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode - ) -> OAuthToken: - """Exchange authorization code for tokens.""" - if authorization_code.code not in self.auth_codes: - raise ValueError("Invalid authorization code") - - # Generate MCP access token - mcp_token = f"mcp_{secrets.token_hex(32)}" - - # Store MCP token - self.tokens[mcp_token] = AccessToken( - token=mcp_token, - client_id=client.client_id, - scopes=authorization_code.scopes, - expires_at=int(time.time()) + 3600, - resource=authorization_code.resource, # RFC 8707 - ) - - # Find GitHub token for this client - github_token = next( - ( - token - for token, data in self.tokens.items() - if (token.startswith("ghu_") or token.startswith("gho_")) and data.client_id == client.client_id - ), - None, - ) - - # Store mapping between MCP token and GitHub token - if github_token: - self.token_mapping[mcp_token] = github_token - - del self.auth_codes[authorization_code.code] - - return OAuthToken( - access_token=mcp_token, - token_type="Bearer", - expires_in=3600, - scope=" ".join(authorization_code.scopes), - ) - - async def load_access_token(self, token: str) -> AccessToken | None: - """Load and validate an access token.""" - access_token = self.tokens.get(token) - if not access_token: - return None - - # Check if expired - if access_token.expires_at and access_token.expires_at < time.time(): - del self.tokens[token] - return None - - return access_token - - async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: - """Load a refresh token - not supported in this example.""" - return None - - async def exchange_refresh_token( - self, - client: OAuthClientInformationFull, - refresh_token: RefreshToken, - scopes: list[str], - ) -> OAuthToken: - """Exchange refresh token - not supported in this example.""" - raise NotImplementedError("Refresh tokens not supported") - - async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: - """Revoke a token.""" - if token in self.tokens: - del self.tokens[token] - - async def get_github_user_info(self, mcp_token: str) -> dict[str, Any]: - """Get GitHub user info using MCP token.""" - github_token = self.token_mapping.get(mcp_token) - if not github_token: - raise ValueError("No GitHub token found for MCP token") - - async with create_mcp_http_client() as client: - response = await client.get( - "https://api.github.com/user", - headers={ - "Authorization": f"Bearer {github_token}", - "Accept": "application/vnd.github.v3+json", - }, - ) - - if response.status_code != 200: - raise ValueError(f"GitHub API error: {response.status_code}") - - return response.json() diff --git a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py index 0725ef9ed..b0455c3e8 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py @@ -7,11 +7,9 @@ NOTE: this is a simplified example for demonstration purposes. This is not a production-ready implementation. - -Usage: - python -m mcp_simple_auth.legacy_as_server --port=8002 """ +import datetime import logging from typing import Any, Literal @@ -19,97 +17,91 @@ from pydantic import AnyHttpUrl, BaseModel from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.responses import RedirectResponse, Response +from starlette.responses import Response -from mcp.server.auth.middleware.auth_context import get_access_token from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions from mcp.server.fastmcp.server import FastMCP -from .github_oauth_provider import GitHubOAuthProvider, GitHubOAuthSettings +from .simple_auth_provider import SimpleAuthSettings, SimpleOAuthProvider logger = logging.getLogger(__name__) class ServerSettings(BaseModel): - """Settings for the simple GitHub MCP server.""" + """Settings for the simple auth MCP server.""" # Server settings host: str = "localhost" port: int = 8000 server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000") - github_callback_path: str = "http://localhost:8000/github/callback" + auth_callback_path: str = "http://localhost:8000/login/callback" -class SimpleGitHubOAuthProvider(GitHubOAuthProvider): - """GitHub OAuth provider for legacy MCP server.""" +class LegacySimpleOAuthProvider(SimpleOAuthProvider): + """Simple OAuth provider for legacy MCP server.""" - def __init__(self, github_settings: GitHubOAuthSettings, github_callback_path: str): - super().__init__(github_settings, github_callback_path) + def __init__(self, auth_settings: SimpleAuthSettings, auth_callback_path: str, server_url: str): + super().__init__(auth_settings, auth_callback_path, server_url) -def create_simple_mcp_server(server_settings: ServerSettings, github_settings: GitHubOAuthSettings) -> FastMCP: - """Create a simple FastMCP server with GitHub OAuth.""" - oauth_provider = SimpleGitHubOAuthProvider(github_settings, server_settings.github_callback_path) +def create_simple_mcp_server(server_settings: ServerSettings, auth_settings: SimpleAuthSettings) -> FastMCP: + """Create a simple FastMCP server with simple authentication.""" + oauth_provider = LegacySimpleOAuthProvider( + auth_settings, server_settings.auth_callback_path, str(server_settings.server_url) + ) - auth_settings = AuthSettings( + mcp_auth_settings = AuthSettings( issuer_url=server_settings.server_url, client_registration_options=ClientRegistrationOptions( enabled=True, - valid_scopes=[github_settings.mcp_scope], - default_scopes=[github_settings.mcp_scope], + valid_scopes=[auth_settings.mcp_scope], + default_scopes=[auth_settings.mcp_scope], ), - required_scopes=[github_settings.mcp_scope], + required_scopes=[auth_settings.mcp_scope], # No resource_server_url parameter in legacy mode resource_server_url=None, ) app = FastMCP( - name="Simple GitHub MCP Server", - instructions="A simple MCP server with GitHub OAuth authentication", + name="Simple Auth MCP Server", + instructions="A simple MCP server with simple credential authentication", auth_server_provider=oauth_provider, host=server_settings.host, port=server_settings.port, debug=True, - auth=auth_settings, + auth=mcp_auth_settings, ) - @app.custom_route("/github/callback", methods=["GET"]) - async def github_callback_handler(request: Request) -> Response: - """Handle GitHub OAuth callback.""" - code = request.query_params.get("code") + @app.custom_route("/login", methods=["GET"]) + async def login_page_handler(request: Request) -> Response: + """Show login form.""" state = request.query_params.get("state") + if not state: + raise HTTPException(400, "Missing state parameter") + return await oauth_provider.get_login_page(state) - if not code or not state: - raise HTTPException(400, "Missing code or state parameter") - - redirect_uri = await oauth_provider.handle_github_callback(code, state) - return RedirectResponse(status_code=302, url=redirect_uri) - - def get_github_token() -> str: - """Get the GitHub token for the authenticated user.""" - access_token = get_access_token() - if not access_token: - raise ValueError("Not authenticated") - - # Get GitHub token from mapping - github_token = oauth_provider.token_mapping.get(access_token.token) - - if not github_token: - raise ValueError("No GitHub token found for user") - - return github_token + @app.custom_route("/login/callback", methods=["POST"]) + async def login_callback_handler(request: Request) -> Response: + """Handle simple authentication callback.""" + return await oauth_provider.handle_login_callback(request) @app.tool() - async def get_user_profile() -> dict[str, Any]: - """Get the authenticated user's GitHub profile information. + async def get_time() -> dict[str, Any]: + """ + Get the current server time. - This is the only tool in our simple example. It requires the 'user' scope. + This tool demonstrates that system information can be protected + by OAuth authentication. User must be authenticated to access it. """ - access_token = get_access_token() - if not access_token: - raise ValueError("Not authenticated") - return await oauth_provider.get_github_user_info(access_token.token) + now = datetime.datetime.now() + + return { + "current_time": now.isoformat(), + "timezone": "UTC", # Simplified for demo + "timestamp": now.timestamp(), + "formatted": now.strftime("%Y-%m-%d %H:%M:%S"), + } return app @@ -123,15 +115,10 @@ async def get_user_profile() -> dict[str, Any]: help="Transport protocol to use ('sse' or 'streamable-http')", ) def main(port: int, transport: Literal["sse", "streamable-http"]) -> int: - """Run the simple GitHub MCP server.""" + """Run the simple auth MCP server.""" logging.basicConfig(level=logging.INFO) - # Load GitHub settings from environment variables - github_settings = GitHubOAuthSettings() - - # Validate required fields - if not github_settings.github_client_id or not github_settings.github_client_secret: - raise ValueError("GitHub credentials not provided") + auth_settings = SimpleAuthSettings() # Create server settings host = "localhost" server_url = f"http://{host}:{port}" @@ -139,11 +126,11 @@ def main(port: int, transport: Literal["sse", "streamable-http"]) -> int: host=host, port=port, server_url=AnyHttpUrl(server_url), - github_callback_path=f"{server_url}/github/callback", + auth_callback_path=f"{server_url}/login", ) - mcp_server = create_simple_mcp_server(server_settings, github_settings) - logger.info(f"Starting server with {transport} transport") + mcp_server = create_simple_mcp_server(server_settings, auth_settings) + logger.info(f"🚀 MCP Legacy Server running on {server_url}") mcp_server.run(transport=transport) return 0 diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 898ee7837..a51b6015b 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -4,19 +4,18 @@ This server validates tokens via Authorization Server introspection and serves MCP resources. Demonstrates RFC 9728 Protected Resource Metadata for AS/RS separation. -Usage: - python -m mcp_simple_auth.server --port=8001 --auth-server=http://localhost:9000 +NOTE: this is a simplified example for demonstration purposes. +This is not a production-ready implementation. """ +import datetime import logging from typing import Any, Literal import click -import httpx from pydantic import AnyHttpUrl from pydantic_settings import BaseSettings, SettingsConfigDict -from mcp.server.auth.middleware.auth_context import get_access_token from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp.server import FastMCP @@ -38,7 +37,7 @@ class ResourceServerSettings(BaseSettings): # Authorization Server settings auth_server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000") auth_server_introspection_endpoint: str = "http://localhost:9000/introspect" - auth_server_github_user_endpoint: str = "http://localhost:9000/github/user" + # No user endpoint needed - we get user data from token introspection # MCP settings mcp_scope: str = "user" @@ -83,60 +82,22 @@ def create_resource_server(settings: ResourceServerSettings) -> FastMCP: ), ) - async def get_github_user_data() -> dict[str, Any]: - """ - Get GitHub user data via Authorization Server proxy endpoint. - - This avoids exposing GitHub tokens to the Resource Server. - The Authorization Server handles the GitHub API call and returns the data. - """ - access_token = get_access_token() - if not access_token: - raise ValueError("Not authenticated") - - # Call Authorization Server's GitHub proxy endpoint - async with httpx.AsyncClient() as client: - response = await client.get( - settings.auth_server_github_user_endpoint, - headers={ - "Authorization": f"Bearer {access_token.token}", - }, - ) - - if response.status_code != 200: - raise ValueError(f"GitHub user data fetch failed: {response.status_code} - {response.text}") - - return response.json() - @app.tool() - async def get_user_profile() -> dict[str, Any]: + async def get_time() -> dict[str, Any]: """ - Get the authenticated user's GitHub profile information. + Get the current server time. - This tool requires the 'user' scope and demonstrates how Resource Servers - can access user data without directly handling GitHub tokens. + This tool demonstrates that system information can be protected + by OAuth authentication. User must be authenticated to access it. """ - return await get_github_user_data() - @app.tool() - async def get_user_info() -> dict[str, Any]: - """ - Get information about the currently authenticated user. - - Returns token and scope information from the Resource Server's perspective. - """ - access_token = get_access_token() - if not access_token: - raise ValueError("Not authenticated") + now = datetime.datetime.now() return { - "authenticated": True, - "client_id": access_token.client_id, - "scopes": access_token.scopes, - "token_expires_at": access_token.expires_at, - "token_type": "Bearer", - "resource_server": str(settings.server_url), - "authorization_server": str(settings.auth_server_url), + "current_time": now.isoformat(), + "timezone": "UTC", # Simplified for demo + "timestamp": now.timestamp(), + "formatted": now.strftime("%Y-%m-%d %H:%M:%S"), } return app @@ -182,7 +143,6 @@ def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http server_url=AnyHttpUrl(server_url), auth_server_url=auth_server_url, auth_server_introspection_endpoint=f"{auth_server}/introspect", - auth_server_github_user_endpoint=f"{auth_server}/github/user", oauth_strict=oauth_strict, ) except ValueError as e: @@ -193,24 +153,8 @@ def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http try: mcp_server = create_resource_server(settings) - logger.info("=" * 80) - logger.info("📦 MCP RESOURCE SERVER") - logger.info("=" * 80) - logger.info(f"🌐 Server URL: {settings.server_url}") - logger.info(f"🔑 Authorization Server: {settings.auth_server_url}") - logger.info("📋 Endpoints:") - logger.info(f" ┌─ Protected Resource Metadata: {settings.server_url}/.well-known/oauth-protected-resource") - mcp_path = "sse" if transport == "sse" else "mcp" - logger.info(f" ├─ MCP Protocol: {settings.server_url}/{mcp_path}") - logger.info(f" └─ Token Introspection: {settings.auth_server_introspection_endpoint}") - logger.info("") - logger.info("🛠️ Available Tools:") - logger.info(" ├─ get_user_profile() - Get GitHub user profile") - logger.info(" └─ get_user_info() - Get authentication status") - logger.info("") - logger.info("🔍 Tokens validated via Authorization Server introspection") - logger.info("📱 Clients discover Authorization Server via Protected Resource Metadata") - logger.info("=" * 80) + logger.info(f"🚀 MCP Resource Server running on {settings.server_url}") + logger.info(f"🔑 Using Authorization Server: {settings.auth_server_url}") # Run the server - this should block and keep running mcp_server.run(transport=transport) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py new file mode 100644 index 000000000..9ae189b84 --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -0,0 +1,270 @@ +""" +Simple OAuth provider for MCP servers. + +This module contains a basic OAuth implementation using hardcoded user credentials +for demonstration purposes. No external authentication provider is required. + +NOTE: this is a simplified example for demonstration purposes. +This is not a production-ready implementation. + +""" + +import logging +import secrets +import time +from typing import Any + +from pydantic import AnyHttpUrl +from pydantic_settings import BaseSettings, SettingsConfigDict +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import HTMLResponse, RedirectResponse, Response + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +logger = logging.getLogger(__name__) + + +class SimpleAuthSettings(BaseSettings): + """Simple OAuth settings for demo purposes.""" + + model_config = SettingsConfigDict(env_prefix="MCP_") + + # Demo user credentials + demo_username: str = "demo_user" + demo_password: str = "demo_password" + + # MCP OAuth scope + mcp_scope: str = "user" + + +class SimpleOAuthProvider(OAuthAuthorizationServerProvider): + """ + Simple OAuth provider for demo purposes. + + This provider handles the OAuth flow by: + 1. Providing a simple login form for demo credentials + 2. Issuing MCP tokens after successful authentication + 3. Maintaining token state for introspection + """ + + def __init__(self, settings: SimpleAuthSettings, auth_callback_url: str, server_url: str): + self.settings = settings + self.auth_callback_url = auth_callback_url + self.server_url = server_url + self.clients: dict[str, OAuthClientInformationFull] = {} + self.auth_codes: dict[str, AuthorizationCode] = {} + self.tokens: dict[str, AccessToken] = {} + self.state_mapping: dict[str, dict[str, str | None]] = {} + # Store authenticated user information + self.user_data: dict[str, dict[str, Any]] = {} + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """Get OAuth client information.""" + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull): + """Register a new OAuth client.""" + self.clients[client_info.client_id] = client_info + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + """Generate an authorization URL for simple login flow.""" + state = params.state or secrets.token_hex(16) + + # Store state mapping for callback + self.state_mapping[state] = { + "redirect_uri": str(params.redirect_uri), + "code_challenge": params.code_challenge, + "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), + "client_id": client.client_id, + "resource": params.resource, # RFC 8707 + } + + # Build simple login URL that points to login page + auth_url = f"{self.auth_callback_url}" f"?state={state}" f"&client_id={client.client_id}" + + return auth_url + + async def get_login_page(self, state: str) -> HTMLResponse: + """Generate login page HTML for the given state.""" + if not state: + raise HTTPException(400, "Missing state parameter") + + # Create simple login form HTML + html_content = f""" + + + + MCP Demo Authentication + + + +

MCP Demo Authentication

+

This is a simplified authentication demo. Use the demo credentials below:

+

Username: demo_user
+ Password: demo_password

+ +
+ +
+ + +
+
+ + +
+ +
+ + + """ + + return HTMLResponse(content=html_content) + + async def handle_login_callback(self, request: Request) -> Response: + """Handle login form submission callback.""" + form = await request.form() + username = form.get("username") + password = form.get("password") + state = form.get("state") + + if not username or not password or not state: + raise HTTPException(400, "Missing username, password, or state parameter") + + # Ensure we have strings, not UploadFile objects + if not isinstance(username, str) or not isinstance(password, str) or not isinstance(state, str): + raise HTTPException(400, "Invalid parameter types") + + redirect_uri = await self.handle_simple_callback(username, password, state) + return RedirectResponse(url=redirect_uri, status_code=302) + + async def handle_simple_callback(self, username: str, password: str, state: str) -> str: + """Handle simple authentication callback and return redirect URI.""" + state_data = self.state_mapping.get(state) + if not state_data: + raise HTTPException(400, "Invalid state parameter") + + redirect_uri = state_data["redirect_uri"] + code_challenge = state_data["code_challenge"] + redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" + client_id = state_data["client_id"] + resource = state_data.get("resource") # RFC 8707 + + # These are required values from our own state mapping + assert redirect_uri is not None + assert code_challenge is not None + assert client_id is not None + + # Validate demo credentials + if username != self.settings.demo_username or password != self.settings.demo_password: + raise HTTPException(401, "Invalid credentials") + + # Create MCP authorization code + new_code = f"mcp_{secrets.token_hex(16)}" + auth_code = AuthorizationCode( + code=new_code, + client_id=client_id, + redirect_uri=AnyHttpUrl(redirect_uri), + redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, + expires_at=time.time() + 300, + scopes=[self.settings.mcp_scope], + code_challenge=code_challenge, + resource=resource, # RFC 8707 + ) + self.auth_codes[new_code] = auth_code + + # Store user data + self.user_data[username] = { + "username": username, + "user_id": f"user_{secrets.token_hex(8)}", + "authenticated_at": time.time(), + } + + del self.state_mapping[state] + return construct_redirect_uri(redirect_uri, code=new_code, state=state) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + """Load an authorization code.""" + return self.auth_codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Exchange authorization code for tokens.""" + if authorization_code.code not in self.auth_codes: + raise ValueError("Invalid authorization code") + + # Generate MCP access token + mcp_token = f"mcp_{secrets.token_hex(32)}" + + # Store MCP token + self.tokens[mcp_token] = AccessToken( + token=mcp_token, + client_id=client.client_id, + scopes=authorization_code.scopes, + expires_at=int(time.time()) + 3600, + resource=authorization_code.resource, # RFC 8707 + ) + + # Store user data mapping for this token + self.user_data[mcp_token] = { + "username": self.settings.demo_username, + "user_id": f"user_{secrets.token_hex(8)}", + "authenticated_at": time.time(), + } + + del self.auth_codes[authorization_code.code] + + return OAuthToken( + access_token=mcp_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(authorization_code.scopes), + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + """Load and validate an access token.""" + access_token = self.tokens.get(token) + if not access_token: + return None + + # Check if expired + if access_token.expires_at and access_token.expires_at < time.time(): + del self.tokens[token] + return None + + return access_token + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + """Load a refresh token - not supported in this example.""" + return None + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + """Exchange refresh token - not supported in this example.""" + raise NotImplementedError("Refresh tokens not supported") + + async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: + """Revoke a token.""" + if token in self.tokens: + del self.tokens[token] From 08797d10d3e4c56fa69c42a8e9f542e19ff1f1ac Mon Sep 17 00:00:00 2001 From: ihrpr Date: Tue, 24 Jun 2025 11:03:13 +0100 Subject: [PATCH 2/3] Fix /.well-known/oauth-authorization-server dropping path --- src/mcp/client/auth.py | 14 +++++++++--- tests/client/test_auth.py | 48 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index c174385ea..5f313f84b 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -200,11 +200,19 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> async def _discover_oauth_metadata(self) -> httpx.Request: """Build OAuth metadata discovery request.""" if self.context.auth_server_url: - base_url = self.context.get_authorization_base_url(self.context.auth_server_url) + auth_server_url = self.context.auth_server_url else: - base_url = self.context.get_authorization_base_url(self.context.server_url) + auth_server_url = self.context.server_url - url = urljoin(base_url, "/.well-known/oauth-authorization-server") + # Per RFC 8414, preserve the path component when constructing discovery URL + parsed = urlparse(auth_server_url) + well_known_path = f"/.well-known/oauth-authorization-server{parsed.path}" + if parsed.path.endswith("/"): + # Strip trailing slash from pathname + well_known_path = well_known_path[:-1] + + base_url = f"{parsed.scheme}://{parsed.netloc}" + url = urljoin(base_url, well_known_path) return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 8dee687a9..3825eb595 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -208,10 +208,58 @@ async def test_discover_oauth_metadata_request(self, oauth_provider): """Test OAuth metadata discovery request building.""" request = await oauth_provider._discover_oauth_metadata() + assert request.method == "GET" + assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp" + assert "mcp-protocol-version" in request.headers + + @pytest.mark.anyio + async def test_discover_oauth_metadata_request_no_path(self, client_metadata, mock_storage): + """Test OAuth metadata discovery request building when server has no path.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + request = await provider._discover_oauth_metadata() + assert request.method == "GET" assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server" assert "mcp-protocol-version" in request.headers + @pytest.mark.anyio + async def test_discover_oauth_metadata_request_trailing_slash(self, client_metadata, mock_storage): + """Test OAuth metadata discovery request building when server path has trailing slash.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp/", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + request = await provider._discover_oauth_metadata() + + assert request.method == "GET" + assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp" + assert "mcp-protocol-version" in request.headers + @pytest.mark.anyio async def test_register_client_request(self, oauth_provider): """Test client registration request building.""" From bdc8bd7aca94ea5823a1c681e198630ccc315857 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Tue, 24 Jun 2025 11:56:26 +0100 Subject: [PATCH 3/3] add fallback --- src/mcp/client/auth.py | 80 ++++++++++++++++++++++++------ tests/client/test_auth.py | 101 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 14 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 5f313f84b..359e0585f 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -106,6 +106,10 @@ class OAuthContext: # State lock: anyio.Lock = field(default_factory=anyio.Lock) + # Discovery state for fallback support + discovery_base_url: str | None = None + discovery_pathname: str | None = None + def get_authorization_base_url(self, server_url: str) -> str: """Extract base URL by removing path component.""" parsed = urlparse(server_url) @@ -197,26 +201,53 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> except ValidationError: pass + def _build_well_known_path(self, pathname: str) -> str: + """Construct well-known path for OAuth metadata discovery.""" + well_known_path = f"/.well-known/oauth-authorization-server{pathname}" + if pathname.endswith("/"): + # Strip trailing slash from pathname to avoid double slashes + well_known_path = well_known_path[:-1] + return well_known_path + + def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool: + """Determine if fallback to root discovery should be attempted.""" + return response_status == 404 and pathname != "/" + + async def _try_metadata_discovery(self, url: str) -> httpx.Request: + """Build metadata discovery request for a specific URL.""" + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + async def _discover_oauth_metadata(self) -> httpx.Request: - """Build OAuth metadata discovery request.""" + """Build OAuth metadata discovery request with fallback support.""" if self.context.auth_server_url: auth_server_url = self.context.auth_server_url else: auth_server_url = self.context.server_url - # Per RFC 8414, preserve the path component when constructing discovery URL + # Per RFC 8414, try path-aware discovery first parsed = urlparse(auth_server_url) - well_known_path = f"/.well-known/oauth-authorization-server{parsed.path}" - if parsed.path.endswith("/"): - # Strip trailing slash from pathname - well_known_path = well_known_path[:-1] - + well_known_path = self._build_well_known_path(parsed.path) base_url = f"{parsed.scheme}://{parsed.netloc}" url = urljoin(base_url, well_known_path) - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: - """Handle OAuth metadata response.""" + # Store fallback info for use in response handler + self.context.discovery_base_url = base_url + self.context.discovery_pathname = parsed.path + + return await self._try_metadata_discovery(url) + + async def _discover_oauth_metadata_fallback(self) -> httpx.Request: + """Build fallback OAuth metadata discovery request for legacy servers.""" + base_url = getattr(self.context, "discovery_base_url", "") + if not base_url: + raise OAuthFlowError("No base URL available for fallback discovery") + + # Fallback to root discovery for legacy servers + url = urljoin(base_url, "/.well-known/oauth-authorization-server") + return await self._try_metadata_discovery(url) + + async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool: + """Handle OAuth metadata response. Returns True if handled successfully.""" if response.status_code == 200: try: content = await response.aread() @@ -225,9 +256,18 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non # Apply default scope if none specified if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: self.context.client_metadata.scope = " ".join(metadata.scopes_supported) + return True except ValidationError: pass + # Check if we should attempt fallback (404 on path-aware discovery) + if not is_fallback and self._should_attempt_fallback( + response.status_code, getattr(self.context, "discovery_pathname", "/") + ): + return False # Signal that fallback should be attempted + + return True # Signal no fallback needed (either success or non-404 error) + async def _register_client(self) -> httpx.Request | None: """Build registration request or skip if already registered.""" if self.context.client_info: @@ -426,10 +466,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. discovery_response = yield discovery_request await self._handle_protected_resource_response(discovery_response) - # Step 2: Discover OAuth metadata + # Step 2: Discover OAuth metadata (with fallback for legacy servers) oauth_request = await self._discover_oauth_metadata() oauth_response = yield oauth_request - await self._handle_oauth_metadata_response(oauth_response) + handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) + + # If path-aware discovery failed with 404, try fallback to root + if not handled: + fallback_request = await self._discover_oauth_metadata_fallback() + fallback_response = yield fallback_request + await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) # Step 3: Register client if needed registration_request = await self._register_client() @@ -472,10 +518,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. discovery_response = yield discovery_request await self._handle_protected_resource_response(discovery_response) - # Step 2: Discover OAuth metadata + # Step 2: Discover OAuth metadata (with fallback for legacy servers) oauth_request = await self._discover_oauth_metadata() oauth_response = yield oauth_request - await self._handle_oauth_metadata_response(oauth_response) + handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) + + # If path-aware discovery failed with 404, try fallback to root + if not handled: + fallback_request = await self._discover_oauth_metadata_fallback() + fallback_response = yield fallback_request + await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) # Step 3: Register client if needed registration_request = await self._register_client() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 3825eb595..d87410d00 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -260,6 +260,107 @@ async def callback_handler() -> tuple[str, str | None]: assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp" assert "mcp-protocol-version" in request.headers + +class TestOAuthFallback: + """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers.""" + + @pytest.mark.anyio + async def test_fallback_discovery_request(self, client_metadata, mock_storage): + """Test fallback discovery request building.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Set up discovery state manually as if path-aware discovery was attempted + provider.context.discovery_base_url = "https://api.example.com" + provider.context.discovery_pathname = "/v1/mcp" + + # Test fallback request building + request = await provider._discover_oauth_metadata_fallback() + + assert request.method == "GET" + assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server" + assert "mcp-protocol-version" in request.headers + + @pytest.mark.anyio + async def test_should_attempt_fallback(self, oauth_provider): + """Test fallback decision logic.""" + # Should attempt fallback on 404 with non-root path + assert oauth_provider._should_attempt_fallback(404, "/v1/mcp") + + # Should NOT attempt fallback on 404 with root path + assert not oauth_provider._should_attempt_fallback(404, "/") + + # Should NOT attempt fallback on other status codes + assert not oauth_provider._should_attempt_fallback(200, "/v1/mcp") + assert not oauth_provider._should_attempt_fallback(500, "/v1/mcp") + + @pytest.mark.anyio + async def test_handle_metadata_response_success(self, oauth_provider): + """Test successful metadata response handling.""" + # Create minimal valid OAuth metadata + content = b"""{ + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token" + }""" + response = httpx.Response(200, content=content) + + # Should return True (success) and set metadata + result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) + assert result is True + assert oauth_provider.context.oauth_metadata is not None + assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/" + + @pytest.mark.anyio + async def test_handle_metadata_response_404_needs_fallback(self, oauth_provider): + """Test 404 response handling that should trigger fallback.""" + # Set up discovery state for non-root path + oauth_provider.context.discovery_base_url = "https://api.example.com" + oauth_provider.context.discovery_pathname = "/v1/mcp" + + # Mock 404 response + response = httpx.Response(404) + + # Should return False (needs fallback) + result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) + assert result is False + + @pytest.mark.anyio + async def test_handle_metadata_response_404_no_fallback_needed(self, oauth_provider): + """Test 404 response handling when no fallback is needed.""" + # Set up discovery state for root path + oauth_provider.context.discovery_base_url = "https://api.example.com" + oauth_provider.context.discovery_pathname = "/" + + # Mock 404 response + response = httpx.Response(404) + + # Should return True (no fallback needed) + result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False) + assert result is True + + @pytest.mark.anyio + async def test_handle_metadata_response_404_fallback_attempt(self, oauth_provider): + """Test 404 response handling during fallback attempt.""" + # Mock 404 response during fallback + response = httpx.Response(404) + + # Should return True (fallback attempt complete, no further action needed) + result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=True) + assert result is True + @pytest.mark.anyio async def test_register_client_request(self, oauth_provider): """Test client registration request building."""