Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/jsonrpc2/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ func (c *Connection) write(ctx context.Context, msg Message) error {

// For cancelled or rejected requests, we don't set the writeErr (which would
// break the connection). They can just be returned to the caller.
if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) && !errors.Is(err, ErrUnsupportedProtocolVersion) && !errors.Is(err, ErrMethodNotFound) {
if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) {
// The call to Write failed, and since ctx.Err() is nil we can't attribute
// the failure (even indirectly) to Context cancellation. The writer appears
// to be broken, and future writes are likely to also fail.
Expand Down
17 changes: 6 additions & 11 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,10 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp
return cs, nil
}

var werr *jsonrpc.Error
if !errors.As(err, &werr) {
return nil, err
}
// Try to negotiate a mutually supported version if the server
// reports an UnsupportedProtocolVersionError with a supported version.
if werr.Code == CodeUnsupportedProtocolVersion && werr.Data != nil {
var werr *jsonrpc.Error
if errors.As(err, &werr) && werr.Code == CodeUnsupportedProtocolVersion && werr.Data != nil {
var data UnsupportedProtocolVersionData
if err := json.Unmarshal(werr.Data, &data); err == nil {
if negotiatedVersion := negotiateMutuallySupportedVersion(data.Supported); negotiatedVersion != "" && negotiatedVersion >= protocolVersion20260630 {
Expand All @@ -318,13 +315,11 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp
}
}
}
// MethodNotFound and UnsupportedProtocolVersion trigger a fallback to legacy initialize.
if werr.Code == jsonrpc.CodeMethodNotFound || werr.Code == CodeUnsupportedProtocolVersion {
break
}
return nil, err
// Per the spec, fall back to the legacy initialize handshake on any
// non-modern error from server/discover.
break
}
// Fallback to the legacy initialize handshake with the legacy protocol version.
// Use the latest legacy protocol version for the fallback initialize.
protocolVersion = protocolVersion20251125
}

Expand Down
69 changes: 0 additions & 69 deletions mcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,6 @@ func TestClientConnectDiscover(t *testing.T) {
// Returning (nil, nil) means "let the default stub handle it" (which
// returns ErrMethodNotFound).
discoverHandler func() (Result, error)
wantConnectErr bool
// wantInitialize is true if the legacy initialize handshake should
// have run (i.e. discover signaled "not supported").
wantInitialize bool
Expand Down Expand Up @@ -710,16 +709,6 @@ func TestClientConnectDiscover(t *testing.T) {
wantInitialize: true,
wantVersion: latestProtocolVersion,
},
{
name: "unexpected error propagates and aborts Connect",
discoverHandler: func() (Result, error) {
return nil, &jsonrpc.Error{
Code: jsonrpc.CodeInternalError,
Message: "boom",
}
},
wantConnectErr: true,
},
}

for _, tc := range tests {
Expand Down Expand Up @@ -754,19 +743,6 @@ func TestClientConnectDiscover(t *testing.T) {

c := NewClient(testImpl, nil)
cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630})
if tc.wantConnectErr {
if err == nil {
_ = cs.Close()
t.Fatal("Connect succeeded, want error")
}
if !gotDiscover.Load() {
t.Error("server did not receive server/discover")
}
if gotInitialize.Load() {
t.Error("server received initialize but discover should have aborted Connect")
}
return
}
if err != nil {
t.Fatalf("Connect: %v", err)
}
Expand Down Expand Up @@ -1061,51 +1037,6 @@ func TestInMemory_E2E_DiscoverFallback_UnsupportedProtocolVersion(t *testing.T)
}
}

// TestInMemory_E2E_DiscoverPropagatesOtherErrors verifies that an unrelated
// error from the discover handler aborts Connect and does NOT silently
// fall back.
func TestInMemory_E2E_DiscoverPropagatesOtherErrors(t *testing.T) {
ctx := context.Background()

orig := supportedProtocolVersions
supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...)
t.Cleanup(func() { supportedProtocolVersions = orig })

var sawInitialize atomic.Bool
server := NewServer(&Implementation{Name: "broken-server", Version: "v1"}, nil)
server.AddReceivingMiddleware(func(next MethodHandler) MethodHandler {
return func(ctx context.Context, method string, req Request) (Result, error) {
switch method {
case methodDiscover:
return nil, &jsonrpc.Error{
Code: jsonrpc.CodeInternalError,
Message: "boom",
}
case methodInitialize:
sawInitialize.Store(true)
}
return next(ctx, method, req)
}
})

ct, st := NewInMemoryTransports()
ss, err := server.Connect(ctx, st, nil)
if err != nil {
t.Fatalf("server.Connect: %v", err)
}
defer ss.Close()

client := NewClient(&Implementation{Name: "new-client", Version: "v1"}, nil)
cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630})
if err == nil {
_ = cs.Close()
t.Fatal("Connect succeeded; want propagated discover error")
}
if sawInitialize.Load() {
t.Error("server received initialize; Connect should have aborted on the discover error")
}
}

// TestClientConnectDiscover_UnsupportedVersionNegotiation verifies the
// per SEP-2575 Version Negotiation Flow: when the client probes server/discover
// with a protocol version the server doesn't implement, the server's
Expand Down
26 changes: 12 additions & 14 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -2103,12 +2103,14 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
}

var requestSummary string
var requestMethod string
var forCall *jsonrpc.Request
switch msg := msg.(type) {
case *jsonrpc.Request:
requestSummary = fmt.Sprintf("sending %q", msg.Method)
if msg.IsCall() {
forCall = msg
requestMethod = msg.Method
}
case *jsonrpc.Response:
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
Expand Down Expand Up @@ -2184,10 +2186,14 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
}

if err := c.checkResponse(ctx, requestSummary, resp); err != nil {
// Only fail the connection for non-transient errors.
// Transient errors (wrapped with ErrRejected) should not break the connection.
// ErrMethodNotFound and ErrUnsupportedProtocolVersion should not break the connection as they trigger the initialize fallback.
if !errors.Is(err, jsonrpc2.ErrRejected) && !errors.Is(err, jsonrpc2.ErrMethodNotFound) && !errors.Is(err, jsonrpc2.ErrUnsupportedProtocolVersion) {
if requestMethod == methodDiscover {
// Wrap the discover failure with ErrRejected so the jsonrpc2 layer
// doesn't set writeErr, which would prevent the legacy initialize
// fallback from succeeding on the same connection.
err = fmt.Errorf("%w: %w", err, jsonrpc2.ErrRejected)
} else if !errors.Is(err, jsonrpc2.ErrRejected) {
// Only fail the connection for non-transient errors.
// Transient errors (wrapped with ErrRejected) should not break the connection.
c.fail(err)
}
return err
Expand Down Expand Up @@ -2400,10 +2406,8 @@ func (c *streamableClientConn) checkResponse(ctx context.Context, requestSummary
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// By default, always try to decode the body and surface the underlying
// JSON-RPC error (or detect vPre servers that reject "server/discover"
// as an unsupported method with a plain HTTP 400) regardless of the
// negotiated protocol version. Setting MCPGODEBUG=noprotocolerrorbody=1
// restores the previous behavior.
// JSON-RPC error.
// Setting MCPGODEBUG=noprotocolerrorbody=1 restores the previous behavior.
if noprotocolerrorbody == "1" {
return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode))
}
Expand All @@ -2412,12 +2416,6 @@ func (c *streamableClientConn) checkResponse(ctx context.Context, requestSummary
if response, ok := msg.(*jsonrpc.Response); ok && response.Error != nil {
return fmt.Errorf("%s: %w: %v", requestSummary, response.Error, http.StatusText(resp.StatusCode))
}
if strings.Contains(string(body), fmt.Sprintf("%s: %q unsupported", jsonrpc2.ErrNotHandled, methodDiscover)) {
return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrMethodNotFound, http.StatusText(resp.StatusCode))
}
if strings.Contains(string(body), "Unsupported protocol version") {
return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrUnsupportedProtocolVersion, http.StatusText(resp.StatusCode))
}
return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode))
}
return nil
Expand Down
118 changes: 65 additions & 53 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1470,59 +1470,6 @@ func TestStreamableClientConnect_DiscoverUnsupportedVersion(t *testing.T) {
}
}

// TestStreamableClientConnect_DiscoverPropagatesOtherErrors verifies that
// Client.Connect does NOT fall back to initialize when server/discover
// returns an unrelated JSON-RPC error (here, CodeInternalError). The Connect
// call should fail with the propagated error rather than masking it.
func TestStreamableClientConnect_DiscoverPropagatesOtherErrors(t *testing.T) {
ctx := context.Background()

var sawInitialize atomic.Bool
fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodDiscover, ""}: {
header: header{"Content-Type": "application/json"},
wantProtocolVersion: protocolVersion20260630,
responseFunc: func(r *jsonrpc.Request) (string, int) {
return jsonBody(t, &jsonrpc.Response{
ID: r.ID,
Error: &jsonrpc.Error{
Code: jsonrpc.CodeInternalError,
Message: "boom",
},
}), http.StatusOK
},
},
{"POST", "", methodInitialize, ""}: {
responseFunc: func(r *jsonrpc.Request) (string, int) {
sawInitialize.Store(true)
return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(initResult)}), http.StatusOK
},
header: header{
"Content-Type": "application/json",
sessionIDHeader: "fallback",
},
optional: true,
},
},
}

httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: protocolVersion20260630})
if err == nil {
_ = session.Close()
t.Fatal("Connect succeeded; want propagated error")
}
if sawInitialize.Load() {
t.Error("server received initialize; Connect should have aborted on the discover error")
}
}

// TestStreamableClientConnect_DiscoverMethodNotFoundVPre verifies that
// Client.Connect falls back to the legacy initialize handshake when a
// pre-SEP-2575 (vPre) server rejects server/discover.
Expand Down Expand Up @@ -1635,3 +1582,68 @@ func TestStreamableClientConnect_DiscoverUnsupportedVersionVPre(t *testing.T) {
t.Errorf("InitializeResult.ProtocolVersion = %q, want %q (initialize fallback)", got, latestProtocolVersion)
}
}

// TestStreamableClientConnect_DiscoverUnsupportedVersionNegotiation verifies that
// when Client.Connect over a streamable transport receives an
// UnsupportedProtocolVersion error containing Data.Supported, it negotiates a
// mutually supported version and retries server/discover.
func TestStreamableClientConnect_DiscoverUnsupportedVersionNegotiation(t *testing.T) {
ctx := context.Background()

oldSupported := supportedProtocolVersions
supportedProtocolVersions = append([]string{protocolVersion20260630}, supportedProtocolVersions...)
t.Cleanup(func() {
supportedProtocolVersions = oldSupported
})

const unsupportedClientVersion = "2099-12-31"

var discoverCalls atomic.Int32

fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodDiscover, ""}: {
header: header{
"Content-Type": "application/json",
},
responseFunc: func(r *jsonrpc.Request) (string, int) {
n := discoverCalls.Add(1)
if n == 1 {
data, _ := json.Marshal(UnsupportedProtocolVersionData{
Supported: []string{protocolVersion20260630},
})
respMsg := &jsonrpc.Response{
ID: r.ID,
Error: &jsonrpc.Error{
Code: CodeUnsupportedProtocolVersion,
Message: "unsupported protocol version",
Data: data,
},
}
return jsonBody(t, respMsg), http.StatusOK
}
return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(discoverResult)}), http.StatusOK
},
},
},
}

httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: unsupportedClientVersion})
if err != nil {
t.Fatalf("Connect: %v", err)
}
defer session.Close()

if got, want := discoverCalls.Load(), int32(2); got != want {
t.Errorf("discover call count = %d, want %d", got, want)
}
if got := session.InitializeResult().ProtocolVersion; got != protocolVersion20260630 {
t.Errorf("InitializeResult.ProtocolVersion = %q, want %q", got, protocolVersion20260630)
}
}
Loading