diff --git a/client/client.go b/client/client.go index dd0e31a0..b519fcd5 100644 --- a/client/client.go +++ b/client/client.go @@ -108,7 +108,6 @@ func (c *Client) sendRequest( Method: method, Params: params, } - response, err := c.transport.SendRequest(ctx, request) if err != nil { return nil, fmt.Errorf("transport error: %w", err) diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go index 90fc2fae..f5d46953 100644 --- a/client/transport/inprocess.go +++ b/client/transport/inprocess.go @@ -34,7 +34,7 @@ func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCReq } requestBytes = append(requestBytes, '\n') - respMessage := c.server.HandleMessage(ctx, requestBytes) + respMessage := c.server.HandleMessage(ctx, map[string]string{}, requestBytes) respByte, err := json.Marshal(respMessage) if err != nil { return nil, fmt.Errorf("failed to marshal response message: %w", err) @@ -54,7 +54,7 @@ func (c *InProcessTransport) SendNotification(ctx context.Context, notification return fmt.Errorf("failed to marshal notification: %w", err) } notificationBytes = append(notificationBytes, '\n') - c.server.HandleMessage(ctx, notificationBytes) + c.server.HandleMessage(ctx, map[string]string{}, notificationBytes) return nil } diff --git a/mcp/types.go b/mcp/types.go index 0091d2e4..80013362 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -151,8 +151,9 @@ func (m *Meta) UnmarshalJSON(data []byte) error { } type Request struct { - Method string `json:"method"` - Params RequestParams `json:"params,omitempty"` + Header map[string]string `json:"header"` + Method string `json:"method"` + Params RequestParams `json:"params,omitempty"` } type RequestParams struct { diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 7e4a68a0..2b65748e 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -14,6 +14,7 @@ import ( // HandleMessage processes an incoming JSON-RPC message and returns an appropriate response func (s *MCPServer) HandleMessage( ctx context.Context, + header map[string]string, message json.RawMessage, ) mcp.JSONRPCMessage { // Add server to context @@ -90,6 +91,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request) result, err = s.{{.HandlerFunc}}(ctx, baseMessage.ID, request) } diff --git a/server/request_handler.go b/server/request_handler.go index 25f6ef14..7984ad97 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -13,6 +13,7 @@ import ( // HandleMessage processes an incoming JSON-RPC message and returns an appropriate response func (s *MCPServer) HandleMessage( ctx context.Context, + header map[string]string, message json.RawMessage, ) mcp.JSONRPCMessage { // Add server to context @@ -82,6 +83,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) result, err = s.handleInitialize(ctx, baseMessage.ID, request) } @@ -101,6 +103,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforePing(ctx, baseMessage.ID, &request) result, err = s.handlePing(ctx, baseMessage.ID, request) } @@ -126,6 +129,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request) result, err = s.handleSetLevel(ctx, baseMessage.ID, request) } @@ -151,6 +155,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeListResources(ctx, baseMessage.ID, &request) result, err = s.handleListResources(ctx, baseMessage.ID, request) } @@ -176,6 +181,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request) } @@ -201,6 +207,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) result, err = s.handleReadResource(ctx, baseMessage.ID, request) } @@ -226,6 +233,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) result, err = s.handleListPrompts(ctx, baseMessage.ID, request) } @@ -251,6 +259,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) result, err = s.handleGetPrompt(ctx, baseMessage.ID, request) } @@ -276,6 +285,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeListTools(ctx, baseMessage.ID, &request) result, err = s.handleListTools(ctx, baseMessage.ID, request) } @@ -301,6 +311,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = header s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) result, err = s.handleToolCall(ctx, baseMessage.ID, request) } diff --git a/server/resource_test.go b/server/resource_test.go index 05a3b279..8fa96107 100644 --- a/server/resource_test.go +++ b/server/resource_test.go @@ -59,7 +59,8 @@ func TestMCPServer_RemoveResource(t *testing.T) { ) // First, verify we have two resources - response := server.HandleMessage(context.Background(), []byte(`{ + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(context.Background(), header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "resources/list" @@ -205,9 +206,9 @@ func TestMCPServer_RemoveResource(t *testing.T) { "1.0.0", WithResourceCapabilities(true, true), ) - + header := map[string]string{"Authorization": "Bearer test"} // Initialize the server - _ = server.HandleMessage(ctx, []byte(`{ + _ = server.HandleMessage(ctx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "initialize" @@ -244,7 +245,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { "id": 1, "method": "resources/list" }` - resourcesList := server.HandleMessage(ctx, []byte(listMessage)) + resourcesList := server.HandleMessage(ctx, header, []byte(listMessage)) // Validate the results tt.validate(t, notifications, resourcesList) diff --git a/server/server_test.go b/server/server_test.go index 1c81d18d..cad5533f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -8,6 +8,7 @@ import ( "fmt" "reflect" "sort" + "strings" "testing" "time" @@ -141,8 +142,8 @@ func TestMCPServer_Capabilities(t *testing.T) { } messageBytes, err := json.Marshal(message) assert.NoError(t, err) - - response := server.HandleMessage(context.Background(), messageBytes) + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(context.Background(), header, messageBytes) tt.validate(t, response) }) } @@ -347,7 +348,8 @@ func TestMCPServer_Tools(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) - _ = server.HandleMessage(ctx, []byte(`{ + header := map[string]string{"Authorization": "Bearer test"} + _ = server.HandleMessage(ctx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "initialize" @@ -367,7 +369,7 @@ func TestMCPServer_Tools(t *testing.T) { } } assert.Len(t, notifications, tt.expectedNotifications) - toolsList := server.HandleMessage(ctx, []byte(`{ + toolsList := server.HandleMessage(ctx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "tools/list" @@ -454,8 +456,8 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { t.Run(tt.name, func(t *testing.T) { messageBytes, err := json.Marshal(tt.message) assert.NoError(t, err) - - response := server.HandleMessage(context.Background(), messageBytes) + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(context.Background(), header, messageBytes) assert.NotNil(t, response) tt.validate(t, response) }) @@ -494,8 +496,10 @@ func TestMCPServer_HandlePagination(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + header := map[string]string{"Authorization": "Bearer test"} response := server.HandleMessage( context.Background(), + header, []byte(tt.message), ) tt.validate(t, response) @@ -518,8 +522,8 @@ func TestMCPServer_HandleNotifications(t *testing.T) { "jsonrpc": "2.0", "method": "notifications/initialized" }` - - response := server.HandleMessage(context.Background(), []byte(message)) + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(context.Background(), header, []byte(message)) assert.Nil(t, response) assert.True(t, notificationReceived) } @@ -598,7 +602,8 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) { t.Run(tt.name, func(t *testing.T) { server := NewMCPServer("test-server", "1.0.0") ctx := tt.contextPrepare(context.Background(), server) - _ = server.HandleMessage(ctx, []byte(`{ + header := map[string]string{"Authorization": "Bearer test"} + _ = server.HandleMessage(ctx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "initialize" @@ -678,7 +683,8 @@ func TestMCPServer_SendNotificationToAllClients(t *testing.T) { t.Run("all sessions", func(t *testing.T) { server := NewMCPServer("test-server", "1.0.0") ctx := contextPrepare(context.Background(), server) - _ = server.HandleMessage(ctx, []byte(`{ + header := map[string]string{"Authorization": "Bearer test"} + _ = server.HandleMessage(ctx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "initialize" @@ -800,8 +806,10 @@ func TestMCPServer_PromptHandling(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + header := map[string]string{"Authorization": "Bearer test"} response := server.HandleMessage( context.Background(), + header, []byte(tt.message), ) tt.validate(t, response) @@ -963,11 +971,12 @@ func TestMCPServer_Prompts(t *testing.T) { }, }, } + header := map[string]string{"Authorization": "Bearer test"} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) - _ = server.HandleMessage(ctx, []byte(`{ + _ = server.HandleMessage(ctx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "initialize" @@ -987,7 +996,7 @@ func TestMCPServer_Prompts(t *testing.T) { } } assert.Len(t, notifications, tt.expectedNotifications) - promptsList := server.HandleMessage(ctx, []byte(`{ + promptsList := server.HandleMessage(ctx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "prompts/list" @@ -1052,9 +1061,10 @@ func TestMCPServer_HandleInvalidMessages(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { errs = nil // Reset errors for each test case - + header := map[string]string{"Authorization": "Bearer test"} response := server.HandleMessage( context.Background(), + header, []byte(tt.message), ) assert.NotNil(t, response) @@ -1187,8 +1197,10 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { t.Run(tt.name, func(t *testing.T) { errs = nil // Reset errors for each test case beforeResults = nil + header := map[string]string{"Authorization": "Bearer test"} response := server.HandleMessage( context.Background(), + header, []byte(tt.message), ) assert.NotNil(t, response) @@ -1278,8 +1290,10 @@ func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) { errs = nil // Reset errors for each test case server := NewMCPServer("test-server", "1.0.0", tt.options...) + header := map[string]string{"Authorization": "Bearer test"} response := server.HandleMessage( context.Background(), + header, []byte(tt.message), ) assert.NotNil(t, response) @@ -1366,8 +1380,8 @@ func TestMCPServer_Instructions(t *testing.T) { } messageBytes, err := json.Marshal(message) assert.NoError(t, err) - - response := server.HandleMessage(context.Background(), messageBytes) + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(context.Background(), header, messageBytes) tt.validate(t, response) }) } @@ -1415,8 +1429,10 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { }` t.Run("Get resource template", func(t *testing.T) { + header := map[string]string{"Authorization": "Bearer test"} response := server.HandleMessage( context.Background(), + header, []byte(listMessage), ) assert.NotNil(t, response) @@ -1436,9 +1452,9 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "test://{a}/test-resource{/b*}", resourceTemplate["uriTemplate"]) - response = server.HandleMessage( context.Background(), + header, []byte(message), ) @@ -1626,16 +1642,16 @@ func TestMCPServer_WithHooks(t *testing.T) { return &mcp.CallToolResult{}, nil }, ) - + header := map[string]string{"Authorization": "Bearer test"} // Initialize the server - _ = server.HandleMessage(context.Background(), []byte(`{ + _ = server.HandleMessage(context.Background(), header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "initialize" }`)) // Test 1: Verify ping method hooks - pingResponse := server.HandleMessage(context.Background(), []byte(`{ + pingResponse := server.HandleMessage(context.Background(), header, []byte(`{ "jsonrpc": "2.0", "id": 2, "method": "ping" @@ -1645,7 +1661,7 @@ func TestMCPServer_WithHooks(t *testing.T) { assert.IsType(t, mcp.JSONRPCResponse{}, pingResponse) // Test 2: Verify tools/list method hooks - toolsListResponse := server.HandleMessage(context.Background(), []byte(`{ + toolsListResponse := server.HandleMessage(context.Background(), header, []byte(`{ "jsonrpc": "2.0", "id": 3, "method": "tools/list" @@ -1655,7 +1671,7 @@ func TestMCPServer_WithHooks(t *testing.T) { assert.IsType(t, mcp.JSONRPCResponse{}, toolsListResponse) // Test 3: Verify error hooks with invalid tool - errorResponse := server.HandleMessage(context.Background(), []byte(`{ + errorResponse := server.HandleMessage(context.Background(), header, []byte(`{ "jsonrpc": "2.0", "id": 4, "method": "tools/call", @@ -1807,8 +1823,8 @@ func TestMCPServer_WithRecover(t *testing.T) { mcp.NewTool("panic-tool"), panicToolHandler, ) - - response := server.HandleMessage(context.Background(), []byte(`{ + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(context.Background(), header, []byte(`{ "jsonrpc": "2.0", "id": 4, "method": "tools/call", @@ -2004,7 +2020,7 @@ func TestMCPServer_ProtocolNegotiation(t *testing.T) { messageBytes, err := json.Marshal(initRequest) assert.NoError(t, err) - response := server.HandleMessage(context.Background(), messageBytes) + response := server.HandleMessage(context.Background(), map[string]string{"Authorization": "Bearer test"}, messageBytes) assert.NotNil(t, response) resp, ok := response.(mcp.JSONRPCResponse) @@ -2022,3 +2038,65 @@ func TestMCPServer_ProtocolNegotiation(t *testing.T) { }) } } + +func TestMCPServer_HandleWithHeader(t *testing.T) { + allowedToolNamesCache := map[string]map[string]struct{}{ + "test": { + "tool1": struct{}{}, + "tool2": struct{}{}, + }, + "test2": { + "tool3": struct{}{}, + "tool4": struct{}{}, + }, + } + myOnAfterListToolsFunc := func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) { + token := strings.TrimPrefix(message.Header["Authorization"], "Bearer ") + allowedToolNames := allowedToolNamesCache[token] + allowedTools := []mcp.Tool{} + for _, tool := range result.Tools { + if _, ok := allowedToolNames[tool.Name]; ok { + allowedTools = append(allowedTools, tool) + } + } + result.Tools = allowedTools + } + hooks := &Hooks{} + hooks.AddAfterListTools(myOnAfterListToolsFunc) + server := NewMCPServer( + "test-server", + "1.0.0", + WithHooks(hooks), + ) + server.AddTool(mcp.NewTool("tool1"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + server.AddTool(mcp.NewTool("tool2"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + server.AddTool(mcp.NewTool("tool3"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + server.AddTool(mcp.NewTool("tool4"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + server.AddTool(mcp.NewTool("tool5"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }) + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(context.Background(), header, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }`)) + assert.IsType(t, mcp.JSONRPCResponse{}, response) + toolsListResponse := response.(mcp.JSONRPCResponse) + assert.IsType(t, mcp.ListToolsResult{}, toolsListResponse.Result) + listTools := toolsListResponse.Result.(mcp.ListToolsResult) + assert.Equal(t, 2, len(listTools.Tools)) +} diff --git a/server/session_test.go b/server/session_test.go index 3067f4e9..9a62deb1 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -273,7 +273,8 @@ func TestMCPServer_ToolsWithSessionTools(t *testing.T) { // List tools with session context sessionCtx := server.WithContext(context.Background(), session) - resp := server.HandleMessage(sessionCtx, []byte(`{ + header := map[string]string{"Authorization": "Bearer test"} + resp := server.HandleMessage(sessionCtx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "tools/list" @@ -600,8 +601,8 @@ func TestMCPServer_CallSessionTool(t *testing.T) { if err != nil { t.Fatalf("Failed to marshal tool request: %v", err) } - - response := server.HandleMessage(sessionCtx, requestBytes) + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(sessionCtx, header, requestBytes) resp, ok := response.(mcp.JSONRPCResponse) assert.True(t, ok) @@ -705,7 +706,8 @@ func TestMCPServer_ToolFiltering(t *testing.T) { // List tools with session context sessionCtx := server.WithContext(context.Background(), session) - response := server.HandleMessage(sessionCtx, []byte(`{ + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(sessionCtx, header, []byte(`{ "jsonrpc": "2.0", "id": 1, "method": "tools/list" @@ -1024,8 +1026,8 @@ func TestMCPServer_SetLevelNotEnabled(t *testing.T) { } requestBytes, err := json.Marshal(setRequest) require.NoError(t, err) - - response := server.HandleMessage(sessionCtx, requestBytes) + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(sessionCtx, header, requestBytes) errorResponse, ok := response.(mcp.JSONRPCError) assert.True(t, ok) @@ -1068,8 +1070,8 @@ func TestMCPServer_SetLevel(t *testing.T) { if err != nil { t.Fatalf("Failed to marshal tool request: %v", err) } - - response := server.HandleMessage(sessionCtx, requestBytes) + header := map[string]string{"Authorization": "Bearer test"} + response := server.HandleMessage(sessionCtx, header, requestBytes) resp, ok := response.(mcp.JSONRPCResponse) assert.True(t, ok) diff --git a/server/sse.go b/server/sse.go index 41699573..2bc62b85 100644 --- a/server/sse.go +++ b/server/sse.go @@ -495,7 +495,10 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error") return } - + header := make(map[string]string) + for k, v := range r.Header { + header[k] = v[0] + } // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled. // this is required because the http ctx will be canceled when the client disconnects detachedCtx := context.WithoutCancel(ctx) @@ -510,7 +513,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { defer cancel() // Use the context that will be canceled when session is done // Process message through MCPServer - response := s.server.HandleMessage(ctx, rawMessage) + response := s.server.HandleMessage(ctx, header, rawMessage) // Only send response if there is one (not for notifications) if response != nil { var message string diff --git a/server/stdio.go b/server/stdio.go index 746a7d96..9764756d 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -257,7 +257,7 @@ func (s *StdioServer) processMessage( } // Handle the message using the wrapped server - response := s.server.HandleMessage(ctx, rawMessage) + response := s.server.HandleMessage(ctx, map[string]string{}, rawMessage) // Only write response if there is one (not for notifications) if response != nil { diff --git a/server/streamable_http.go b/server/streamable_http.go index 1312c975..b0fe3d0b 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -309,9 +309,12 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } } }() - + header := make(map[string]string) + for k, v := range r.Header { + header[k] = v[0] + } // Process message through MCPServer - response := s.server.HandleMessage(ctx, rawData) + response := s.server.HandleMessage(ctx, header, rawData) if response == nil { // For notifications, just send 202 Accepted with no body w.WriteHeader(http.StatusAccepted)