|
13 | 13 | OAuthClientInformationFull,
|
14 | 14 | OAuthClientMetadata,
|
15 | 15 | OAuthToken,
|
| 16 | + ProtectedResourceMetadata, |
16 | 17 | )
|
17 | 18 |
|
18 | 19 |
|
@@ -434,6 +435,87 @@ async def test_refresh_token_request(self, oauth_provider, valid_tokens):
|
434 | 435 | assert "client_secret=test_secret" in content
|
435 | 436 |
|
436 | 437 |
|
| 438 | +class TestProtectedResourceMetadata: |
| 439 | + """Test protected resource handling.""" |
| 440 | + |
| 441 | + @pytest.mark.anyio |
| 442 | + async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider): |
| 443 | + """Test resource parameter is included for protocol version >= 2025-06-18.""" |
| 444 | + # Set protocol version to 2025-06-18 |
| 445 | + oauth_provider.context.protocol_version = "2025-06-18" |
| 446 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 447 | + client_id="test_client", |
| 448 | + client_secret="test_secret", |
| 449 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 450 | + ) |
| 451 | + |
| 452 | + # Test in token exchange |
| 453 | + request = await oauth_provider._exchange_token("test_code", "test_verifier") |
| 454 | + content = request.content.decode() |
| 455 | + assert "resource=" in content |
| 456 | + # Check URL-encoded resource parameter |
| 457 | + from urllib.parse import quote |
| 458 | + |
| 459 | + expected_resource = quote(oauth_provider.context.get_resource_url(), safe="") |
| 460 | + assert f"resource={expected_resource}" in content |
| 461 | + |
| 462 | + # Test in refresh token |
| 463 | + oauth_provider.context.current_tokens = OAuthToken( |
| 464 | + access_token="test_access", |
| 465 | + token_type="Bearer", |
| 466 | + refresh_token="test_refresh", |
| 467 | + ) |
| 468 | + refresh_request = await oauth_provider._refresh_token() |
| 469 | + refresh_content = refresh_request.content.decode() |
| 470 | + assert "resource=" in refresh_content |
| 471 | + |
| 472 | + @pytest.mark.anyio |
| 473 | + async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider): |
| 474 | + """Test resource parameter is excluded for protocol version < 2025-06-18.""" |
| 475 | + # Set protocol version to older version |
| 476 | + oauth_provider.context.protocol_version = "2025-03-26" |
| 477 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 478 | + client_id="test_client", |
| 479 | + client_secret="test_secret", |
| 480 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 481 | + ) |
| 482 | + |
| 483 | + # Test in token exchange |
| 484 | + request = await oauth_provider._exchange_token("test_code", "test_verifier") |
| 485 | + content = request.content.decode() |
| 486 | + assert "resource=" not in content |
| 487 | + |
| 488 | + # Test in refresh token |
| 489 | + oauth_provider.context.current_tokens = OAuthToken( |
| 490 | + access_token="test_access", |
| 491 | + token_type="Bearer", |
| 492 | + refresh_token="test_refresh", |
| 493 | + ) |
| 494 | + refresh_request = await oauth_provider._refresh_token() |
| 495 | + refresh_content = refresh_request.content.decode() |
| 496 | + assert "resource=" not in refresh_content |
| 497 | + |
| 498 | + @pytest.mark.anyio |
| 499 | + async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider): |
| 500 | + """Test resource parameter is always included when protected resource metadata exists.""" |
| 501 | + # Set old protocol version but with protected resource metadata |
| 502 | + oauth_provider.context.protocol_version = "2025-03-26" |
| 503 | + oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata( |
| 504 | + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), |
| 505 | + authorization_servers=[AnyHttpUrl("https://api.example.com")], |
| 506 | + ) |
| 507 | + oauth_provider.context.client_info = OAuthClientInformationFull( |
| 508 | + client_id="test_client", |
| 509 | + client_secret="test_secret", |
| 510 | + redirect_uris=[AnyUrl("http://localhost:3030/callback")], |
| 511 | + ) |
| 512 | + |
| 513 | + # Test in token exchange |
| 514 | + request = await oauth_provider._exchange_token("test_code", "test_verifier") |
| 515 | + content = request.content.decode() |
| 516 | + assert "resource=" in content |
| 517 | + |
| 518 | + |
437 | 519 | class TestAuthFlow:
|
438 | 520 | """Test the auth flow in httpx."""
|
439 | 521 |
|
|
0 commit comments