From 9281b4a6ec7b704554bc9ae26d5c5ba1d9657194 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 20 Apr 2025 15:06:18 +0300 Subject: [PATCH 01/13] Add session specific methods for tools --- server/errors.go | 23 +++ server/server.go | 214 +++++++++------------- server/session.go | 217 ++++++++++++++++++++++ server/session_test.go | 401 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 728 insertions(+), 127 deletions(-) create mode 100644 server/errors.go create mode 100644 server/session.go create mode 100644 server/session_test.go diff --git a/server/errors.go b/server/errors.go new file mode 100644 index 00000000..8f892ac2 --- /dev/null +++ b/server/errors.go @@ -0,0 +1,23 @@ +package server + +import ( + "errors" +) + +var ( + // Common server errors + ErrUnsupported = errors.New("not supported") + ErrResourceNotFound = errors.New("resource not found") + ErrPromptNotFound = errors.New("prompt not found") + ErrToolNotFound = errors.New("tool not found") + + // Session-related errors + ErrSessionNotFound = errors.New("session not found") + ErrSessionExists = errors.New("session already exists") + ErrSessionNotInitialized = errors.New("session not properly initialized") + ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + + // Notification-related errors + ErrNotificationNotInitialized = errors.New("notification channel not initialized") + ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") +) \ No newline at end of file diff --git a/server/server.go b/server/server.go index d121f9b3..3838f810 100644 --- a/server/server.go +++ b/server/server.go @@ -5,7 +5,6 @@ import ( "context" "encoding/base64" "encoding/json" - "errors" "fmt" "reflect" "sort" @@ -44,31 +43,22 @@ type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mc // ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc +// ToolFilterFunc is a function that filters tools based on context, typically using session information. +type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool + // ServerTool combines a Tool with its ToolHandlerFunc. type ServerTool struct { Tool mcp.Tool Handler ToolHandlerFunc } -// ClientSession represents an active session that can be used by MCPServer to interact with client. -type ClientSession interface { - // Initialize marks session as fully initialized and ready for notifications - Initialize() - // Initialized returns if session is ready to accept notifications - Initialized() bool - // NotificationChannel provides a channel suitable for sending notifications to client. - NotificationChannel() chan<- mcp.JSONRPCNotification - // SessionID is a unique identifier used to track user session. - SessionID() string -} - -// clientSessionKey is the context key for storing current client notification channel. -type clientSessionKey struct{} +// serverKey is the context key for storing the server instance +type serverKey struct{} -// ClientSessionFromContext retrieves current client notification context from context. -func ClientSessionFromContext(ctx context.Context) ClientSession { - if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { - return session +// ServerFromContext retrieves the MCPServer instance from a context +func ServerFromContext(ctx context.Context) *MCPServer { + if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { + return srv } return nil } @@ -128,13 +118,6 @@ func (e *requestError) Unwrap() error { return e.err } -var ( - ErrUnsupported = errors.New("not supported") - ErrResourceNotFound = errors.New("resource not found") - ErrPromptNotFound = errors.New("prompt not found") - ErrToolNotFound = errors.New("tool not found") -) - // NotificationHandlerFunc handles incoming notifications. type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification) @@ -148,6 +131,7 @@ type MCPServer struct { middlewareMu sync.RWMutex notificationHandlersMu sync.RWMutex capabilitiesMu sync.RWMutex + toolFiltersMu sync.RWMutex name string version string @@ -158,6 +142,7 @@ type MCPServer struct { promptHandlers map[string]PromptHandlerFunc tools map[string]ServerTool toolHandlerMiddlewares []ToolHandlerMiddleware + toolFilters []ToolFilterFunc notificationHandlers map[string]NotificationHandlerFunc capabilities serverCapabilities paginationLimit *int @@ -165,17 +150,6 @@ type MCPServer struct { hooks *Hooks } -// serverKey is the context key for storing the server instance -type serverKey struct{} - -// ServerFromContext retrieves the MCPServer instance from a context -func ServerFromContext(ctx context.Context) *MCPServer { - if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { - return srv - } - return nil -} - // WithPaginationLimit sets the pagination limit for the server. func WithPaginationLimit(limit int) ServerOption { return func(s *MCPServer) { @@ -183,92 +157,6 @@ func WithPaginationLimit(limit int) ServerOption { } } -// WithContext sets the current client session and returns the provided context -func (s *MCPServer) WithContext( - ctx context.Context, - session ClientSession, -) context.Context { - return context.WithValue(ctx, clientSessionKey{}, session) -} - -// RegisterSession saves session that should be notified in case if some server attributes changed. -func (s *MCPServer) RegisterSession( - ctx context.Context, - session ClientSession, -) error { - sessionID := session.SessionID() - if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { - return fmt.Errorf("session %s is already registered", sessionID) - } - s.hooks.RegisterSession(ctx, session) - return nil -} - -// UnregisterSession removes from storage session that is shut down. -func (s *MCPServer) UnregisterSession( - ctx context.Context, - sessionID string, -) { - session, _ := s.sessions.LoadAndDelete(sessionID) - s.hooks.UnregisterSession(ctx, session.(ClientSession)) -} - -// SendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) SendNotificationToAllClients( - method string, - params map[string]any, -) { - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - - s.sessions.Range(func(k, v any) bool { - if session, ok := v.(ClientSession); ok && session.Initialized() { - select { - case session.NotificationChannel() <- notification: - default: - // TODO: log blocked channel in the future versions - } - } - return true - }) -} - -// SendNotificationToClient sends a notification to the current client -func (s *MCPServer) SendNotificationToClient( - ctx context.Context, - method string, - params map[string]any, -) error { - session := ClientSessionFromContext(ctx) - if session == nil || !session.Initialized() { - return fmt.Errorf("notification channel not initialized") - } - - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - - select { - case session.NotificationChannel() <- notification: - return nil - default: - return fmt.Errorf("notification channel full or blocked") - } -} - // serverCapabilities defines the supported features of the MCP server type serverCapabilities struct { tools *toolCapabilities @@ -316,6 +204,17 @@ func WithToolHandlerMiddleware( } } +// WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools +func WithToolFilter( + toolFilter ToolFilterFunc, +) ServerOption { + return func(s *MCPServer) { + s.toolFiltersMu.Lock() + s.toolFilters = append(s.toolFilters, toolFilter) + s.toolFiltersMu.Unlock() + } +} + // WithRecovery adds a middleware that recovers from panics in tool handlers. func WithRecovery() ServerOption { return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { @@ -814,6 +713,7 @@ func (s *MCPServer) handleListTools( id interface{}, request mcp.ListToolsRequest, ) (*mcp.ListToolsResult, *requestError) { + // Get the base tools from the server s.toolsMu.RLock() tools := make([]mcp.Tool, 0, len(s.tools)) @@ -832,6 +732,49 @@ func (s *MCPServer) handleListTools( } s.toolsMu.RUnlock() + // Check if there are session-specific tools + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + // Override or add session-specific tools + // We need to create a map first to merge the tools properly + toolMap := make(map[string]mcp.Tool) + + // Add global tools first + for _, tool := range tools { + toolMap[tool.Name] = tool + } + + // Then override with session-specific tools + for name, serverTool := range sessionTools { + toolMap[name] = serverTool.Tool + } + + // Convert back to slice + tools = make([]mcp.Tool, 0, len(toolMap)) + for _, tool := range toolMap { + tools = append(tools, tool) + } + + // Sort again to maintain consistent ordering + sort.Slice(tools, func(i, j int) bool { + return tools[i].Name < tools[j].Name + }) + } + } + } + + // Apply tool filters if any are defined + s.toolFiltersMu.RLock() + if len(s.toolFilters) > 0 { + for _, filter := range s.toolFilters { + tools = filter(ctx, tools) + } + } + s.toolFiltersMu.RUnlock() + + // Apply pagination toolsToReturn, nextCursor, err := listByPagination[mcp.Tool](ctx, s, request.Params.Cursor, tools) if err != nil { return nil, &requestError{ @@ -840,6 +783,7 @@ func (s *MCPServer) handleListTools( err: err, } } + result := mcp.ListToolsResult{ Tools: toolsToReturn, PaginatedResult: mcp.PaginatedResult{ @@ -854,9 +798,25 @@ func (s *MCPServer) handleToolCall( id interface{}, request mcp.CallToolRequest, ) (*mcp.CallToolResult, *requestError) { - s.toolsMu.RLock() - tool, ok := s.tools[request.Params.Name] - s.toolsMu.RUnlock() + // First check session-specific tools + var tool ServerTool + var ok bool + + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + tool, ok = sessionTools[request.Params.Name] + } + } + } + + // If not found in session tools, check global tools + if !ok { + s.toolsMu.RLock() + tool, ok = s.tools[request.Params.Name] + s.toolsMu.RUnlock() + } if !ok { return nil, &requestError{ @@ -928,4 +888,4 @@ func createErrorResponse( Message: message, }, } -} +} \ No newline at end of file diff --git a/server/session.go b/server/session.go new file mode 100644 index 00000000..993e1319 --- /dev/null +++ b/server/session.go @@ -0,0 +1,217 @@ +package server + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// ClientSession represents an active session that can be used by MCPServer to interact with client. +type ClientSession interface { + // Initialize marks session as fully initialized and ready for notifications + Initialize() + // Initialized returns if session is ready to accept notifications + Initialized() bool + // NotificationChannel provides a channel suitable for sending notifications to client. + NotificationChannel() chan<- mcp.JSONRPCNotification + // SessionID is a unique identifier used to track user session. + SessionID() string +} + +// SessionWithTools is an extension of ClientSession that can store session-specific tool data +type SessionWithTools interface { + ClientSession + // GetSessionTools returns the tools specific to this session, if any + GetSessionTools() map[string]ServerTool + // SetSessionTools sets tools specific to this session + SetSessionTools(tools map[string]ServerTool) +} + +// clientSessionKey is the context key for storing current client notification channel. +type clientSessionKey struct{} + +// ClientSessionFromContext retrieves current client notification context from context. +func ClientSessionFromContext(ctx context.Context) ClientSession { + if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { + return session + } + return nil +} + +// WithContext sets the current client session and returns the provided context +func (s *MCPServer) WithContext( + ctx context.Context, + session ClientSession, +) context.Context { + return context.WithValue(ctx, clientSessionKey{}, session) +} + +// RegisterSession saves session that should be notified in case if some server attributes changed. +func (s *MCPServer) RegisterSession( + ctx context.Context, + session ClientSession, +) error { + sessionID := session.SessionID() + if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + return ErrSessionExists + } + s.hooks.RegisterSession(ctx, session) + return nil +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + ctx context.Context, + sessionID string, +) { + session, _ := s.sessions.LoadAndDelete(sessionID) + s.hooks.UnregisterSession(ctx, session.(ClientSession)) +} + +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( + method string, + params map[string]any, +) { + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + s.sessions.Range(func(k, v any) bool { + if session, ok := v.(ClientSession); ok && session.Initialized() { + select { + case session.NotificationChannel() <- notification: + default: + // TODO: log blocked channel in the future versions + } + } + return true + }) +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + select { + case session.NotificationChannel() <- notification: + return nil + default: + return ErrNotificationChannelBlocked + } +} + +// SendNotificationToSpecificClient sends a notification to a specific client by session ID +func (s *MCPServer) SendNotificationToSpecificClient( + sessionID string, + method string, + params map[string]any, +) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + select { + case session.NotificationChannel() <- notification: + return nil + default: + return ErrNotificationChannelBlocked + } +} + +// AddSessionTools adds tools for a specific session +func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + sessionTools := session.GetSessionTools() + if sessionTools == nil { + sessionTools = make(map[string]ServerTool) + } + + for _, tool := range tools { + sessionTools[tool.Tool.Name] = tool + } + + session.SetSessionTools(sessionTools) + + // Send notification only to this session + s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) + + return nil +} + +// DeleteSessionTools removes tools from a specific session +func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + sessionTools := session.GetSessionTools() + if sessionTools == nil { + return nil + } + + for _, name := range names { + delete(sessionTools, name) + } + + session.SetSessionTools(sessionTools) + + // Send notification only to this session + s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) + + return nil +} \ No newline at end of file diff --git a/server/session_test.go b/server/session_test.go new file mode 100644 index 00000000..e79ce995 --- /dev/null +++ b/server/session_test.go @@ -0,0 +1,401 @@ +package server + +import ( + "context" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// sessionTestClient implements the basic ClientSession interface for testing +type sessionTestClient struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool +} + +func (f sessionTestClient) SessionID() string { + return f.sessionID +} + +func (f sessionTestClient) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f sessionTestClient) Initialize() { +} + +func (f sessionTestClient) Initialized() bool { + return f.initialized +} + +// fakeSessionWithTools implements the SessionWithTools interface for testing +type fakeSessionWithTools struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool + sessionTools map[string]ServerTool +} + +func (f *fakeSessionWithTools) SessionID() string { + return f.sessionID +} + +func (f *fakeSessionWithTools) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *fakeSessionWithTools) Initialize() { + f.initialized = true +} + +func (f *fakeSessionWithTools) Initialized() bool { + return f.initialized +} + +func (f *fakeSessionWithTools) GetSessionTools() map[string]ServerTool { + return f.sessionTools +} + +func (f *fakeSessionWithTools) SetSessionTools(tools map[string]ServerTool) { + f.sessionTools = tools +} + +// Verify that both implementations satisfy their respective interfaces +var _ ClientSession = sessionTestClient{} +var _ SessionWithTools = &fakeSessionWithTools{} + +func TestSessionWithTools_Integration(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + // Create session-specific tools + sessionTool := ServerTool{ + Tool: mcp.NewTool("session-tool"), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("session-tool result"), nil + }, + } + + // Create a session with tools + session := &fakeSessionWithTools{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + sessionTools: map[string]ServerTool{ + "session-tool": sessionTool, + }, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Test that we can access the session-specific tool + testReq := mcp.CallToolRequest{} + testReq.Params.Name = "session-tool" + testReq.Params.Arguments = map[string]interface{}{} + + // Call using session context + sessionCtx := server.WithContext(context.Background(), session) + + // Check if the session was stored in the context correctly + s := ClientSessionFromContext(sessionCtx) + require.NotNil(t, s, "Session should be available from context") + assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match") + + // Check if the session can be cast to SessionWithTools + swt, ok := s.(SessionWithTools) + require.True(t, ok, "Session should implement SessionWithTools") + + // Check if the tools are accessible + tools := swt.GetSessionTools() + require.NotNil(t, tools, "Session tools should be available") + require.Contains(t, tools, "session-tool", "Session should have session-tool") + + // Test session tool access with session context + t.Run("test session tool access", func(t *testing.T) { + // First test directly getting the tool from session tools + tool, exists := tools["session-tool"] + require.True(t, exists, "Session tool should exist in the map") + require.NotNil(t, tool, "Session tool should not be nil") + + // Now test calling directly with the handler + result, err := tool.Handler(sessionCtx, testReq) + require.NoError(t, err, "No error calling session tool handler directly") + require.NotNil(t, result, "Result should not be nil") + require.Len(t, result.Content, 1, "Result should have one content item") + + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "Content should be TextContent") + assert.Equal(t, "session-tool result", textContent.Text, "Result text should match") + }) +} + +func TestMCPServer_ToolsWithSessionTools(t *testing.T) { + // Basic test to verify that session-specific tools are returned correctly in a tools list + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + // Add global tools + server.AddTools( + ServerTool{Tool: mcp.NewTool("global-tool-1")}, + ServerTool{Tool: mcp.NewTool("global-tool-2")}, + ) + + // Create a session with tools + session := &fakeSessionWithTools{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + sessionTools: map[string]ServerTool{ + "session-tool-1": {Tool: mcp.NewTool("session-tool-1")}, + "global-tool-1": {Tool: mcp.NewTool("global-tool-1", mcp.WithDescription("Overridden"))}, + }, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // List tools with session context + sessionCtx := server.WithContext(context.Background(), session) + resp := server.HandleMessage(sessionCtx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }`)) + + jsonResp, ok := resp.(mcp.JSONRPCResponse) + require.True(t, ok, "Response should be a JSONRPCResponse") + + result, ok := jsonResp.Result.(mcp.ListToolsResult) + require.True(t, ok, "Result should be a ListToolsResult") + + // Should have 3 tools - 2 global tools (one overridden) and 1 session-specific tool + assert.Len(t, result.Tools, 3, "Should have 3 tools") + + // Find the overridden tool and verify its description + var found bool + for _, tool := range result.Tools { + if tool.Name == "global-tool-1" { + assert.Equal(t, "Overridden", tool.Description, "Global tool should be overridden") + found = true + break + } + } + assert.True(t, found, "Should find the overridden global tool") +} + +func TestMCPServer_AddSessionTools(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &fakeSessionWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific tools + err = server.AddSessionTools(session.SessionID(), + ServerTool{Tool: mcp.NewTool("session-tool")}, + ) + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/tools/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify tool was added to session + assert.Len(t, session.GetSessionTools(), 1) + assert.Contains(t, session.GetSessionTools(), "session-tool") +} + +func TestMCPServer_DeleteSessionTools(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + ctx := context.Background() + + // Create a session with tools + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &fakeSessionWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + sessionTools: map[string]ServerTool{ + "session-tool-1": { + Tool: mcp.NewTool("session-tool-1"), + }, + "session-tool-2": { + Tool: mcp.NewTool("session-tool-2"), + }, + }, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Delete one of the session tools + err = server.DeleteSessionTools(session.SessionID(), "session-tool-1") + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/tools/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify tool was removed from session + assert.Len(t, session.GetSessionTools(), 1) + assert.NotContains(t, session.GetSessionTools(), "session-tool-1") + assert.Contains(t, session.GetSessionTools(), "session-tool-2") +} + +func TestMCPServer_ToolFiltering(t *testing.T) { + // Create a filter that filters tools by prefix + filterByPrefix := func(prefix string) ToolFilterFunc { + return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool { + var filtered []mcp.Tool + for _, tool := range tools { + if len(tool.Name) >= len(prefix) && tool.Name[:len(prefix)] == prefix { + filtered = append(filtered, tool) + } + } + return filtered + } + } + + // Create a server with a tool filter + server := NewMCPServer("test-server", "1.0.0", + WithToolCapabilities(true), + WithToolFilter(filterByPrefix("allow-")), + ) + + // Add tools with different prefixes + server.AddTools( + ServerTool{Tool: mcp.NewTool("allow-tool-1")}, + ServerTool{Tool: mcp.NewTool("allow-tool-2")}, + ServerTool{Tool: mcp.NewTool("deny-tool-1")}, + ServerTool{Tool: mcp.NewTool("deny-tool-2")}, + ) + + // Create a session with tools + session := &fakeSessionWithTools{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + sessionTools: map[string]ServerTool{ + "allow-session-tool": { + Tool: mcp.NewTool("allow-session-tool"), + }, + "deny-session-tool": { + Tool: mcp.NewTool("deny-session-tool"), + }, + }, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // List tools with session context + sessionCtx := server.WithContext(context.Background(), session) + response := server.HandleMessage(sessionCtx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }`)) + resp, ok := response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok := resp.Result.(mcp.ListToolsResult) + require.True(t, ok) + + // Should only include tools with the "allow-" prefix + assert.Len(t, result.Tools, 3) + + // Verify all tools start with "allow-" + for _, tool := range result.Tools { + assert.True(t, len(tool.Name) >= 6 && tool.Name[:6] == "allow-", + "Tool should start with 'allow-', got: %s", tool.Name) + } +} + +func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + session1Chan := make(chan mcp.JSONRPCNotification, 10) + session1 := &fakeSession{ + sessionID: "session-1", + notificationChannel: session1Chan, + initialized: true, + } + + session2Chan := make(chan mcp.JSONRPCNotification, 10) + session2 := &fakeSession{ + sessionID: "session-2", + notificationChannel: session2Chan, + initialized: true, + } + + session3 := &fakeSession{ + sessionID: "session-3", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, // Not initialized + } + + // Register sessions + err := server.RegisterSession(context.Background(), session1) + require.NoError(t, err) + err = server.RegisterSession(context.Background(), session2) + require.NoError(t, err) + err = server.RegisterSession(context.Background(), session3) + require.NoError(t, err) + + // Send notification to session 1 + err = server.SendNotificationToSpecificClient(session1.SessionID(), "test-method", map[string]any{ + "data": "test-data", + }) + require.NoError(t, err) + + // Check that only session 1 received the notification + select { + case notification := <-session1Chan: + assert.Equal(t, "test-method", notification.Method) + assert.Equal(t, "test-data", notification.Params.AdditionalFields["data"]) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received by session 1") + } + + // Verify session 2 did not receive notification + select { + case notification := <-session2Chan: + t.Errorf("Unexpected notification received by session 2: %v", notification) + case <-time.After(100 * time.Millisecond): + // Expected, no notification for session 2 + } + + // Test sending to non-existent session + err = server.SendNotificationToSpecificClient("non-existent", "test-method", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + + // Test sending to uninitialized session + err = server.SendNotificationToSpecificClient(session3.SessionID(), "test-method", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not properly initialized") +} \ No newline at end of file From 08ff4145fdc82671d9ad4d0f03402f0c0119d1da Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 20 Apr 2025 15:42:37 +0300 Subject: [PATCH 02/13] Add session managment docs to README --- README.md | 196 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/README.md b/README.md index 332ab3dd..2429ee72 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,10 @@ MCP Go handles all the complex protocol details and server management, so you ca - [Tools](#tools) - [Prompts](#prompts) - [Examples](#examples) +- [Extras](#extras) + - [Session Management](#session-management) + - [Request Hooks](#request-hooks) + - [Tool Handler Middleware](#tool-handler-middleware) - [Contributing](#contributing) - [Prerequisites](#prerequisites) - [Installation](#installation-1) @@ -516,6 +520,198 @@ For examples, see the `examples/` directory. ## Extras +### Session Management + +MCP-Go provides a robust session management system that allows you to: +- Maintain separate state for each connected client +- Register and track client sessions +- Send notifications to specific clients +- Provide per-session tool customization + +
+Show Session Management Examples + +#### Basic Session Handling + +```go +// Create a server with session capabilities +s := server.NewMCPServer( + "Session Demo", + "1.0.0", + server.WithToolCapabilities(true), +) + +// Implement your own ClientSession +type MySession struct { + id string + notifChannel chan mcp.JSONRPCNotification + isInitialized bool + // Add custom fields for your application +} + +// Implement the ClientSession interface +func (s *MySession) SessionID() string { + return s.id +} + +func (s *MySession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifChannel +} + +func (s *MySession) Initialize() { + s.isInitialized = true +} + +func (s *MySession) Initialized() bool { + return s.isInitialized +} + +// Register a session +session := &MySession{ + id: "user-123", + notifChannel: make(chan mcp.JSONRPCNotification, 10), +} +if err := s.RegisterSession(context.Background(), session); err != nil { + log.Printf("Failed to register session: %v", err) +} + +// Send notification to a specific client +err := s.SendNotificationToSpecificClient( + session.SessionID(), + "notification/update", + map[string]any{"message": "New data available!"}, +) +if err != nil { + log.Printf("Failed to send notification: %v", err) +} + +// Unregister session when done +s.UnregisterSession(context.Background(), session.SessionID()) +``` + +#### Per-Session Tools + +For more advanced use cases, you can implement the `SessionWithTools` interface to support per-session tool customization: + +```go +// Implement SessionWithTools interface for per-session tools +type MyAdvancedSession struct { + MySession // Embed the basic session + sessionTools map[string]server.ServerTool +} + +// Implement additional methods for SessionWithTools +func (s *MyAdvancedSession) GetSessionTools() map[string]server.ServerTool { + return s.sessionTools +} + +func (s *MyAdvancedSession) SetSessionTools(tools map[string]server.ServerTool) { + s.sessionTools = tools +} + +// Create and register a session with tools support +advSession := &MyAdvancedSession{ + MySession: MySession{ + id: "user-456", + notifChannel: make(chan mcp.JSONRPCNotification, 10), + }, + sessionTools: make(map[string]server.ServerTool), +} +if err := s.RegisterSession(context.Background(), advSession); err != nil { + log.Printf("Failed to register session: %v", err) +} + +// Add session-specific tools +userSpecificTool := mcp.NewTool( + "user_data", + mcp.WithDescription("Access user-specific data"), +) +err := s.AddSessionTools( + advSession.SessionID(), + server.ServerTool{ + Tool: userSpecificTool, + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // This handler is only available to this specific session + return mcp.NewToolResultText("User-specific data for " + advSession.SessionID()), nil + }, + }, +) +if err != nil { + log.Printf("Failed to add session tool: %v", err) +} + +// Delete session-specific tools when no longer needed +err = s.DeleteSessionTools(advSession.SessionID(), "user_data") +if err != nil { + log.Printf("Failed to delete session tool: %v", err) +} +``` + +#### Tool Filtering + +You can also apply filters to control which tools are available to certain sessions: + +```go +// Add a tool filter that only shows tools with certain prefixes +s := server.NewMCPServer( + "Tool Filtering Demo", + "1.0.0", + server.WithToolCapabilities(true), + server.WithToolFilter(func(ctx context.Context, tools []mcp.Tool) []mcp.Tool { + // Get session from context + session := server.ClientSessionFromContext(ctx) + if session == nil { + return tools // Return all tools if no session + } + + // Example: filter tools based on session ID prefix + if strings.HasPrefix(session.SessionID(), "admin-") { + // Admin users get all tools + return tools + } else { + // Regular users only get tools with "public-" prefix + var filteredTools []mcp.Tool + for _, tool := range tools { + if strings.HasPrefix(tool.Name, "public-") { + filteredTools = append(filteredTools, tool) + } + } + return filteredTools + } + }), +) +``` + +#### Working with Context + +The session context is automatically passed to tool and resource handlers: + +```go +s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Get the current session from context + session := server.ClientSessionFromContext(ctx) + if session == nil { + return mcp.NewToolResultError("No active session"), nil + } + + return mcp.NewToolResultText("Hello, session " + session.SessionID()), nil +}) + +// When using handlers in HTTP/SSE servers, you need to pass the context with the session +httpHandler := func(w http.ResponseWriter, r *http.Request) { + // Get session from somewhere (like a cookie or header) + session := getSessionFromRequest(r) + + // Add session to context + ctx := s.WithContext(r.Context(), session) + + // Use this context when handling requests + // ... +} +``` + +
+ ### Request Hooks Hook into the request lifecycle by creating a `Hooks` object with your From b6db1849cd86e2eeaaa123e34bb57b31c89c7f44 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 20 Apr 2025 15:50:05 +0300 Subject: [PATCH 03/13] Formatting --- server/errors.go | 14 ++++++------ server/server.go | 16 ++++++------- server/session.go | 6 ++--- server/session_test.go | 52 +++++++++++++++++++++--------------------- 4 files changed, 44 insertions(+), 44 deletions(-) diff --git a/server/errors.go b/server/errors.go index 8f892ac2..7ced5cf7 100644 --- a/server/errors.go +++ b/server/errors.go @@ -6,18 +6,18 @@ import ( var ( // Common server errors - ErrUnsupported = errors.New("not supported") - ErrResourceNotFound = errors.New("resource not found") - ErrPromptNotFound = errors.New("prompt not found") - ErrToolNotFound = errors.New("tool not found") - + ErrUnsupported = errors.New("not supported") + ErrResourceNotFound = errors.New("resource not found") + ErrPromptNotFound = errors.New("prompt not found") + ErrToolNotFound = errors.New("tool not found") + // Session-related errors ErrSessionNotFound = errors.New("session not found") ErrSessionExists = errors.New("session already exists") ErrSessionNotInitialized = errors.New("session not properly initialized") ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") - + // Notification-related errors ErrNotificationNotInitialized = errors.New("notification channel not initialized") ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") -) \ No newline at end of file +) diff --git a/server/server.go b/server/server.go index 3838f810..c9e9a006 100644 --- a/server/server.go +++ b/server/server.go @@ -740,23 +740,23 @@ func (s *MCPServer) handleListTools( // Override or add session-specific tools // We need to create a map first to merge the tools properly toolMap := make(map[string]mcp.Tool) - + // Add global tools first for _, tool := range tools { toolMap[tool.Name] = tool } - + // Then override with session-specific tools for name, serverTool := range sessionTools { toolMap[name] = serverTool.Tool } - + // Convert back to slice tools = make([]mcp.Tool, 0, len(toolMap)) for _, tool := range toolMap { tools = append(tools, tool) } - + // Sort again to maintain consistent ordering sort.Slice(tools, func(i, j int) bool { return tools[i].Name < tools[j].Name @@ -783,7 +783,7 @@ func (s *MCPServer) handleListTools( err: err, } } - + result := mcp.ListToolsResult{ Tools: toolsToReturn, PaginatedResult: mcp.PaginatedResult{ @@ -801,7 +801,7 @@ func (s *MCPServer) handleToolCall( // First check session-specific tools var tool ServerTool var ok bool - + session := ClientSessionFromContext(ctx) if session != nil { if sessionWithTools, ok := session.(SessionWithTools); ok { @@ -810,7 +810,7 @@ func (s *MCPServer) handleToolCall( } } } - + // If not found in session tools, check global tools if !ok { s.toolsMu.RLock() @@ -888,4 +888,4 @@ func createErrorResponse( Message: message, }, } -} \ No newline at end of file +} diff --git a/server/session.go b/server/session.go index 993e1319..4e8545d7 100644 --- a/server/session.go +++ b/server/session.go @@ -183,7 +183,7 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error // Send notification only to this session s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) - + return nil } @@ -212,6 +212,6 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error // Send notification only to this session s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) - + return nil -} \ No newline at end of file +} diff --git a/server/session_test.go b/server/session_test.go index e79ce995..4181fa11 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -109,7 +109,7 @@ func TestSessionWithTools_Integration(t *testing.T) { // Check if the session can be cast to SessionWithTools swt, ok := s.(SessionWithTools) require.True(t, ok, "Session should implement SessionWithTools") - + // Check if the tools are accessible tools := swt.GetSessionTools() require.NotNil(t, tools, "Session tools should be available") @@ -121,13 +121,13 @@ func TestSessionWithTools_Integration(t *testing.T) { tool, exists := tools["session-tool"] require.True(t, exists, "Session tool should exist in the map") require.NotNil(t, tool, "Session tool should not be nil") - + // Now test calling directly with the handler result, err := tool.Handler(sessionCtx, testReq) require.NoError(t, err, "No error calling session tool handler directly") require.NotNil(t, result, "Result should not be nil") require.Len(t, result.Content, 1, "Result should have one content item") - + textContent, ok := result.Content[0].(mcp.TextContent) require.True(t, ok, "Content should be TextContent") assert.Equal(t, "session-tool result", textContent.Text, "Result text should match") @@ -137,13 +137,13 @@ func TestSessionWithTools_Integration(t *testing.T) { func TestMCPServer_ToolsWithSessionTools(t *testing.T) { // Basic test to verify that session-specific tools are returned correctly in a tools list server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) - + // Add global tools server.AddTools( ServerTool{Tool: mcp.NewTool("global-tool-1")}, ServerTool{Tool: mcp.NewTool("global-tool-2")}, ) - + // Create a session with tools session := &fakeSessionWithTools{ sessionID: "session-1", @@ -154,11 +154,11 @@ func TestMCPServer_ToolsWithSessionTools(t *testing.T) { "global-tool-1": {Tool: mcp.NewTool("global-tool-1", mcp.WithDescription("Overridden"))}, }, } - + // Register the session err := server.RegisterSession(context.Background(), session) require.NoError(t, err) - + // List tools with session context sessionCtx := server.WithContext(context.Background(), session) resp := server.HandleMessage(sessionCtx, []byte(`{ @@ -166,16 +166,16 @@ func TestMCPServer_ToolsWithSessionTools(t *testing.T) { "id": 1, "method": "tools/list" }`)) - + jsonResp, ok := resp.(mcp.JSONRPCResponse) require.True(t, ok, "Response should be a JSONRPCResponse") - + result, ok := jsonResp.Result.(mcp.ListToolsResult) require.True(t, ok, "Result should be a ListToolsResult") - + // Should have 3 tools - 2 global tools (one overridden) and 1 session-specific tool assert.Len(t, result.Tools, 3, "Should have 3 tools") - + // Find the overridden tool and verify its description var found bool for _, tool := range result.Tools { @@ -205,7 +205,7 @@ func TestMCPServer_AddSessionTools(t *testing.T) { require.NoError(t, err) // Add session-specific tools - err = server.AddSessionTools(session.SessionID(), + err = server.AddSessionTools(session.SessionID(), ServerTool{Tool: mcp.NewTool("session-tool")}, ) require.NoError(t, err) @@ -321,43 +321,43 @@ func TestMCPServer_ToolFiltering(t *testing.T) { }`)) resp, ok := response.(mcp.JSONRPCResponse) require.True(t, ok) - + result, ok := resp.Result.(mcp.ListToolsResult) require.True(t, ok) - + // Should only include tools with the "allow-" prefix assert.Len(t, result.Tools, 3) - + // Verify all tools start with "allow-" for _, tool := range result.Tools { - assert.True(t, len(tool.Name) >= 6 && tool.Name[:6] == "allow-", + assert.True(t, len(tool.Name) >= 6 && tool.Name[:6] == "allow-", "Tool should start with 'allow-', got: %s", tool.Name) } } func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { server := NewMCPServer("test-server", "1.0.0") - + session1Chan := make(chan mcp.JSONRPCNotification, 10) session1 := &fakeSession{ sessionID: "session-1", notificationChannel: session1Chan, initialized: true, } - + session2Chan := make(chan mcp.JSONRPCNotification, 10) session2 := &fakeSession{ sessionID: "session-2", notificationChannel: session2Chan, initialized: true, } - + session3 := &fakeSession{ sessionID: "session-3", notificationChannel: make(chan mcp.JSONRPCNotification, 10), initialized: false, // Not initialized } - + // Register sessions err := server.RegisterSession(context.Background(), session1) require.NoError(t, err) @@ -365,13 +365,13 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { require.NoError(t, err) err = server.RegisterSession(context.Background(), session3) require.NoError(t, err) - + // Send notification to session 1 err = server.SendNotificationToSpecificClient(session1.SessionID(), "test-method", map[string]any{ "data": "test-data", }) require.NoError(t, err) - + // Check that only session 1 received the notification select { case notification := <-session1Chan: @@ -380,7 +380,7 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { case <-time.After(100 * time.Millisecond): t.Error("Expected notification not received by session 1") } - + // Verify session 2 did not receive notification select { case notification := <-session2Chan: @@ -388,14 +388,14 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { case <-time.After(100 * time.Millisecond): // Expected, no notification for session 2 } - + // Test sending to non-existent session err = server.SendNotificationToSpecificClient("non-existent", "test-method", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "not found") - + // Test sending to uninitialized session err = server.SendNotificationToSpecificClient(session3.SessionID(), "test-method", nil) assert.Error(t, err) assert.Contains(t, err.Error(), "not properly initialized") -} \ No newline at end of file +} From 018a7459f66882aabd7cac5b16529b71f9509a73 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 20 Apr 2025 16:11:42 +0300 Subject: [PATCH 04/13] Remove TODO and update tests --- server/session.go | 44 +++++++++++++++-- server/session_test.go | 106 ++++++++++++++++++++++++++++++++++------- 2 files changed, 129 insertions(+), 21 deletions(-) diff --git a/server/session.go b/server/session.go index 4e8545d7..2224d883 100644 --- a/server/session.go +++ b/server/session.go @@ -2,6 +2,7 @@ package server import ( "context" + "fmt" "github.com/mark3labs/mcp-go/mcp" ) @@ -87,8 +88,20 @@ func (s *MCPServer) SendNotificationToAllClients( if session, ok := v.(ClientSession); ok && session.Initialized() { select { case session.NotificationChannel() <- notification: + // Successfully sent notification default: - // TODO: log blocked channel in the future versions + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + go func(sessionID string) { + ctx := context.Background() + // Use the error hook to report the blocked channel + s.hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID()) + } } } return true @@ -120,6 +133,17 @@ func (s *MCPServer) SendNotificationToClient( case session.NotificationChannel() <- notification: return nil default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + go func(sessionID string) { + // Use the error hook to report the blocked channel + s.hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID()) + } return ErrNotificationChannelBlocked } } @@ -154,6 +178,18 @@ func (s *MCPServer) SendNotificationToSpecificClient( case session.NotificationChannel() <- notification: return nil default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + ctx := context.Background() + go func(sID string) { + // Use the error hook to report the blocked channel + s.hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": method, + "sessionID": sID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) + }(sessionID) + } return ErrNotificationChannelBlocked } } @@ -183,7 +219,7 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error // Send notification only to this session s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) - + return nil } @@ -212,6 +248,6 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error // Send notification only to this session s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) - + return nil -} +} \ No newline at end of file diff --git a/server/session_test.go b/server/session_test.go index 4181fa11..9da317e3 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "errors" "testing" "time" @@ -32,41 +33,41 @@ func (f sessionTestClient) Initialized() bool { return f.initialized } -// fakeSessionWithTools implements the SessionWithTools interface for testing -type fakeSessionWithTools struct { +// sessionTestClientWithTools implements the SessionWithTools interface for testing +type sessionTestClientWithTools struct { sessionID string notificationChannel chan mcp.JSONRPCNotification initialized bool sessionTools map[string]ServerTool } -func (f *fakeSessionWithTools) SessionID() string { +func (f *sessionTestClientWithTools) SessionID() string { return f.sessionID } -func (f *fakeSessionWithTools) NotificationChannel() chan<- mcp.JSONRPCNotification { +func (f *sessionTestClientWithTools) NotificationChannel() chan<- mcp.JSONRPCNotification { return f.notificationChannel } -func (f *fakeSessionWithTools) Initialize() { +func (f *sessionTestClientWithTools) Initialize() { f.initialized = true } -func (f *fakeSessionWithTools) Initialized() bool { +func (f *sessionTestClientWithTools) Initialized() bool { return f.initialized } -func (f *fakeSessionWithTools) GetSessionTools() map[string]ServerTool { +func (f *sessionTestClientWithTools) GetSessionTools() map[string]ServerTool { return f.sessionTools } -func (f *fakeSessionWithTools) SetSessionTools(tools map[string]ServerTool) { +func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool) { f.sessionTools = tools } // Verify that both implementations satisfy their respective interfaces var _ ClientSession = sessionTestClient{} -var _ SessionWithTools = &fakeSessionWithTools{} +var _ SessionWithTools = &sessionTestClientWithTools{} func TestSessionWithTools_Integration(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) @@ -80,7 +81,7 @@ func TestSessionWithTools_Integration(t *testing.T) { } // Create a session with tools - session := &fakeSessionWithTools{ + session := &sessionTestClientWithTools{ sessionID: "session-1", notificationChannel: make(chan mcp.JSONRPCNotification, 10), initialized: true, @@ -145,7 +146,7 @@ func TestMCPServer_ToolsWithSessionTools(t *testing.T) { ) // Create a session with tools - session := &fakeSessionWithTools{ + session := &sessionTestClientWithTools{ sessionID: "session-1", notificationChannel: make(chan mcp.JSONRPCNotification, 10), initialized: true, @@ -194,7 +195,7 @@ func TestMCPServer_AddSessionTools(t *testing.T) { // Create a session sessionChan := make(chan mcp.JSONRPCNotification, 10) - session := &fakeSessionWithTools{ + session := &sessionTestClientWithTools{ sessionID: "session-1", notificationChannel: sessionChan, initialized: true, @@ -229,7 +230,7 @@ func TestMCPServer_DeleteSessionTools(t *testing.T) { // Create a session with tools sessionChan := make(chan mcp.JSONRPCNotification, 10) - session := &fakeSessionWithTools{ + session := &sessionTestClientWithTools{ sessionID: "session-1", notificationChannel: sessionChan, initialized: true, @@ -294,7 +295,7 @@ func TestMCPServer_ToolFiltering(t *testing.T) { ) // Create a session with tools - session := &fakeSessionWithTools{ + session := &sessionTestClientWithTools{ sessionID: "session-1", notificationChannel: make(chan mcp.JSONRPCNotification, 10), initialized: true, @@ -339,20 +340,20 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { server := NewMCPServer("test-server", "1.0.0") session1Chan := make(chan mcp.JSONRPCNotification, 10) - session1 := &fakeSession{ + session1 := &sessionTestClient{ sessionID: "session-1", notificationChannel: session1Chan, initialized: true, } session2Chan := make(chan mcp.JSONRPCNotification, 10) - session2 := &fakeSession{ + session2 := &sessionTestClient{ sessionID: "session-2", notificationChannel: session2Chan, initialized: true, } - session3 := &fakeSession{ + session3 := &sessionTestClient{ sessionID: "session-3", notificationChannel: make(chan mcp.JSONRPCNotification, 10), initialized: false, // Not initialized @@ -399,3 +400,74 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "not properly initialized") } + +func TestMCPServer_NotificationChannelBlocked(t *testing.T) { + // Set up a hooks object to capture error notifications + errorCaptured := false + errorSessionID := "" + errorMethod := "" + + hooks := &Hooks{} + hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + errorCaptured = true + // Extract session ID and method from the error message metadata + if msgMap, ok := message.(map[string]interface{}); ok { + if sid, ok := msgMap["sessionID"].(string); ok { + errorSessionID = sid + } + if m, ok := msgMap["method"].(string); ok { + errorMethod = m + } + } + // Verify the error is a notification channel blocked error + assert.True(t, errors.Is(err, ErrNotificationChannelBlocked)) + }) + + // Create a server with hooks + server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) + + // Create a session with a very small buffer that will get blocked + smallBufferChan := make(chan mcp.JSONRPCNotification, 1) + session := &sessionTestClient{ + sessionID: "blocked-session", + notificationChannel: smallBufferChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Fill the buffer first to ensure it gets blocked + server.SendNotificationToSpecificClient(session.SessionID(), "first-message", nil) + + // This will cause the buffer to block + err = server.SendNotificationToSpecificClient(session.SessionID(), "blocked-message", nil) + assert.Error(t, err) + assert.Equal(t, ErrNotificationChannelBlocked, err) + + // Wait a bit for the goroutine to execute + time.Sleep(10 * time.Millisecond) + + // Verify the error was logged via hooks + assert.True(t, errorCaptured, "Error hook should have been called") + assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook") + assert.Equal(t, "blocked-message", errorMethod, "Method should be captured in the error hook") + + // Also test SendNotificationToAllClients with a blocked channel + // Reset the captured data + errorCaptured = false + errorSessionID = "" + errorMethod = "" + + // Send to all clients (which includes our blocked one) + server.SendNotificationToAllClients("broadcast-message", nil) + + // Wait a bit for the goroutine to execute + time.Sleep(10 * time.Millisecond) + + // Verify the error was logged via hooks + assert.True(t, errorCaptured, "Error hook should have been called for broadcast") + assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook") + assert.Equal(t, "broadcast-message", errorMethod, "Method should be captured in the error hook") +} From c8f040ffe654b0982dfe64a20f95f5fb312afe20 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 20 Apr 2025 16:16:09 +0300 Subject: [PATCH 05/13] Apply suggestions from code review Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- server/session.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/server/session.go b/server/session.go index 2224d883..4cd09a43 100644 --- a/server/session.go +++ b/server/session.go @@ -65,8 +65,13 @@ func (s *MCPServer) UnregisterSession( ctx context.Context, sessionID string, ) { - session, _ := s.sessions.LoadAndDelete(sessionID) - s.hooks.UnregisterSession(ctx, session.(ClientSession)) + sessionValue, ok := s.sessions.LoadAndDelete(sessionID) + if !ok { + return + } + if session, ok := sessionValue.(ClientSession); ok { + s.hooks.UnregisterSession(ctx, session) + } } // SendNotificationToAllClients sends a notification to all the currently active clients. From fbecaf461a26bcd8821dce8c6caac1c3883e8566 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 20 Apr 2025 16:37:01 +0300 Subject: [PATCH 06/13] Update tests --- server/session.go | 6 +++--- server/session_test.go | 39 ++++++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/server/session.go b/server/session.go index 4cd09a43..ce58cb8a 100644 --- a/server/session.go +++ b/server/session.go @@ -224,7 +224,7 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error // Send notification only to this session s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) - + return nil } @@ -253,6 +253,6 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error // Send notification only to this session s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) - + return nil -} \ No newline at end of file +} diff --git a/server/session_test.go b/server/session_test.go index 9da317e3..5b207b8b 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -26,9 +26,14 @@ func (f sessionTestClient) NotificationChannel() chan<- mcp.JSONRPCNotification return f.notificationChannel } -func (f sessionTestClient) Initialize() { +// Initialize marks the session as initialized +// This implementation properly sets the initialized flag to true +// as required by the interface contract +func (f *sessionTestClient) Initialize() { + f.initialized = true } +// Initialized returns whether the session has been initialized func (f sessionTestClient) Initialized() bool { return f.initialized } @@ -66,7 +71,7 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool } // Verify that both implementations satisfy their respective interfaces -var _ ClientSession = sessionTestClient{} +var _ ClientSession = &sessionTestClient{} var _ SessionWithTools = &sessionTestClientWithTools{} func TestSessionWithTools_Integration(t *testing.T) { @@ -343,15 +348,15 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { session1 := &sessionTestClient{ sessionID: "session-1", notificationChannel: session1Chan, - initialized: true, } + session1.Initialize() session2Chan := make(chan mcp.JSONRPCNotification, 10) session2 := &sessionTestClient{ sessionID: "session-2", notificationChannel: session2Chan, - initialized: true, } + session2.Initialize() session3 := &sessionTestClient{ sessionID: "session-3", @@ -406,7 +411,7 @@ func TestMCPServer_NotificationChannelBlocked(t *testing.T) { errorCaptured := false errorSessionID := "" errorMethod := "" - + hooks := &Hooks{} hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { errorCaptured = true @@ -422,50 +427,50 @@ func TestMCPServer_NotificationChannelBlocked(t *testing.T) { // Verify the error is a notification channel blocked error assert.True(t, errors.Is(err, ErrNotificationChannelBlocked)) }) - + // Create a server with hooks server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) - + // Create a session with a very small buffer that will get blocked smallBufferChan := make(chan mcp.JSONRPCNotification, 1) session := &sessionTestClient{ sessionID: "blocked-session", notificationChannel: smallBufferChan, - initialized: true, } - + session.Initialize() + // Register the session err := server.RegisterSession(context.Background(), session) require.NoError(t, err) - + // Fill the buffer first to ensure it gets blocked server.SendNotificationToSpecificClient(session.SessionID(), "first-message", nil) - + // This will cause the buffer to block err = server.SendNotificationToSpecificClient(session.SessionID(), "blocked-message", nil) assert.Error(t, err) assert.Equal(t, ErrNotificationChannelBlocked, err) - + // Wait a bit for the goroutine to execute time.Sleep(10 * time.Millisecond) - + // Verify the error was logged via hooks assert.True(t, errorCaptured, "Error hook should have been called") assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook") assert.Equal(t, "blocked-message", errorMethod, "Method should be captured in the error hook") - + // Also test SendNotificationToAllClients with a blocked channel // Reset the captured data errorCaptured = false errorSessionID = "" errorMethod = "" - + // Send to all clients (which includes our blocked one) server.SendNotificationToAllClients("broadcast-message", nil) - + // Wait a bit for the goroutine to execute time.Sleep(10 * time.Millisecond) - + // Verify the error was logged via hooks assert.True(t, errorCaptured, "Error hook should have been called for broadcast") assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook") From 80a7a1e9ed8581846b9b71f3eb43179d214d4cf0 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 20 Apr 2025 16:48:12 +0300 Subject: [PATCH 07/13] Update tests --- server/session_test.go | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/server/session_test.go b/server/session_test.go index 5b207b8b..51a89432 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -3,6 +3,7 @@ package server import ( "context" "errors" + "sync" "testing" "time" @@ -361,7 +362,7 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { session3 := &sessionTestClient{ sessionID: "session-3", notificationChannel: make(chan mcp.JSONRPCNotification, 10), - initialized: false, // Not initialized + initialized: false, // Not initialized - deliberately not calling Initialize() } // Register sessions @@ -408,12 +409,16 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { func TestMCPServer_NotificationChannelBlocked(t *testing.T) { // Set up a hooks object to capture error notifications + var mu sync.Mutex errorCaptured := false errorSessionID := "" errorMethod := "" hooks := &Hooks{} hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + mu.Lock() + defer mu.Unlock() + errorCaptured = true // Extract session ID and method from the error message metadata if msgMap, ok := message.(map[string]interface{}); ok { @@ -455,15 +460,23 @@ func TestMCPServer_NotificationChannelBlocked(t *testing.T) { time.Sleep(10 * time.Millisecond) // Verify the error was logged via hooks - assert.True(t, errorCaptured, "Error hook should have been called") - assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook") - assert.Equal(t, "blocked-message", errorMethod, "Method should be captured in the error hook") + mu.Lock() + localErrorCaptured := errorCaptured + localErrorSessionID := errorSessionID + localErrorMethod := errorMethod + mu.Unlock() + + assert.True(t, localErrorCaptured, "Error hook should have been called") + assert.Equal(t, "blocked-session", localErrorSessionID, "Session ID should be captured in the error hook") + assert.Equal(t, "blocked-message", localErrorMethod, "Method should be captured in the error hook") // Also test SendNotificationToAllClients with a blocked channel // Reset the captured data + mu.Lock() errorCaptured = false errorSessionID = "" errorMethod = "" + mu.Unlock() // Send to all clients (which includes our blocked one) server.SendNotificationToAllClients("broadcast-message", nil) @@ -472,7 +485,13 @@ func TestMCPServer_NotificationChannelBlocked(t *testing.T) { time.Sleep(10 * time.Millisecond) // Verify the error was logged via hooks - assert.True(t, errorCaptured, "Error hook should have been called for broadcast") - assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook") - assert.Equal(t, "broadcast-message", errorMethod, "Method should be captured in the error hook") + mu.Lock() + localErrorCaptured = errorCaptured + localErrorSessionID = errorSessionID + localErrorMethod = errorMethod + mu.Unlock() + + assert.True(t, localErrorCaptured, "Error hook should have been called for broadcast") + assert.Equal(t, "blocked-session", localErrorSessionID, "Session ID should be captured in the error hook") + assert.Equal(t, "broadcast-message", localErrorMethod, "Method should be captured in the error hook") } From 0814e560f45c3f388da802c3e9a4590aeaf1c50f Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 22 Apr 2025 20:20:20 +0300 Subject: [PATCH 08/13] Fix inaffassign --- server/server.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index 2e39b69d..c6f80b68 100644 --- a/server/server.go +++ b/server/server.go @@ -830,7 +830,11 @@ func (s *MCPServer) handleToolCall( if session != nil { if sessionWithTools, ok := session.(SessionWithTools); ok { if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { - tool, ok = sessionTools[request.Params.Name] + var sessionOk bool + tool, sessionOk = sessionTools[request.Params.Name] + if sessionOk { + ok = true + } } } } From 36b261dd8eb1d2863991732f2eb09531bebb3490 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 29 Apr 2025 19:33:19 +0300 Subject: [PATCH 09/13] gitignore --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5430d3b0..b575ab67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .aider* .env -.idea \ No newline at end of file +.idea +.opencode +.claude From 739219c0567ff221f799705a5a6b9cca71c49b99 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 29 Apr 2025 19:57:21 +0300 Subject: [PATCH 10/13] Impl AddSessionTool --- README.md | 16 ++++++++++++++++ server/session.go | 5 +++++ server/session_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) diff --git a/README.md b/README.md index 2429ee72..594d49ca 100644 --- a/README.md +++ b/README.md @@ -626,6 +626,21 @@ userSpecificTool := mcp.NewTool( "user_data", mcp.WithDescription("Access user-specific data"), ) +// You can use AddSessionTool (similar to AddTool) +err := s.AddSessionTool( + advSession.SessionID(), + userSpecificTool, + func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // This handler is only available to this specific session + return mcp.NewToolResultText("User-specific data for " + advSession.SessionID()), nil + }, +) +if err != nil { + log.Printf("Failed to add session tool: %v", err) +} + +// Or use AddSessionTools directly with ServerTool +/* err := s.AddSessionTools( advSession.SessionID(), server.ServerTool{ @@ -639,6 +654,7 @@ err := s.AddSessionTools( if err != nil { log.Printf("Failed to add session tool: %v", err) } +*/ // Delete session-specific tools when no longer needed err = s.DeleteSessionTools(advSession.SessionID(), "user_data") diff --git a/server/session.go b/server/session.go index ce58cb8a..ead196ee 100644 --- a/server/session.go +++ b/server/session.go @@ -199,6 +199,11 @@ func (s *MCPServer) SendNotificationToSpecificClient( } } +// AddSessionTool adds a tool for a specific session +func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error { + return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler}) +} + // AddSessionTools adds tools for a specific session func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error { sessionValue, ok := s.sessions.Load(sessionID) diff --git a/server/session_test.go b/server/session_test.go index 51a89432..36f2a283 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -230,6 +230,45 @@ func TestMCPServer_AddSessionTools(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "session-tool") } +func TestMCPServer_AddSessionTool(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific tool using the new helper method + err = server.AddSessionTool( + session.SessionID(), + mcp.NewTool("session-tool-helper"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("helper result"), nil + }, + ) + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/tools/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify tool was added to session + assert.Len(t, session.GetSessionTools(), 1) + assert.Contains(t, session.GetSessionTools(), "session-tool-helper") +} + func TestMCPServer_DeleteSessionTools(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) ctx := context.Background() From 1d95aa06b57c8929c267b4e5106a0e5aecc90534 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 29 Apr 2025 20:05:53 +0300 Subject: [PATCH 11/13] Fix potential race condition --- server/session.go | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/server/session.go b/server/session.go index ead196ee..b57e9665 100644 --- a/server/session.go +++ b/server/session.go @@ -217,15 +217,23 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error } sessionTools := session.GetSessionTools() - if sessionTools == nil { - sessionTools = make(map[string]ServerTool) + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools)) + + // Copy existing tools + if sessionTools != nil { + for k, v := range sessionTools { + newSessionTools[k] = v + } } + // Add new tools for _, tool := range tools { - sessionTools[tool.Tool.Name] = tool + newSessionTools[tool.Tool.Name] = tool } - session.SetSessionTools(sessionTools) + session.SetSessionTools(newSessionTools) // Send notification only to this session s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) @@ -250,11 +258,20 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error return nil } + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)) + + // Copy existing tools except those being deleted + for k, v := range sessionTools { + newSessionTools[k] = v + } + + // Remove specified tools for _, name := range names { - delete(sessionTools, name) + delete(newSessionTools, name) } - session.SetSessionTools(sessionTools) + session.SetSessionTools(newSessionTools) // Send notification only to this session s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) From ccb5cdb1532ef9a33d367887b046b20f6053cbb6 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 29 Apr 2025 20:20:35 +0300 Subject: [PATCH 12/13] Fixes --- server/session.go | 54 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/server/session.go b/server/session.go index b57e9665..81ced7c7 100644 --- a/server/session.go +++ b/server/session.go @@ -98,14 +98,16 @@ func (s *MCPServer) SendNotificationToAllClients( // Channel is blocked, if there's an error hook, use it if s.hooks != nil && len(s.hooks.OnError) > 0 { err := ErrNotificationChannelBlocked - go func(sessionID string) { + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { ctx := context.Background() // Use the error hook to report the blocked channel - s.hooks.onError(ctx, nil, "notification", map[string]interface{}{ + hooks.onError(ctx, nil, "notification", map[string]interface{}{ "method": method, "sessionID": sessionID, }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) - }(session.SessionID()) + }(session.SessionID(), hooks) } } } @@ -141,13 +143,15 @@ func (s *MCPServer) SendNotificationToClient( // Channel is blocked, if there's an error hook, use it if s.hooks != nil && len(s.hooks.OnError) > 0 { err := ErrNotificationChannelBlocked - go func(sessionID string) { + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { // Use the error hook to report the blocked channel - s.hooks.onError(ctx, nil, "notification", map[string]interface{}{ + hooks.onError(ctx, nil, "notification", map[string]interface{}{ "method": method, "sessionID": sessionID, }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) - }(session.SessionID()) + }(session.SessionID(), hooks) } return ErrNotificationChannelBlocked } @@ -187,13 +191,15 @@ func (s *MCPServer) SendNotificationToSpecificClient( if s.hooks != nil && len(s.hooks.OnError) > 0 { err := ErrNotificationChannelBlocked ctx := context.Background() - go func(sID string) { + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sID string, hooks *Hooks) { // Use the error hook to report the blocked channel - s.hooks.onError(ctx, nil, "notification", map[string]interface{}{ + hooks.onError(ctx, nil, "notification", map[string]interface{}{ "method": method, "sessionID": sID, }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) - }(sessionID) + }(sessionID, hooks) } return ErrNotificationChannelBlocked } @@ -236,7 +242,20 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error session.SetSessionTools(newSessionTools) // Send notification only to this session - s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully added, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding tools: %w", err)) + }(sessionID, hooks) + } + } return nil } @@ -274,7 +293,20 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error session.SetSessionTools(newSessionTools) // Send notification only to this session - s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil) + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully deleted, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after deleting tools: %w", err)) + }(sessionID, hooks) + } + } return nil } From f7ef6bc1925f3b601a92bcab348b72c9e4a71526 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 29 Apr 2025 20:36:02 +0300 Subject: [PATCH 13/13] Fixes --- server/session.go | 6 ++++++ server/session_test.go | 30 ++++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/server/session.go b/server/session.go index 81ced7c7..b051878d 100644 --- a/server/session.go +++ b/server/session.go @@ -23,8 +23,10 @@ type ClientSession interface { type SessionWithTools interface { ClientSession // GetSessionTools returns the tools specific to this session, if any + // This method must be thread-safe for concurrent access GetSessionTools() map[string]ServerTool // SetSessionTools sets tools specific to this session + // This method must be thread-safe for concurrent access SetSessionTools(tools map[string]ServerTool) } @@ -222,6 +224,7 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error return ErrSessionDoesNotSupportTools } + // Get existing tools (this should return a thread-safe copy) sessionTools := session.GetSessionTools() // Create a new map to avoid concurrent modification issues @@ -239,6 +242,7 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error newSessionTools[tool.Tool.Name] = tool } + // Set the tools (this should be thread-safe) session.SetSessionTools(newSessionTools) // Send notification only to this session @@ -272,6 +276,7 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error return ErrSessionDoesNotSupportTools } + // Get existing tools (this should return a thread-safe copy) sessionTools := session.GetSessionTools() if sessionTools == nil { return nil @@ -290,6 +295,7 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error delete(newSessionTools, name) } + // Set the tools (this should be thread-safe) session.SetSessionTools(newSessionTools) // Send notification only to this session diff --git a/server/session_test.go b/server/session_test.go index 36f2a283..8a67c78f 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -45,6 +45,7 @@ type sessionTestClientWithTools struct { notificationChannel chan mcp.JSONRPCNotification initialized bool sessionTools map[string]ServerTool + mu sync.RWMutex // Mutex to protect concurrent access to sessionTools } func (f *sessionTestClientWithTools) SessionID() string { @@ -64,11 +65,36 @@ func (f *sessionTestClientWithTools) Initialized() bool { } func (f *sessionTestClientWithTools) GetSessionTools() map[string]ServerTool { - return f.sessionTools + f.mu.RLock() + defer f.mu.RUnlock() + + // Return a copy of the map to prevent concurrent modification + if f.sessionTools == nil { + return nil + } + + toolsCopy := make(map[string]ServerTool, len(f.sessionTools)) + for k, v := range f.sessionTools { + toolsCopy[k] = v + } + return toolsCopy } func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool) { - f.sessionTools = tools + f.mu.Lock() + defer f.mu.Unlock() + + // Create a copy of the map to prevent concurrent modification + if tools == nil { + f.sessionTools = nil + return + } + + toolsCopy := make(map[string]ServerTool, len(tools)) + for k, v := range tools { + toolsCopy[k] = v + } + f.sessionTools = toolsCopy } // Verify that both implementations satisfy their respective interfaces