Skip to content

Commit 018a745

Browse files
committed
Remove TODO and update tests
1 parent b6db184 commit 018a745

File tree

2 files changed

+129
-21
lines changed

2 files changed

+129
-21
lines changed

server/session.go

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package server
22

33
import (
44
"context"
5+
"fmt"
56

67
"github.com/mark3labs/mcp-go/mcp"
78
)
@@ -87,8 +88,20 @@ func (s *MCPServer) SendNotificationToAllClients(
8788
if session, ok := v.(ClientSession); ok && session.Initialized() {
8889
select {
8990
case session.NotificationChannel() <- notification:
91+
// Successfully sent notification
9092
default:
91-
// TODO: log blocked channel in the future versions
93+
// Channel is blocked, if there's an error hook, use it
94+
if s.hooks != nil && len(s.hooks.OnError) > 0 {
95+
err := ErrNotificationChannelBlocked
96+
go func(sessionID string) {
97+
ctx := context.Background()
98+
// Use the error hook to report the blocked channel
99+
s.hooks.onError(ctx, nil, "notification", map[string]interface{}{
100+
"method": method,
101+
"sessionID": sessionID,
102+
}, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
103+
}(session.SessionID())
104+
}
92105
}
93106
}
94107
return true
@@ -120,6 +133,17 @@ func (s *MCPServer) SendNotificationToClient(
120133
case session.NotificationChannel() <- notification:
121134
return nil
122135
default:
136+
// Channel is blocked, if there's an error hook, use it
137+
if s.hooks != nil && len(s.hooks.OnError) > 0 {
138+
err := ErrNotificationChannelBlocked
139+
go func(sessionID string) {
140+
// Use the error hook to report the blocked channel
141+
s.hooks.onError(ctx, nil, "notification", map[string]interface{}{
142+
"method": method,
143+
"sessionID": sessionID,
144+
}, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
145+
}(session.SessionID())
146+
}
123147
return ErrNotificationChannelBlocked
124148
}
125149
}
@@ -154,6 +178,18 @@ func (s *MCPServer) SendNotificationToSpecificClient(
154178
case session.NotificationChannel() <- notification:
155179
return nil
156180
default:
181+
// Channel is blocked, if there's an error hook, use it
182+
if s.hooks != nil && len(s.hooks.OnError) > 0 {
183+
err := ErrNotificationChannelBlocked
184+
ctx := context.Background()
185+
go func(sID string) {
186+
// Use the error hook to report the blocked channel
187+
s.hooks.onError(ctx, nil, "notification", map[string]interface{}{
188+
"method": method,
189+
"sessionID": sID,
190+
}, fmt.Errorf("notification channel blocked for session %s: %w", sID, err))
191+
}(sessionID)
192+
}
157193
return ErrNotificationChannelBlocked
158194
}
159195
}
@@ -183,7 +219,7 @@ func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error
183219

184220
// Send notification only to this session
185221
s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil)
186-
222+
187223
return nil
188224
}
189225

@@ -212,6 +248,6 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error
212248

213249
// Send notification only to this session
214250
s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil)
215-
251+
216252
return nil
217-
}
253+
}

server/session_test.go

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package server
22

33
import (
44
"context"
5+
"errors"
56
"testing"
67
"time"
78

@@ -32,41 +33,41 @@ func (f sessionTestClient) Initialized() bool {
3233
return f.initialized
3334
}
3435

35-
// fakeSessionWithTools implements the SessionWithTools interface for testing
36-
type fakeSessionWithTools struct {
36+
// sessionTestClientWithTools implements the SessionWithTools interface for testing
37+
type sessionTestClientWithTools struct {
3738
sessionID string
3839
notificationChannel chan mcp.JSONRPCNotification
3940
initialized bool
4041
sessionTools map[string]ServerTool
4142
}
4243

43-
func (f *fakeSessionWithTools) SessionID() string {
44+
func (f *sessionTestClientWithTools) SessionID() string {
4445
return f.sessionID
4546
}
4647

47-
func (f *fakeSessionWithTools) NotificationChannel() chan<- mcp.JSONRPCNotification {
48+
func (f *sessionTestClientWithTools) NotificationChannel() chan<- mcp.JSONRPCNotification {
4849
return f.notificationChannel
4950
}
5051

51-
func (f *fakeSessionWithTools) Initialize() {
52+
func (f *sessionTestClientWithTools) Initialize() {
5253
f.initialized = true
5354
}
5455

55-
func (f *fakeSessionWithTools) Initialized() bool {
56+
func (f *sessionTestClientWithTools) Initialized() bool {
5657
return f.initialized
5758
}
5859

59-
func (f *fakeSessionWithTools) GetSessionTools() map[string]ServerTool {
60+
func (f *sessionTestClientWithTools) GetSessionTools() map[string]ServerTool {
6061
return f.sessionTools
6162
}
6263

63-
func (f *fakeSessionWithTools) SetSessionTools(tools map[string]ServerTool) {
64+
func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool) {
6465
f.sessionTools = tools
6566
}
6667

6768
// Verify that both implementations satisfy their respective interfaces
6869
var _ ClientSession = sessionTestClient{}
69-
var _ SessionWithTools = &fakeSessionWithTools{}
70+
var _ SessionWithTools = &sessionTestClientWithTools{}
7071

7172
func TestSessionWithTools_Integration(t *testing.T) {
7273
server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true))
@@ -80,7 +81,7 @@ func TestSessionWithTools_Integration(t *testing.T) {
8081
}
8182

8283
// Create a session with tools
83-
session := &fakeSessionWithTools{
84+
session := &sessionTestClientWithTools{
8485
sessionID: "session-1",
8586
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
8687
initialized: true,
@@ -145,7 +146,7 @@ func TestMCPServer_ToolsWithSessionTools(t *testing.T) {
145146
)
146147

147148
// Create a session with tools
148-
session := &fakeSessionWithTools{
149+
session := &sessionTestClientWithTools{
149150
sessionID: "session-1",
150151
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
151152
initialized: true,
@@ -194,7 +195,7 @@ func TestMCPServer_AddSessionTools(t *testing.T) {
194195

195196
// Create a session
196197
sessionChan := make(chan mcp.JSONRPCNotification, 10)
197-
session := &fakeSessionWithTools{
198+
session := &sessionTestClientWithTools{
198199
sessionID: "session-1",
199200
notificationChannel: sessionChan,
200201
initialized: true,
@@ -229,7 +230,7 @@ func TestMCPServer_DeleteSessionTools(t *testing.T) {
229230

230231
// Create a session with tools
231232
sessionChan := make(chan mcp.JSONRPCNotification, 10)
232-
session := &fakeSessionWithTools{
233+
session := &sessionTestClientWithTools{
233234
sessionID: "session-1",
234235
notificationChannel: sessionChan,
235236
initialized: true,
@@ -294,7 +295,7 @@ func TestMCPServer_ToolFiltering(t *testing.T) {
294295
)
295296

296297
// Create a session with tools
297-
session := &fakeSessionWithTools{
298+
session := &sessionTestClientWithTools{
298299
sessionID: "session-1",
299300
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
300301
initialized: true,
@@ -339,20 +340,20 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) {
339340
server := NewMCPServer("test-server", "1.0.0")
340341

341342
session1Chan := make(chan mcp.JSONRPCNotification, 10)
342-
session1 := &fakeSession{
343+
session1 := &sessionTestClient{
343344
sessionID: "session-1",
344345
notificationChannel: session1Chan,
345346
initialized: true,
346347
}
347348

348349
session2Chan := make(chan mcp.JSONRPCNotification, 10)
349-
session2 := &fakeSession{
350+
session2 := &sessionTestClient{
350351
sessionID: "session-2",
351352
notificationChannel: session2Chan,
352353
initialized: true,
353354
}
354355

355-
session3 := &fakeSession{
356+
session3 := &sessionTestClient{
356357
sessionID: "session-3",
357358
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
358359
initialized: false, // Not initialized
@@ -399,3 +400,74 @@ func TestMCPServer_SendNotificationToSpecificClient(t *testing.T) {
399400
assert.Error(t, err)
400401
assert.Contains(t, err.Error(), "not properly initialized")
401402
}
403+
404+
func TestMCPServer_NotificationChannelBlocked(t *testing.T) {
405+
// Set up a hooks object to capture error notifications
406+
errorCaptured := false
407+
errorSessionID := ""
408+
errorMethod := ""
409+
410+
hooks := &Hooks{}
411+
hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
412+
errorCaptured = true
413+
// Extract session ID and method from the error message metadata
414+
if msgMap, ok := message.(map[string]interface{}); ok {
415+
if sid, ok := msgMap["sessionID"].(string); ok {
416+
errorSessionID = sid
417+
}
418+
if m, ok := msgMap["method"].(string); ok {
419+
errorMethod = m
420+
}
421+
}
422+
// Verify the error is a notification channel blocked error
423+
assert.True(t, errors.Is(err, ErrNotificationChannelBlocked))
424+
})
425+
426+
// Create a server with hooks
427+
server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks))
428+
429+
// Create a session with a very small buffer that will get blocked
430+
smallBufferChan := make(chan mcp.JSONRPCNotification, 1)
431+
session := &sessionTestClient{
432+
sessionID: "blocked-session",
433+
notificationChannel: smallBufferChan,
434+
initialized: true,
435+
}
436+
437+
// Register the session
438+
err := server.RegisterSession(context.Background(), session)
439+
require.NoError(t, err)
440+
441+
// Fill the buffer first to ensure it gets blocked
442+
server.SendNotificationToSpecificClient(session.SessionID(), "first-message", nil)
443+
444+
// This will cause the buffer to block
445+
err = server.SendNotificationToSpecificClient(session.SessionID(), "blocked-message", nil)
446+
assert.Error(t, err)
447+
assert.Equal(t, ErrNotificationChannelBlocked, err)
448+
449+
// Wait a bit for the goroutine to execute
450+
time.Sleep(10 * time.Millisecond)
451+
452+
// Verify the error was logged via hooks
453+
assert.True(t, errorCaptured, "Error hook should have been called")
454+
assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook")
455+
assert.Equal(t, "blocked-message", errorMethod, "Method should be captured in the error hook")
456+
457+
// Also test SendNotificationToAllClients with a blocked channel
458+
// Reset the captured data
459+
errorCaptured = false
460+
errorSessionID = ""
461+
errorMethod = ""
462+
463+
// Send to all clients (which includes our blocked one)
464+
server.SendNotificationToAllClients("broadcast-message", nil)
465+
466+
// Wait a bit for the goroutine to execute
467+
time.Sleep(10 * time.Millisecond)
468+
469+
// Verify the error was logged via hooks
470+
assert.True(t, errorCaptured, "Error hook should have been called for broadcast")
471+
assert.Equal(t, "blocked-session", errorSessionID, "Session ID should be captured in the error hook")
472+
assert.Equal(t, "broadcast-message", errorMethod, "Method should be captured in the error hook")
473+
}

0 commit comments

Comments
 (0)