Skip to content

Commit 3ba0c91

Browse files
feat(sse): Add SessionWithTools support to SSEServer (#232)
Implement SessionWithTools interface for sseSession to support session-specific tools: - Add tools field to sseSession struct - Implement GetSessionTools and SetSessionTools methods
1 parent eadd702 commit 3ba0c91

File tree

2 files changed

+141
-2
lines changed

2 files changed

+141
-2
lines changed

server/sse.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ type sseSession struct {
2828
requestID atomic.Int64
2929
notificationChannel chan mcp.JSONRPCNotification
3030
initialized atomic.Bool
31+
tools sync.Map // stores session-specific tools
3132
}
3233

3334
// SSEContextFunc is a function that takes an existing context and the current
@@ -58,7 +59,34 @@ func (s *sseSession) Initialized() bool {
5859
return s.initialized.Load()
5960
}
6061

61-
var _ ClientSession = (*sseSession)(nil)
62+
func (s *sseSession) GetSessionTools() map[string]ServerTool {
63+
tools := make(map[string]ServerTool)
64+
s.tools.Range(func(key, value interface{}) bool {
65+
if tool, ok := value.(ServerTool); ok {
66+
tools[key.(string)] = tool
67+
}
68+
return true
69+
})
70+
return tools
71+
}
72+
73+
func (s *sseSession) SetSessionTools(tools map[string]ServerTool) {
74+
// Clear existing tools
75+
s.tools.Range(func(key, _ interface{}) bool {
76+
s.tools.Delete(key)
77+
return true
78+
})
79+
80+
// Set new tools
81+
for name, tool := range tools {
82+
s.tools.Store(name, tool)
83+
}
84+
}
85+
86+
var (
87+
_ ClientSession = (*sseSession)(nil)
88+
_ SessionWithTools = (*sseSession)(nil)
89+
)
6290

6391
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
6492
// It provides real-time communication capabilities over HTTP using the SSE protocol.

server/sse_test.go

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,8 @@ func TestSSEServer(t *testing.T) {
666666
t.Fatalf("Failed to marshal tool request: %v", err)
667667
}
668668

669-
req, err := http.NewRequest(http.MethodPost, messageURL, bytes.NewBuffer(requestBody))
669+
var req *http.Request
670+
req, err = http.NewRequest(http.MethodPost, messageURL, bytes.NewBuffer(requestBody))
670671
if err != nil {
671672
t.Fatalf("Failed to create tool request: %v", err)
672673
}
@@ -1129,6 +1130,116 @@ func TestSSEServer(t *testing.T) {
11291130
})
11301131
}
11311132
})
1133+
1134+
t.Run("SessionWithTools implementation", func(t *testing.T) {
1135+
// Create hooks to track sessions
1136+
hooks := &Hooks{}
1137+
var registeredSession *sseSession
1138+
hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) {
1139+
if s, ok := session.(*sseSession); ok {
1140+
registeredSession = s
1141+
}
1142+
})
1143+
1144+
mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
1145+
testServer := NewTestServer(mcpServer)
1146+
defer testServer.Close()
1147+
1148+
// Connect to SSE endpoint
1149+
sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL))
1150+
if err != nil {
1151+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
1152+
}
1153+
defer sseResp.Body.Close()
1154+
1155+
// Read the endpoint event to ensure session is established
1156+
_, err = readSeeEvent(sseResp)
1157+
if err != nil {
1158+
t.Fatalf("Failed to read SSE response: %v", err)
1159+
}
1160+
1161+
// Verify we got a session
1162+
if registeredSession == nil {
1163+
t.Fatal("Session was not registered via hook")
1164+
}
1165+
1166+
// Test setting and getting tools
1167+
tools := map[string]ServerTool{
1168+
"test_tool": {
1169+
Tool: mcp.Tool{
1170+
Name: "test_tool",
1171+
Description: "A test tool",
1172+
Annotations: mcp.ToolAnnotation{
1173+
Title: "Test Tool",
1174+
},
1175+
},
1176+
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1177+
return mcp.NewToolResultText("test"), nil
1178+
},
1179+
},
1180+
}
1181+
1182+
// Test SetSessionTools
1183+
registeredSession.SetSessionTools(tools)
1184+
1185+
// Test GetSessionTools
1186+
retrievedTools := registeredSession.GetSessionTools()
1187+
if len(retrievedTools) != 1 {
1188+
t.Errorf("Expected 1 tool, got %d", len(retrievedTools))
1189+
}
1190+
if tool, exists := retrievedTools["test_tool"]; !exists {
1191+
t.Error("Expected test_tool to exist")
1192+
} else if tool.Tool.Name != "test_tool" {
1193+
t.Errorf("Expected tool name test_tool, got %s", tool.Tool.Name)
1194+
}
1195+
1196+
// Test concurrent access
1197+
var wg sync.WaitGroup
1198+
for i := 0; i < 10; i++ {
1199+
wg.Add(2)
1200+
go func(i int) {
1201+
defer wg.Done()
1202+
tools := map[string]ServerTool{
1203+
fmt.Sprintf("tool_%d", i): {
1204+
Tool: mcp.Tool{
1205+
Name: fmt.Sprintf("tool_%d", i),
1206+
Description: fmt.Sprintf("Tool %d", i),
1207+
Annotations: mcp.ToolAnnotation{
1208+
Title: fmt.Sprintf("Tool %d", i),
1209+
},
1210+
},
1211+
},
1212+
}
1213+
registeredSession.SetSessionTools(tools)
1214+
}(i)
1215+
go func() {
1216+
defer wg.Done()
1217+
_ = registeredSession.GetSessionTools()
1218+
}()
1219+
}
1220+
wg.Wait()
1221+
1222+
// Verify we can still get and set tools after concurrent access
1223+
finalTools := map[string]ServerTool{
1224+
"final_tool": {
1225+
Tool: mcp.Tool{
1226+
Name: "final_tool",
1227+
Description: "Final Tool",
1228+
Annotations: mcp.ToolAnnotation{
1229+
Title: "Final Tool",
1230+
},
1231+
},
1232+
},
1233+
}
1234+
registeredSession.SetSessionTools(finalTools)
1235+
retrievedTools = registeredSession.GetSessionTools()
1236+
if len(retrievedTools) != 1 {
1237+
t.Errorf("Expected 1 tool, got %d", len(retrievedTools))
1238+
}
1239+
if _, exists := retrievedTools["final_tool"]; !exists {
1240+
t.Error("Expected final_tool to exist")
1241+
}
1242+
})
11321243
}
11331244

11341245
func readSeeEvent(sseResp *http.Response) (string, error) {

0 commit comments

Comments
 (0)