Skip to content

feat: client-side streamable-http transport supports resumability #380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions client/transport/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,9 @@ func TestSSE(t *testing.T) {
t.Run("SSEEventWithoutEventField", func(t *testing.T) {
// Test that SSE events with only data field (no event field) are processed correctly
// This tests the fix for issue #369

var messageReceived chan struct{}

// Create a custom mock server that sends SSE events without event field
sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
Expand Down Expand Up @@ -449,7 +449,7 @@ func TestSSE(t *testing.T) {
messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)

// Signal that message was received
close(messageReceived)
})
Expand Down
177 changes: 162 additions & 15 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
}
}

// WithResumption enables the client to attempt resuming broken connections.
// This can help to reduce network congestion as the server does not need
// to redeliver messages that have already been sent on the previous broken
// connection.
//
// As the retry itself might fail, the retry count can be set and it must be a value >= 1.
// If the value is < 1, it will be set to 1.
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery
// NOTICE: Even enabled, the server may not support this feature.
func WithResumption(maxRetryCount int) StreamableHTTPCOption {
if maxRetryCount < 1 {
maxRetryCount = 1
}
return func(sc *StreamableHTTP) {
sc.resumptionEnabled = true
sc.maxRetryCount = maxRetryCount
}
}

// StreamableHTTP implements Streamable HTTP transport.
//
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
Expand All @@ -66,14 +85,14 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
// - batching
// - continuously listening for server notifications when no request is in flight
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
// - resuming stream
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
// - server -> client request
type StreamableHTTP struct {
serverURL *url.URL
httpClient *http.Client
headers map[string]string
headerFunc HTTPHeaderFunc
serverURL *url.URL
httpClient *http.Client
headers map[string]string
headerFunc HTTPHeaderFunc
resumptionEnabled bool
maxRetryCount int

sessionID atomic.Value // string

Expand Down Expand Up @@ -159,7 +178,8 @@ func (c *StreamableHTTP) Close() error {
}

const (
headerKeySessionID = "Mcp-Session-Id"
headerKeySessionID = "Mcp-Session-Id"
headerKeyLastEventID = "Last-Event-Id"
)

// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required
Expand Down Expand Up @@ -300,16 +320,131 @@ func (c *StreamableHTTP) SendRequest(

case "text/event-stream":
// Server is using SSE for streaming responses
return c.handleSSEResponse(ctx, resp.Body)

if !c.resumptionEnabled {
return c.handleSSEResponse(ctx, resp.Body, nil)
}
var lastEventId string
resumptionCallback := func(id string) {
lastEventId = id
}
resp, err := c.handleSSEResponse(ctx, resp.Body, resumptionCallback)
if err == nil || lastEventId == "" {
return resp, err
}
for range c.maxRetryCount {
resp, err, canRetry := c.performSSEResumption(ctx, lastEventId, resumptionCallback)
if err == nil || lastEventId == "" || !canRetry {
return resp, err
}
}
return nil, fmt.Errorf("failed to retrieve response after attempting resumption %d times", c.maxRetryCount)
default:
return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
}
}

// performSSEResumption sends a request to the server including both the
// session id and the last event id, expecting the server to return an
// SSE streaming response sending the events with ids after the last event id.
// It returns the final result for the request once received, or an error.
// A boolean is returned as well to indicate whether it is appropriate to
// attempt resumption if an error is returned.
func (c *StreamableHTTP) performSSEResumption(
ctx context.Context,
lastEventId string,
resumptionCallback func(string),
) (*JSONRPCResponse, error, bool) {

ctx, cancel := context.WithCancel(ctx)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.serverURL.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err), false
}

req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
sessionID := c.sessionID.Load()
if sessionID != "" {
req.Header.Set(headerKeySessionID, sessionID.(string))
}
if lastEventId == "" {
return nil, fmt.Errorf("sse resumption request requires a last event id"), false
}
req.Header.Set(headerKeyLastEventID, lastEventId)

for k, v := range c.headers {
req.Header.Set(k, v)
}

// Add OAuth authorization if configured
if c.oauthHandler != nil {
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
if err != nil {
// If we get an authorization error, return a specific error that can be handled by the client
if err.Error() == "no valid token available, authorization required" {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
}, false
}
return nil, fmt.Errorf("failed to get authorization header: %w", err), false
}
req.Header.Set("Authorization", authHeader)
}

if c.headerFunc != nil {
for k, v := range c.headerFunc(ctx) {
req.Header.Set(k, v)
}
}

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to connect to SSE stream: %w", err), true
}

// Check if we got an error response
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
// handle session closed
if resp.StatusCode == http.StatusNotFound {
c.sessionID.CompareAndSwap(sessionID, "")
return nil, fmt.Errorf("session terminated (404). need to re-initialize"), false
}

// Handle OAuth unauthorized error
if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
return nil, &OAuthAuthorizationRequiredError{
Handler: c.oauthHandler,
}, false
}

// handle error response
var errResponse JSONRPCResponse
body, _ := io.ReadAll(resp.Body)
if err := json.Unmarshal(body, &errResponse); err == nil {
return &errResponse, nil, false
}
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body), false
}

mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
switch mediaType {
case "text/event-stream":
resp, err := c.handleSSEResponse(ctx, resp.Body, resumptionCallback)
return resp, err, true
default:
return nil, fmt.Errorf("unexpected content type for sse resumption response: %s", resp.Header.Get("Content-Type")), false
}
}

// handleSSEResponse processes an SSE stream for a specific request.
// It returns the final result for the request once received, or an error.
func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
func (c *StreamableHTTP) handleSSEResponse(
ctx context.Context,
reader io.ReadCloser,
resumptionCallback func(string),
) (*JSONRPCResponse, error) {

// Create a channel for this specific request
responseChan := make(chan *JSONRPCResponse, 1)
Expand All @@ -322,7 +457,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
// only close responseChan after readingSSE()
defer close(responseChan)

c.readSSE(ctx, reader, func(event, data string) {
c.readSSE(ctx, reader, func(event, data, id string) {

// (unsupported: batching)

Expand All @@ -332,6 +467,10 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
return
}

if id != "" && resumptionCallback != nil {
resumptionCallback(id)
}

// Handle notification
if message.ID.IsNil() {
var notification mcp.JSONRPCNotification
Expand Down Expand Up @@ -365,11 +504,11 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl

// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
// It will end when the reader is closed (or the context is done).
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data, id string)) {
defer reader.Close()

br := bufio.NewReader(reader)
var event, data string
var event, data, id string

for {
select {
Expand All @@ -385,7 +524,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
if event == "" {
event = "message"
}
handler(event, data)
handler(event, data, id)
}
return
}
Expand All @@ -407,9 +546,10 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
if event == "" {
event = "message"
}
handler(event, data)
handler(event, data, id)
event = ""
data = ""
id = ""
}
continue
}
Expand All @@ -418,6 +558,13 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand
event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
} else if strings.HasPrefix(line, "id:") {
eventId := strings.TrimSpace(strings.TrimPrefix(line, "id:"))
if strings.Contains(eventId, "\x00") {
// will be sent back in HTTP header, a null byte in header breaks HTTP standard
continue
}
id = eventId
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions client/transport/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ func TestStreamableHTTP(t *testing.T) {
t.Run("SSEEventWithoutEventField", func(t *testing.T) {
// Test that SSE events with only data field (no event field) are processed correctly
// This tests the fix for issue #369

// Create a custom mock server that sends SSE events without event field
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Expand All @@ -437,7 +437,7 @@ func TestStreamableHTTP(t *testing.T) {
// This should be processed as a "message" event according to SSE spec
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)

response := map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
Expand Down