Skip to content

Commit e744c19

Browse files
authored
fix: panic when streamable HTTP server sends notification (#348)
fix: panic when streamable HTTP server sends notification
1 parent 991b31c commit e744c19

File tree

3 files changed

+166
-6
lines changed

3 files changed

+166
-6
lines changed

client/http_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/mark3labs/mcp-go/mcp"
7+
"github.com/mark3labs/mcp-go/server"
8+
"testing"
9+
"time"
10+
)
11+
12+
func TestHTTPClient(t *testing.T) {
13+
hooks := &server.Hooks{}
14+
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
15+
clientSession := server.ClientSessionFromContext(ctx)
16+
// wait until all the notifications are handled
17+
for len(clientSession.NotificationChannel()) > 0 {
18+
}
19+
time.Sleep(time.Millisecond * 50)
20+
})
21+
22+
// Create MCP server with capabilities
23+
mcpServer := server.NewMCPServer(
24+
"test-server",
25+
"1.0.0",
26+
server.WithToolCapabilities(true),
27+
server.WithHooks(hooks),
28+
)
29+
30+
mcpServer.AddTool(
31+
mcp.NewTool("notify"),
32+
func(
33+
ctx context.Context,
34+
request mcp.CallToolRequest,
35+
) (*mcp.CallToolResult, error) {
36+
server := server.ServerFromContext(ctx)
37+
err := server.SendNotificationToClient(
38+
ctx,
39+
"notifications/progress",
40+
map[string]any{
41+
"progress": 10,
42+
"total": 10,
43+
"progressToken": 0,
44+
},
45+
)
46+
if err != nil {
47+
return nil, fmt.Errorf("failed to send notification: %w", err)
48+
}
49+
50+
return &mcp.CallToolResult{
51+
Content: []mcp.Content{
52+
mcp.TextContent{
53+
Type: "text",
54+
Text: "notification sent successfully",
55+
},
56+
},
57+
}, nil
58+
},
59+
)
60+
61+
testServer := server.NewTestStreamableHTTPServer(mcpServer)
62+
defer testServer.Close()
63+
64+
t.Run("Can receive notification from server", func(t *testing.T) {
65+
client, err := NewStreamableHttpClient(testServer.URL)
66+
if err != nil {
67+
t.Fatalf("create client failed %v", err)
68+
return
69+
}
70+
71+
notificationNum := 0
72+
client.OnNotification(func(notification mcp.JSONRPCNotification) {
73+
notificationNum += 1
74+
})
75+
76+
ctx := context.Background()
77+
78+
if err := client.Start(ctx); err != nil {
79+
t.Fatalf("Failed to start client: %v", err)
80+
return
81+
}
82+
83+
// Initialize
84+
initRequest := mcp.InitializeRequest{}
85+
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
86+
initRequest.Params.ClientInfo = mcp.Implementation{
87+
Name: "test-client",
88+
Version: "1.0.0",
89+
}
90+
91+
_, err = client.Initialize(ctx, initRequest)
92+
if err != nil {
93+
t.Fatalf("Failed to initialize: %v\n", err)
94+
}
95+
96+
request := mcp.CallToolRequest{}
97+
request.Params.Name = "notify"
98+
result, err := client.CallTool(ctx, request)
99+
if err != nil {
100+
t.Fatalf("CallTool failed: %v", err)
101+
}
102+
103+
if len(result.Content) != 1 {
104+
t.Errorf("Expected 1 content item, got %d", len(result.Content))
105+
}
106+
107+
if notificationNum != 1 {
108+
t.Errorf("Expected 1 notification item, got %d", notificationNum)
109+
}
110+
})
111+
}

server/session.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ type SessionWithClientInfo interface {
4848
SetClientInfo(clientInfo mcp.Implementation)
4949
}
5050

51+
// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
52+
type SessionWithStreamableHTTPConfig interface {
53+
ClientSession
54+
// UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server
55+
// sends notifications to the client
56+
//
57+
// The protocol specification:
58+
// - If the server response contains any JSON-RPC notifications, it MUST either:
59+
// - Return Content-Type: text/event-stream to initiate an SSE stream, OR
60+
// - Return Content-Type: application/json for a single JSON object
61+
// - The client MUST support both response types.
62+
//
63+
// Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server
64+
UpgradeToSSEWhenReceiveNotification()
65+
}
66+
5167
// clientSessionKey is the context key for storing current client notification channel.
5268
type clientSessionKey struct{}
5369

@@ -146,6 +162,11 @@ func (s *MCPServer) SendNotificationToClient(
146162
return ErrNotificationNotInitialized
147163
}
148164

165+
// upgrades the client-server communication to SSE stream when the server sends notifications to the client
166+
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
167+
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
168+
}
169+
149170
notification := mcp.JSONRPCNotification{
150171
JSONRPC: mcp.JSONRPC_VERSION,
151172
Notification: mcp.Notification{
@@ -193,6 +214,11 @@ func (s *MCPServer) SendNotificationToSpecificClient(
193214
return ErrSessionNotInitialized
194215
}
195216

217+
// upgrades the client-server communication to SSE stream when the server sends notifications to the client
218+
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
219+
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
220+
}
221+
196222
notification := mcp.JSONRPCNotification{
197223
JSONRPC: mcp.JSONRPC_VERSION,
198224
Notification: mcp.Notification{

server/streamable_http.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http/httptest"
1010
"strings"
1111
"sync"
12+
"sync/atomic"
1213
"time"
1314

1415
"github.com/google/uuid"
@@ -243,9 +244,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
243244

244245
// handle potential notifications
245246
mu := sync.Mutex{}
246-
upgraded := false
247+
upgradedHeader := false
247248
done := make(chan struct{})
248-
defer close(done)
249249

250250
go func() {
251251
for {
@@ -254,20 +254,26 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
254254
func() {
255255
mu.Lock()
256256
defer mu.Unlock()
257+
// if the done chan is closed, as the request is terminated, just return
258+
select {
259+
case <-done:
260+
return
261+
default:
262+
}
257263
defer func() {
258264
flusher, ok := w.(http.Flusher)
259265
if ok {
260266
flusher.Flush()
261267
}
262268
}()
263269

264-
// if there's notifications, upgrade to SSE response
265-
if !upgraded {
266-
upgraded = true
270+
// if there's notifications, upgradedHeader to SSE response
271+
if !upgradedHeader {
267272
w.Header().Set("Content-Type", "text/event-stream")
268273
w.Header().Set("Connection", "keep-alive")
269274
w.Header().Set("Cache-Control", "no-cache")
270275
w.WriteHeader(http.StatusAccepted)
276+
upgradedHeader = true
271277
}
272278
err := writeSSEEvent(w, nt)
273279
if err != nil {
@@ -294,10 +300,20 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
294300
// Write response
295301
mu.Lock()
296302
defer mu.Unlock()
303+
// close the done chan before unlock
304+
defer close(done)
297305
if ctx.Err() != nil {
298306
return
299307
}
300-
if upgraded {
308+
// If client-server communication already upgraded to SSE stream
309+
if session.upgradeToSSE.Load() {
310+
if !upgradedHeader {
311+
w.Header().Set("Content-Type", "text/event-stream")
312+
w.Header().Set("Connection", "keep-alive")
313+
w.Header().Set("Cache-Control", "no-cache")
314+
w.WriteHeader(http.StatusAccepted)
315+
upgradedHeader = true
316+
}
301317
if err := writeSSEEvent(w, response); err != nil {
302318
s.logger.Errorf("Failed to write final SSE response event: %v", err)
303319
}
@@ -494,6 +510,7 @@ type streamableHttpSession struct {
494510
sessionID string
495511
notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
496512
tools *sessionToolsStore
513+
upgradeToSSE atomic.Bool
497514
}
498515

499516
func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
@@ -534,6 +551,12 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
534551

535552
var _ SessionWithTools = (*streamableHttpSession)(nil)
536553

554+
func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
555+
s.upgradeToSSE.Store(true)
556+
}
557+
558+
var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
559+
537560
// --- session id manager ---
538561

539562
type SessionIdManager interface {

0 commit comments

Comments
 (0)