diff --git a/.gitignore b/.gitignore index 5430d3b0..b575ab67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .aider* .env -.idea \ No newline at end of file +.idea +.opencode +.claude diff --git a/README.md b/README.md index 332ab3dd..594d49ca 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,10 @@ MCP Go handles all the complex protocol details and server management, so you ca - [Tools](#tools) - [Prompts](#prompts) - [Examples](#examples) +- [Extras](#extras) + - [Session Management](#session-management) + - [Request Hooks](#request-hooks) + - [Tool Handler Middleware](#tool-handler-middleware) - [Contributing](#contributing) - [Prerequisites](#prerequisites) - [Installation](#installation-1) @@ -516,6 +520,214 @@ For examples, see the `examples/` directory. ## Extras +### Session Management + +MCP-Go provides a robust session management system that allows you to: +- Maintain separate state for each connected client +- Register and track client sessions +- Send notifications to specific clients +- Provide per-session tool customization + +
+Show Session Management Examples + +#### Basic Session Handling + +```go +// Create a server with session capabilities +s := server.NewMCPServer( + "Session Demo", + "1.0.0", + server.WithToolCapabilities(true), +) + +// Implement your own ClientSession +type MySession struct { + id string + notifChannel chan mcp.JSONRPCNotification + isInitialized bool + // Add custom fields for your application +} + +// Implement the ClientSession interface +func (s *MySession) SessionID() string { + return s.id +} + +func (s *MySession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifChannel +} + +func (s *MySession) Initialize() { + s.isInitialized = true +} + +func (s *MySession) Initialized() bool { + return s.isInitialized +} + +// Register a session +session := &MySession{ + id: "user-123", + notifChannel: make(chan mcp.JSONRPCNotification, 10), +} +if err := s.RegisterSession(context.Background(), session); err != nil { + log.Printf("Failed to register session: %v", err) +} + +// Send notification to a specific client +err := s.SendNotificationToSpecificClient( + session.SessionID(), + "notification/update", + map[string]any{"message": "New data available!"}, +) +if err != nil { + log.Printf("Failed to send notification: %v", err) +} + +// Unregister session when done +s.UnregisterSession(context.Background(), session.SessionID()) +``` + +#### Per-Session Tools + +For more advanced use cases, you can implement the `SessionWithTools` interface to support per-session tool customization: + +```go +// Implement SessionWithTools interface for per-session tools +type MyAdvancedSession struct { + MySession // Embed the basic session + sessionTools map[string]server.ServerTool +} + +// Implement additional methods for SessionWithTools +func (s *MyAdvancedSession) GetSessionTools() map[string]server.ServerTool { + return s.sessionTools +} + +func (s *MyAdvancedSession) SetSessionTools(tools map[string]server.ServerTool) { + s.sessionTools = tools +} + +// Create and register a session with tools support +advSession := &MyAdvancedSession{ + MySession: MySession{ + id: "user-456", + notifChannel: make(chan mcp.JSONRPCNotification, 10), + }, + sessionTools: make(map[string]server.ServerTool), +} +if err := s.RegisterSession(context.Background(), advSession); err != nil { + log.Printf("Failed to register session: %v", err) +} + +// Add session-specific tools +userSpecificTool := mcp.NewTool( + "user_data", + mcp.WithDescription("Access user-specific data"), +) +// You can use AddSessionTool (similar to AddTool) +err := s.AddSessionTool( + advSession.SessionID(), + userSpecificTool, + func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // This handler is only available to this specific session + return mcp.NewToolResultText("User-specific data for " + advSession.SessionID()), nil + }, +) +if err != nil { + log.Printf("Failed to add session tool: %v", err) +} + +// Or use AddSessionTools directly with ServerTool +/* +err := s.AddSessionTools( + advSession.SessionID(), + server.ServerTool{ + Tool: userSpecificTool, + Handler: func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // This handler is only available to this specific session + return mcp.NewToolResultText("User-specific data for " + advSession.SessionID()), nil + }, + }, +) +if err != nil { + log.Printf("Failed to add session tool: %v", err) +} +*/ + +// Delete session-specific tools when no longer needed +err = s.DeleteSessionTools(advSession.SessionID(), "user_data") +if err != nil { + log.Printf("Failed to delete session tool: %v", err) +} +``` + +#### Tool Filtering + +You can also apply filters to control which tools are available to certain sessions: + +```go +// Add a tool filter that only shows tools with certain prefixes +s := server.NewMCPServer( + "Tool Filtering Demo", + "1.0.0", + server.WithToolCapabilities(true), + server.WithToolFilter(func(ctx context.Context, tools []mcp.Tool) []mcp.Tool { + // Get session from context + session := server.ClientSessionFromContext(ctx) + if session == nil { + return tools // Return all tools if no session + } + + // Example: filter tools based on session ID prefix + if strings.HasPrefix(session.SessionID(), "admin-") { + // Admin users get all tools + return tools + } else { + // Regular users only get tools with "public-" prefix + var filteredTools []mcp.Tool + for _, tool := range tools { + if strings.HasPrefix(tool.Name, "public-") { + filteredTools = append(filteredTools, tool) + } + } + return filteredTools + } + }), +) +``` + +#### Working with Context + +The session context is automatically passed to tool and resource handlers: + +```go +s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Get the current session from context + session := server.ClientSessionFromContext(ctx) + if session == nil { + return mcp.NewToolResultError("No active session"), nil + } + + return mcp.NewToolResultText("Hello, session " + session.SessionID()), nil +}) + +// When using handlers in HTTP/SSE servers, you need to pass the context with the session +httpHandler := func(w http.ResponseWriter, r *http.Request) { + // Get session from somewhere (like a cookie or header) + session := getSessionFromRequest(r) + + // Add session to context + ctx := s.WithContext(r.Context(), session) + + // Use this context when handling requests + // ... +} +``` + +
+ ### Request Hooks Hook into the request lifecycle by creating a `Hooks` object with your diff --git a/server/errors.go b/server/errors.go new file mode 100644 index 00000000..7ced5cf7 --- /dev/null +++ b/server/errors.go @@ -0,0 +1,23 @@ +package server + +import ( + "errors" +) + +var ( + // Common server errors + ErrUnsupported = errors.New("not supported") + ErrResourceNotFound = errors.New("resource not found") + ErrPromptNotFound = errors.New("prompt not found") + ErrToolNotFound = errors.New("tool not found") + + // Session-related errors + ErrSessionNotFound = errors.New("session not found") + ErrSessionExists = errors.New("session already exists") + ErrSessionNotInitialized = errors.New("session not properly initialized") + ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + + // Notification-related errors + ErrNotificationNotInitialized = errors.New("notification channel not initialized") + ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") +) diff --git a/server/server.go b/server/server.go index 430f8d53..c6f80b68 100644 --- a/server/server.go +++ b/server/server.go @@ -5,7 +5,6 @@ import ( "context" "encoding/base64" "encoding/json" - "errors" "fmt" "reflect" "sort" @@ -44,31 +43,22 @@ type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mc // ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc +// ToolFilterFunc is a function that filters tools based on context, typically using session information. +type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool + // ServerTool combines a Tool with its ToolHandlerFunc. type ServerTool struct { Tool mcp.Tool Handler ToolHandlerFunc } -// ClientSession represents an active session that can be used by MCPServer to interact with client. -type ClientSession interface { - // Initialize marks session as fully initialized and ready for notifications - Initialize() - // Initialized returns if session is ready to accept notifications - Initialized() bool - // NotificationChannel provides a channel suitable for sending notifications to client. - NotificationChannel() chan<- mcp.JSONRPCNotification - // SessionID is a unique identifier used to track user session. - SessionID() string -} - -// clientSessionKey is the context key for storing current client notification channel. -type clientSessionKey struct{} +// serverKey is the context key for storing the server instance +type serverKey struct{} -// ClientSessionFromContext retrieves current client notification context from context. -func ClientSessionFromContext(ctx context.Context) ClientSession { - if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { - return session +// ServerFromContext retrieves the MCPServer instance from a context +func ServerFromContext(ctx context.Context) *MCPServer { + if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { + return srv } return nil } @@ -128,13 +118,6 @@ func (e *requestError) Unwrap() error { return e.err } -var ( - ErrUnsupported = errors.New("not supported") - ErrResourceNotFound = errors.New("resource not found") - ErrPromptNotFound = errors.New("prompt not found") - ErrToolNotFound = errors.New("tool not found") -) - // NotificationHandlerFunc handles incoming notifications. type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification) @@ -148,6 +131,7 @@ type MCPServer struct { middlewareMu sync.RWMutex notificationHandlersMu sync.RWMutex capabilitiesMu sync.RWMutex + toolFiltersMu sync.RWMutex name string version string @@ -158,6 +142,7 @@ type MCPServer struct { promptHandlers map[string]PromptHandlerFunc tools map[string]ServerTool toolHandlerMiddlewares []ToolHandlerMiddleware + toolFilters []ToolFilterFunc notificationHandlers map[string]NotificationHandlerFunc capabilities serverCapabilities paginationLimit *int @@ -165,17 +150,6 @@ type MCPServer struct { hooks *Hooks } -// serverKey is the context key for storing the server instance -type serverKey struct{} - -// ServerFromContext retrieves the MCPServer instance from a context -func ServerFromContext(ctx context.Context) *MCPServer { - if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { - return srv - } - return nil -} - // WithPaginationLimit sets the pagination limit for the server. func WithPaginationLimit(limit int) ServerOption { return func(s *MCPServer) { @@ -183,92 +157,6 @@ func WithPaginationLimit(limit int) ServerOption { } } -// WithContext sets the current client session and returns the provided context -func (s *MCPServer) WithContext( - ctx context.Context, - session ClientSession, -) context.Context { - return context.WithValue(ctx, clientSessionKey{}, session) -} - -// RegisterSession saves session that should be notified in case if some server attributes changed. -func (s *MCPServer) RegisterSession( - ctx context.Context, - session ClientSession, -) error { - sessionID := session.SessionID() - if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { - return fmt.Errorf("session %s is already registered", sessionID) - } - s.hooks.RegisterSession(ctx, session) - return nil -} - -// UnregisterSession removes from storage session that is shut down. -func (s *MCPServer) UnregisterSession( - ctx context.Context, - sessionID string, -) { - session, _ := s.sessions.LoadAndDelete(sessionID) - s.hooks.UnregisterSession(ctx, session.(ClientSession)) -} - -// SendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) SendNotificationToAllClients( - method string, - params map[string]any, -) { - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - - s.sessions.Range(func(k, v any) bool { - if session, ok := v.(ClientSession); ok && session.Initialized() { - select { - case session.NotificationChannel() <- notification: - default: - // TODO: log blocked channel in the future versions - } - } - return true - }) -} - -// SendNotificationToClient sends a notification to the current client -func (s *MCPServer) SendNotificationToClient( - ctx context.Context, - method string, - params map[string]any, -) error { - session := ClientSessionFromContext(ctx) - if session == nil || !session.Initialized() { - return fmt.Errorf("notification channel not initialized") - } - - notification := mcp.JSONRPCNotification{ - JSONRPC: mcp.JSONRPC_VERSION, - Notification: mcp.Notification{ - Method: method, - Params: mcp.NotificationParams{ - AdditionalFields: params, - }, - }, - } - - select { - case session.NotificationChannel() <- notification: - return nil - default: - return fmt.Errorf("notification channel full or blocked") - } -} - // serverCapabilities defines the supported features of the MCP server type serverCapabilities struct { tools *toolCapabilities @@ -316,6 +204,17 @@ func WithToolHandlerMiddleware( } } +// WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools +func WithToolFilter( + toolFilter ToolFilterFunc, +) ServerOption { + return func(s *MCPServer) { + s.toolFiltersMu.Lock() + s.toolFilters = append(s.toolFilters, toolFilter) + s.toolFiltersMu.Unlock() + } +} + // WithRecovery adds a middleware that recovers from panics in tool handlers. func WithRecovery() ServerOption { return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { @@ -838,6 +737,7 @@ func (s *MCPServer) handleListTools( id interface{}, request mcp.ListToolsRequest, ) (*mcp.ListToolsResult, *requestError) { + // Get the base tools from the server s.toolsMu.RLock() tools := make([]mcp.Tool, 0, len(s.tools)) @@ -856,6 +756,49 @@ func (s *MCPServer) handleListTools( } s.toolsMu.RUnlock() + // Check if there are session-specific tools + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + // Override or add session-specific tools + // We need to create a map first to merge the tools properly + toolMap := make(map[string]mcp.Tool) + + // Add global tools first + for _, tool := range tools { + toolMap[tool.Name] = tool + } + + // Then override with session-specific tools + for name, serverTool := range sessionTools { + toolMap[name] = serverTool.Tool + } + + // Convert back to slice + tools = make([]mcp.Tool, 0, len(toolMap)) + for _, tool := range toolMap { + tools = append(tools, tool) + } + + // Sort again to maintain consistent ordering + sort.Slice(tools, func(i, j int) bool { + return tools[i].Name < tools[j].Name + }) + } + } + } + + // Apply tool filters if any are defined + s.toolFiltersMu.RLock() + if len(s.toolFilters) > 0 { + for _, filter := range s.toolFilters { + tools = filter(ctx, tools) + } + } + s.toolFiltersMu.RUnlock() + + // Apply pagination toolsToReturn, nextCursor, err := listByPagination[mcp.Tool](ctx, s, request.Params.Cursor, tools) if err != nil { return nil, &requestError{ @@ -864,6 +807,7 @@ func (s *MCPServer) handleListTools( err: err, } } + result := mcp.ListToolsResult{ Tools: toolsToReturn, PaginatedResult: mcp.PaginatedResult{ @@ -878,9 +822,29 @@ func (s *MCPServer) handleToolCall( id interface{}, request mcp.CallToolRequest, ) (*mcp.CallToolResult, *requestError) { - s.toolsMu.RLock() - tool, ok := s.tools[request.Params.Name] - s.toolsMu.RUnlock() + // First check session-specific tools + var tool ServerTool + var ok bool + + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + var sessionOk bool + tool, sessionOk = sessionTools[request.Params.Name] + if sessionOk { + ok = true + } + } + } + } + + // If not found in session tools, check global tools + if !ok { + s.toolsMu.RLock() + tool, ok = s.tools[request.Params.Name] + s.toolsMu.RUnlock() + } if !ok { return nil, &requestError{ diff --git a/server/session.go b/server/session.go new file mode 100644 index 00000000..b051878d --- /dev/null +++ b/server/session.go @@ -0,0 +1,318 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// ClientSession represents an active session that can be used by MCPServer to interact with client. +type ClientSession interface { + // Initialize marks session as fully initialized and ready for notifications + Initialize() + // Initialized returns if session is ready to accept notifications + Initialized() bool + // NotificationChannel provides a channel suitable for sending notifications to client. + NotificationChannel() chan<- mcp.JSONRPCNotification + // SessionID is a unique identifier used to track user session. + SessionID() string +} + +// SessionWithTools is an extension of ClientSession that can store session-specific tool data +type SessionWithTools interface { + ClientSession + // GetSessionTools returns the tools specific to this session, if any + // This method must be thread-safe for concurrent access + GetSessionTools() map[string]ServerTool + // SetSessionTools sets tools specific to this session + // This method must be thread-safe for concurrent access + SetSessionTools(tools map[string]ServerTool) +} + +// clientSessionKey is the context key for storing current client notification channel. +type clientSessionKey struct{} + +// ClientSessionFromContext retrieves current client notification context from context. +func ClientSessionFromContext(ctx context.Context) ClientSession { + if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { + return session + } + return nil +} + +// WithContext sets the current client session and returns the provided context +func (s *MCPServer) WithContext( + ctx context.Context, + session ClientSession, +) context.Context { + return context.WithValue(ctx, clientSessionKey{}, session) +} + +// RegisterSession saves session that should be notified in case if some server attributes changed. +func (s *MCPServer) RegisterSession( + ctx context.Context, + session ClientSession, +) error { + sessionID := session.SessionID() + if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + return ErrSessionExists + } + s.hooks.RegisterSession(ctx, session) + return nil +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + ctx context.Context, + sessionID string, +) { + sessionValue, ok := s.sessions.LoadAndDelete(sessionID) + if !ok { + return + } + if session, ok := sessionValue.(ClientSession); ok { + s.hooks.UnregisterSession(ctx, session) + } +} + +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( + method string, + params map[string]any, +) { + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + s.sessions.Range(func(k, v any) bool { + if session, ok := v.(ClientSession); ok && session.Initialized() { + select { + case session.NotificationChannel() <- notification: + // Successfully sent notification + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { + ctx := context.Background() + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID(), hooks) + } + } + } + return true + }) +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID(), hooks) + } + return ErrNotificationChannelBlocked + } +} + +// SendNotificationToSpecificClient sends a notification to a specific client by session ID +func (s *MCPServer) SendNotificationToSpecificClient( + sessionID string, + method string, + params map[string]any, +) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + ctx := context.Background() + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": method, + "sessionID": sID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) + }(sessionID, hooks) + } + return ErrNotificationChannelBlocked + } +} + +// AddSessionTool adds a tool for a specific session +func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error { + return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler}) +} + +// AddSessionTools adds tools for a specific session +func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + // Get existing tools (this should return a thread-safe copy) + sessionTools := session.GetSessionTools() + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools)) + + // Copy existing tools + if sessionTools != nil { + for k, v := range sessionTools { + newSessionTools[k] = v + } + } + + // Add new tools + for _, tool := range tools { + newSessionTools[tool.Tool.Name] = tool + } + + // Set the tools (this should be thread-safe) + session.SetSessionTools(newSessionTools) + + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully added, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding tools: %w", err)) + }(sessionID, hooks) + } + } + + return nil +} + +// DeleteSessionTools removes tools from a specific session +func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + // Get existing tools (this should return a thread-safe copy) + sessionTools := session.GetSessionTools() + if sessionTools == nil { + return nil + } + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)) + + // Copy existing tools except those being deleted + for k, v := range sessionTools { + newSessionTools[k] = v + } + + // Remove specified tools + for _, name := range names { + delete(newSessionTools, name) + } + + // Set the tools (this should be thread-safe) + session.SetSessionTools(newSessionTools) + + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully deleted, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]interface{}{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after deleting tools: %w", err)) + }(sessionID, hooks) + } + } + + return nil +} diff --git a/server/session_test.go b/server/session_test.go new file mode 100644 index 00000000..8a67c78f --- /dev/null +++ b/server/session_test.go @@ -0,0 +1,562 @@ +package server + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// sessionTestClient implements the basic ClientSession interface for testing +type sessionTestClient struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool +} + +func (f sessionTestClient) SessionID() string { + return f.sessionID +} + +func (f sessionTestClient) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +// Initialize marks the session as initialized +// This implementation properly sets the initialized flag to true +// as required by the interface contract +func (f *sessionTestClient) Initialize() { + f.initialized = true +} + +// Initialized returns whether the session has been initialized +func (f sessionTestClient) Initialized() bool { + return f.initialized +} + +// sessionTestClientWithTools implements the SessionWithTools interface for testing +type sessionTestClientWithTools struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool + sessionTools map[string]ServerTool + mu sync.RWMutex // Mutex to protect concurrent access to sessionTools +} + +func (f *sessionTestClientWithTools) SessionID() string { + return f.sessionID +} + +func (f *sessionTestClientWithTools) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *sessionTestClientWithTools) Initialize() { + f.initialized = true +} + +func (f *sessionTestClientWithTools) Initialized() bool { + return f.initialized +} + +func (f *sessionTestClientWithTools) GetSessionTools() map[string]ServerTool { + f.mu.RLock() + defer f.mu.RUnlock() + + // Return a copy of the map to prevent concurrent modification + if f.sessionTools == nil { + return nil + } + + toolsCopy := make(map[string]ServerTool, len(f.sessionTools)) + for k, v := range f.sessionTools { + toolsCopy[k] = v + } + return toolsCopy +} + +func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool) { + f.mu.Lock() + defer f.mu.Unlock() + + // Create a copy of the map to prevent concurrent modification + if tools == nil { + f.sessionTools = nil + return + } + + toolsCopy := make(map[string]ServerTool, len(tools)) + for k, v := range tools { + toolsCopy[k] = v + } + f.sessionTools = toolsCopy +} + +// Verify that both implementations satisfy their respective interfaces +var _ ClientSession = &sessionTestClient{} +var _ SessionWithTools = &sessionTestClientWithTools{} + +func TestSessionWithTools_Integration(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + // Create session-specific tools + sessionTool := ServerTool{ + Tool: mcp.NewTool("session-tool"), + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("session-tool result"), nil + }, + } + + // Create a session with tools + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + sessionTools: map[string]ServerTool{ + "session-tool": sessionTool, + }, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Test that we can access the session-specific tool + testReq := mcp.CallToolRequest{} + testReq.Params.Name = "session-tool" + testReq.Params.Arguments = map[string]interface{}{} + + // Call using session context + sessionCtx := server.WithContext(context.Background(), session) + + // Check if the session was stored in the context correctly + s := ClientSessionFromContext(sessionCtx) + require.NotNil(t, s, "Session should be available from context") + assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match") + + // Check if the session can be cast to SessionWithTools + swt, ok := s.(SessionWithTools) + require.True(t, ok, "Session should implement SessionWithTools") + + // Check if the tools are accessible + tools := swt.GetSessionTools() + require.NotNil(t, tools, "Session tools should be available") + require.Contains(t, tools, "session-tool", "Session should have session-tool") + + // Test session tool access with session context + t.Run("test session tool access", func(t *testing.T) { + // First test directly getting the tool from session tools + tool, exists := tools["session-tool"] + require.True(t, exists, "Session tool should exist in the map") + require.NotNil(t, tool, "Session tool should not be nil") + + // Now test calling directly with the handler + result, err := tool.Handler(sessionCtx, testReq) + require.NoError(t, err, "No error calling session tool handler directly") + require.NotNil(t, result, "Result should not be nil") + require.Len(t, result.Content, 1, "Result should have one content item") + + textContent, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "Content should be TextContent") + assert.Equal(t, "session-tool result", textContent.Text, "Result text should match") + }) +} + +func TestMCPServer_ToolsWithSessionTools(t *testing.T) { + // Basic test to verify that session-specific tools are returned correctly in a tools list + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + // Add global tools + server.AddTools( + ServerTool{Tool: mcp.NewTool("global-tool-1")}, + ServerTool{Tool: mcp.NewTool("global-tool-2")}, + ) + + // Create a session with tools + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + sessionTools: map[string]ServerTool{ + "session-tool-1": {Tool: mcp.NewTool("session-tool-1")}, + "global-tool-1": {Tool: mcp.NewTool("global-tool-1", mcp.WithDescription("Overridden"))}, + }, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // List tools with session context + sessionCtx := server.WithContext(context.Background(), session) + resp := server.HandleMessage(sessionCtx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }`)) + + jsonResp, ok := resp.(mcp.JSONRPCResponse) + require.True(t, ok, "Response should be a JSONRPCResponse") + + result, ok := jsonResp.Result.(mcp.ListToolsResult) + require.True(t, ok, "Result should be a ListToolsResult") + + // Should have 3 tools - 2 global tools (one overridden) and 1 session-specific tool + assert.Len(t, result.Tools, 3, "Should have 3 tools") + + // Find the overridden tool and verify its description + var found bool + for _, tool := range result.Tools { + if tool.Name == "global-tool-1" { + assert.Equal(t, "Overridden", tool.Description, "Global tool should be overridden") + found = true + break + } + } + assert.True(t, found, "Should find the overridden global tool") +} + +func TestMCPServer_AddSessionTools(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific tools + err = server.AddSessionTools(session.SessionID(), + ServerTool{Tool: mcp.NewTool("session-tool")}, + ) + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/tools/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify tool was added to session + assert.Len(t, session.GetSessionTools(), 1) + assert.Contains(t, session.GetSessionTools(), "session-tool") +} + +func TestMCPServer_AddSessionTool(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific tool using the new helper method + err = server.AddSessionTool( + session.SessionID(), + mcp.NewTool("session-tool-helper"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("helper result"), nil + }, + ) + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/tools/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify tool was added to session + assert.Len(t, session.GetSessionTools(), 1) + assert.Contains(t, session.GetSessionTools(), "session-tool-helper") +} + +func TestMCPServer_DeleteSessionTools(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + ctx := context.Background() + + // Create a session with tools + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + sessionTools: map[string]ServerTool{ + "session-tool-1": { + Tool: mcp.NewTool("session-tool-1"), + }, + "session-tool-2": { + Tool: mcp.NewTool("session-tool-2"), + }, + }, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Delete one of the session tools + err = server.DeleteSessionTools(session.SessionID(), "session-tool-1") + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/tools/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify tool was removed from session + assert.Len(t, session.GetSessionTools(), 1) + assert.NotContains(t, session.GetSessionTools(), "session-tool-1") + assert.Contains(t, session.GetSessionTools(), "session-tool-2") +} + +func TestMCPServer_ToolFiltering(t *testing.T) { + // Create a filter that filters tools by prefix + filterByPrefix := func(prefix string) ToolFilterFunc { + return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool { + var filtered []mcp.Tool + for _, tool := range tools { + if len(tool.Name) >= len(prefix) && tool.Name[:len(prefix)] == prefix { + filtered = append(filtered, tool) + } + } + return filtered + } + } + + // Create a server with a tool filter + server := NewMCPServer("test-server", "1.0.0", + WithToolCapabilities(true), + WithToolFilter(filterByPrefix("allow-")), + ) + + // Add tools with different prefixes + server.AddTools( + ServerTool{Tool: mcp.NewTool("allow-tool-1")}, + ServerTool{Tool: mcp.NewTool("allow-tool-2")}, + ServerTool{Tool: mcp.NewTool("deny-tool-1")}, + ServerTool{Tool: mcp.NewTool("deny-tool-2")}, + ) + + // Create a session with tools + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + sessionTools: map[string]ServerTool{ + "allow-session-tool": { + Tool: mcp.NewTool("allow-session-tool"), + }, + "deny-session-tool": { + Tool: mcp.NewTool("deny-session-tool"), + }, + }, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // List tools with session context + sessionCtx := server.WithContext(context.Background(), session) + response := server.HandleMessage(sessionCtx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }`)) + resp, ok := response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok := resp.Result.(mcp.ListToolsResult) + require.True(t, ok) + + // Should only include tools with the "allow-" prefix + assert.Len(t, result.Tools, 3) + + // Verify all tools start with "allow-" + for _, tool := range result.Tools { + assert.True(t, len(tool.Name) >= 6 && tool.Name[:6] == "allow-", + "Tool should start with 'allow-', got: %s", tool.Name) + } +} + +func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + session1Chan := make(chan mcp.JSONRPCNotification, 10) + session1 := &sessionTestClient{ + sessionID: "session-1", + notificationChannel: session1Chan, + } + session1.Initialize() + + session2Chan := make(chan mcp.JSONRPCNotification, 10) + session2 := &sessionTestClient{ + sessionID: "session-2", + notificationChannel: session2Chan, + } + session2.Initialize() + + session3 := &sessionTestClient{ + sessionID: "session-3", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, // Not initialized - deliberately not calling Initialize() + } + + // Register sessions + err := server.RegisterSession(context.Background(), session1) + require.NoError(t, err) + err = server.RegisterSession(context.Background(), session2) + require.NoError(t, err) + err = server.RegisterSession(context.Background(), session3) + require.NoError(t, err) + + // Send notification to session 1 + err = server.SendNotificationToSpecificClient(session1.SessionID(), "test-method", map[string]any{ + "data": "test-data", + }) + require.NoError(t, err) + + // Check that only session 1 received the notification + select { + case notification := <-session1Chan: + assert.Equal(t, "test-method", notification.Method) + assert.Equal(t, "test-data", notification.Params.AdditionalFields["data"]) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received by session 1") + } + + // Verify session 2 did not receive notification + select { + case notification := <-session2Chan: + t.Errorf("Unexpected notification received by session 2: %v", notification) + case <-time.After(100 * time.Millisecond): + // Expected, no notification for session 2 + } + + // Test sending to non-existent session + err = server.SendNotificationToSpecificClient("non-existent", "test-method", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + + // Test sending to uninitialized session + err = server.SendNotificationToSpecificClient(session3.SessionID(), "test-method", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not properly initialized") +} + +func TestMCPServer_NotificationChannelBlocked(t *testing.T) { + // Set up a hooks object to capture error notifications + var mu sync.Mutex + errorCaptured := false + errorSessionID := "" + errorMethod := "" + + hooks := &Hooks{} + hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + mu.Lock() + defer mu.Unlock() + + errorCaptured = true + // Extract session ID and method from the error message metadata + if msgMap, ok := message.(map[string]interface{}); ok { + if sid, ok := msgMap["sessionID"].(string); ok { + errorSessionID = sid + } + if m, ok := msgMap["method"].(string); ok { + errorMethod = m + } + } + // Verify the error is a notification channel blocked error + assert.True(t, errors.Is(err, ErrNotificationChannelBlocked)) + }) + + // Create a server with hooks + server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) + + // Create a session with a very small buffer that will get blocked + smallBufferChan := make(chan mcp.JSONRPCNotification, 1) + session := &sessionTestClient{ + sessionID: "blocked-session", + notificationChannel: smallBufferChan, + } + session.Initialize() + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Fill the buffer first to ensure it gets blocked + server.SendNotificationToSpecificClient(session.SessionID(), "first-message", nil) + + // This will cause the buffer to block + err = server.SendNotificationToSpecificClient(session.SessionID(), "blocked-message", nil) + assert.Error(t, err) + assert.Equal(t, ErrNotificationChannelBlocked, err) + + // Wait a bit for the goroutine to execute + time.Sleep(10 * time.Millisecond) + + // Verify the error was logged via hooks + mu.Lock() + localErrorCaptured := errorCaptured + localErrorSessionID := errorSessionID + localErrorMethod := errorMethod + mu.Unlock() + + assert.True(t, localErrorCaptured, "Error hook should have been called") + assert.Equal(t, "blocked-session", localErrorSessionID, "Session ID should be captured in the error hook") + assert.Equal(t, "blocked-message", localErrorMethod, "Method should be captured in the error hook") + + // Also test SendNotificationToAllClients with a blocked channel + // Reset the captured data + mu.Lock() + errorCaptured = false + errorSessionID = "" + errorMethod = "" + mu.Unlock() + + // Send to all clients (which includes our blocked one) + server.SendNotificationToAllClients("broadcast-message", nil) + + // Wait a bit for the goroutine to execute + time.Sleep(10 * time.Millisecond) + + // Verify the error was logged via hooks + mu.Lock() + localErrorCaptured = errorCaptured + localErrorSessionID = errorSessionID + localErrorMethod = errorMethod + mu.Unlock() + + assert.True(t, localErrorCaptured, "Error hook should have been called for broadcast") + assert.Equal(t, "blocked-session", localErrorSessionID, "Session ID should be captured in the error hook") + assert.Equal(t, "broadcast-message", localErrorMethod, "Method should be captured in the error hook") +}