Skip to content

Commit 873706a

Browse files
author
王荣昌
committed
fix(client): optimize and standardize the readSSEStream function handling
1 parent 1b2f45c commit 873706a

File tree

4 files changed

+83
-102
lines changed

4 files changed

+83
-102
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.aider*
22
.env
3-
.idea
3+
.idea
4+
.vscode

client/transport/sse.go

Lines changed: 11 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
package transport
22

33
import (
4-
"bufio"
54
"bytes"
65
"context"
76
"encoding/json"
87
"fmt"
98
"io"
109
"net/http"
1110
"net/url"
12-
"strings"
1311
"sync"
1412
"sync/atomic"
1513
"time"
@@ -104,7 +102,7 @@ func (c *SSE) Start(ctx context.Context) error {
104102
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
105103
}
106104

107-
go c.readSSE(resp.Body)
105+
go c.readSSE(ctx, resp.Body)
108106

109107
// Wait for the endpoint to be received
110108
timeout := time.NewTimer(30 * time.Second)
@@ -125,56 +123,18 @@ func (c *SSE) Start(ctx context.Context) error {
125123

126124
// readSSE continuously reads the SSE stream and processes events.
127125
// It runs until the connection is closed or an error occurs.
128-
func (c *SSE) readSSE(reader io.ReadCloser) {
129-
defer reader.Close()
130-
131-
br := bufio.NewReader(reader)
132-
var event, data string
133-
134-
for {
135-
// when close or start's ctx cancel, the reader will be closed
136-
// and the for loop will break.
137-
line, err := br.ReadString('\n')
138-
if err != nil {
139-
if err == io.EOF {
140-
// Process any pending event before exit
141-
if event != "" && data != "" {
142-
c.handleSSEEvent(event, data)
143-
}
144-
break
145-
}
146-
if !c.closed.Load() {
147-
fmt.Printf("SSE stream error: %v\n", err)
148-
}
149-
return
150-
}
151-
152-
// Remove only newline markers
153-
line = strings.TrimRight(line, "\r\n")
154-
if line == "" {
155-
// Empty line means end of event
156-
if event != "" && data != "" {
157-
c.handleSSEEvent(event, data)
158-
event = ""
159-
data = ""
160-
}
161-
continue
162-
}
163-
164-
if strings.HasPrefix(line, "event:") {
165-
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
166-
} else if strings.HasPrefix(line, "data:") {
167-
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
168-
}
126+
func (c *SSE) readSSE(ctx context.Context, reader io.ReadCloser) {
127+
if err := ReadSSEStream(ctx, reader, c.handleSSEEvent); err != nil && !c.closed.Load() {
128+
fmt.Printf("SSE stream error: %v\n", err)
169129
}
170130
}
171131

172132
// handleSSEEvent processes SSE events based on their type.
173133
// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
174-
func (c *SSE) handleSSEEvent(event, data string) {
175-
switch event {
134+
func (c *SSE) handleSSEEvent(evt SSEEvent) {
135+
switch evt.event {
176136
case "endpoint":
177-
endpoint, err := c.baseURL.Parse(data)
137+
endpoint, err := c.baseURL.Parse(evt.data)
178138
if err != nil {
179139
fmt.Printf("Error parsing endpoint URL: %v\n", err)
180140
return
@@ -188,15 +148,15 @@ func (c *SSE) handleSSEEvent(event, data string) {
188148

189149
case "message":
190150
var baseMessage JSONRPCResponse
191-
if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
151+
if err := json.Unmarshal([]byte(evt.data), &baseMessage); err != nil {
192152
fmt.Printf("Error unmarshaling message: %v\n", err)
193153
return
194154
}
195155

196156
// Handle notification
197157
if baseMessage.ID == nil {
198158
var notification mcp.JSONRPCNotification
199-
if err := json.Unmarshal([]byte(data), &notification); err != nil {
159+
if err := json.Unmarshal([]byte(evt.data), &notification); err != nil {
200160
return
201161
}
202162
c.notifyMu.RLock()
@@ -255,7 +215,7 @@ func (c *SSE) SendRequest(
255215

256216
req, err := http.NewRequestWithContext(
257217
ctx,
258-
"POST",
218+
http.MethodPost,
259219
c.endpoint.String(),
260220
bytes.NewReader(requestBytes),
261221
)
@@ -333,7 +293,7 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
333293

334294
req, err := http.NewRequestWithContext(
335295
ctx,
336-
"POST",
296+
http.MethodPost,
337297
c.endpoint.String(),
338298
bytes.NewReader(notificationBytes),
339299
)

client/transport/sse_helper.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package transport
2+
3+
import (
4+
"bufio"
5+
"context"
6+
"fmt"
7+
"io"
8+
"strings"
9+
)
10+
11+
type SSEEvent struct {
12+
event string
13+
data string
14+
}
15+
16+
// ReadSSEStream continuously reads the SSE stream and processes events.
17+
func ReadSSEStream(ctx context.Context, reader io.ReadCloser, onEvent func(event SSEEvent)) error {
18+
defer func(reader io.ReadCloser) {
19+
err := reader.Close()
20+
if err != nil {
21+
fmt.Printf("Error closing reader: %v\n", err)
22+
}
23+
}(reader)
24+
25+
scanner := bufio.NewScanner(reader)
26+
var event, data strings.Builder
27+
28+
processEvent := func() {
29+
if event.Len() > 0 || data.Len() > 0 {
30+
onEvent(SSEEvent{event: event.String(), data: data.String()})
31+
event.Reset()
32+
data.Reset()
33+
}
34+
}
35+
36+
for scanner.Scan() {
37+
select {
38+
case <-ctx.Done():
39+
return nil
40+
default:
41+
line := scanner.Text()
42+
43+
switch {
44+
case strings.HasPrefix(line, "event:"):
45+
event.WriteString(strings.TrimSpace(strings.TrimPrefix(line, "event:")))
46+
case strings.HasPrefix(line, "data:"):
47+
if data.Len() > 0 {
48+
data.WriteString("\n")
49+
}
50+
data.WriteString(strings.TrimSpace(strings.TrimPrefix(line, "data:")))
51+
case line == "":
52+
processEvent()
53+
}
54+
}
55+
}
56+
// EOF Handle the last event after reaching EOF
57+
processEvent()
58+
59+
return scanner.Err()
60+
}

client/transport/streamable_http.go

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
package transport
22

33
import (
4-
"bufio"
54
"bytes"
65
"context"
76
"encoding/json"
87
"fmt"
98
"io"
109
"net/http"
1110
"net/url"
12-
"strings"
1311
"sync"
1412
"sync/atomic"
1513
"time"
@@ -245,20 +243,20 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
245243
// only close responseChan after readingSSE()
246244
defer close(responseChan)
247245

248-
c.readSSE(ctx, reader, func(event, data string) {
246+
c.readSSE(ctx, reader, func(evt SSEEvent) {
249247

250248
// (unsupported: batching)
251-
249+
252250
var message JSONRPCResponse
253-
if err := json.Unmarshal([]byte(data), &message); err != nil {
251+
if err := json.Unmarshal([]byte(evt.data), &message); err != nil {
254252
fmt.Printf("failed to unmarshal message: %v\n", err)
255253
return
256254
}
257-
255+
258256
// Handle notification
259257
if message.ID == nil {
260258
var notification mcp.JSONRPCNotification
261-
if err := json.Unmarshal([]byte(data), &notification); err != nil {
259+
if err := json.Unmarshal([]byte(evt.data), &notification); err != nil {
262260
fmt.Printf("failed to unmarshal notification: %v\n", err)
263261
return
264262
}
@@ -269,7 +267,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
269267
c.notifyMu.RUnlock()
270268
return
271269
}
272-
270+
273271
responseChan <- &message
274272
})
275273
}()
@@ -288,52 +286,14 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
288286

289287
// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
290288
// It will end when the reader is closed (or the context is done).
291-
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
292-
defer reader.Close()
293-
294-
br := bufio.NewReader(reader)
295-
var event, data string
296-
297-
for {
289+
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(evt SSEEvent)) {
290+
if err := ReadSSEStream(ctx, reader, handler); err != nil {
298291
select {
299292
case <-ctx.Done():
300293
return
301294
default:
302-
line, err := br.ReadString('\n')
303-
if err != nil {
304-
if err == io.EOF {
305-
// Process any pending event before exit
306-
if event != "" && data != "" {
307-
handler(event, data)
308-
}
309-
return
310-
}
311-
select {
312-
case <-ctx.Done():
313-
return
314-
default:
315-
fmt.Printf("SSE stream error: %v\n", err)
316-
return
317-
}
318-
}
319-
320-
// Remove only newline markers
321-
line = strings.TrimRight(line, "\r\n")
322-
if line == "" {
323-
// Empty line means end of event
324-
if event != "" && data != "" {
325-
handler(event, data)
326-
event = ""
327-
data = ""
328-
}
329-
continue
330-
}
331-
332-
if strings.HasPrefix(line, "event:") {
333-
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
334-
} else if strings.HasPrefix(line, "data:") {
335-
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
336-
}
295+
fmt.Printf("SSE stream error: %v\n", err)
296+
return
337297
}
338298
}
339299
}

0 commit comments

Comments
 (0)