Skip to content

Commit 5ebbb94

Browse files
committed
load/save state for each method; add test
1 parent 5b44363 commit 5ebbb94

File tree

6 files changed

+188
-57
lines changed

6 files changed

+188
-57
lines changed

mcp/logging.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,12 @@ func NewLoggingHandler(ss *ServerSession, opts *LoggingHandlerOptions) *LoggingH
112112
return lh
113113
}
114114

115-
// Enabled implements [slog.Handler.Enabled] by comparing level to the [ServerSession]'s level.
115+
// Enabled implements [slog.Handler.Enabled].
116116
func (h *LoggingHandler) Enabled(ctx context.Context, level slog.Level) bool {
117-
// This is also checked in ServerSession.LoggingMessage, so checking it here
118-
// is just an optimization that skips building the JSON.
119-
h.ss.mu.Lock()
120-
mcpLevel := h.ss.opts.SessionState.LogLevel
121-
h.ss.mu.Unlock()
122-
return level >= mcpLevelToSlog(mcpLevel)
117+
// This is already checked in ServerSession.LoggingMessage. Checking it here
118+
// would be an optimization that skips building the JSON, but it would
119+
// end up loading the SessionState twice, so don't do it.
120+
return true
123121
}
124122

125123
// WithAttrs implements [slog.Handler.WithAttrs].

mcp/server.go

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -560,17 +560,22 @@ func (s *Server) disconnect(cc *ServerSession) {
560560
}
561561

562562
type SessionOptions struct {
563-
SessionID string
563+
// SessionID is the session's unique ID.
564+
SessionID string
565+
// SessionState is the current state of the session. The default is the initial
566+
// state.
564567
SessionState *SessionState
568+
// SessionStore stores SessionStates. By default it is a MemorySessionStore.
565569
SessionStore SessionStore
566570
}
567571

568572
// Connect connects the MCP server over the given transport and starts handling
569573
// messages.
574+
// It returns a [ServerSession] for interacting with a [Client].
570575
//
571-
// It returns a connection object that may be used to terminate the connection
572-
// (with [Connection.Close]), or await client termination (with
573-
// [Connection.Wait]).
576+
// [SessionOptions.SessionStore] should be nil only for single-session transports,
577+
// like [StdioTransport]. Multi-session transports, like [StreamableServerTransport],
578+
// must provide a [SessionStore].
574579
func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions) (*ServerSession, error) {
575580
if opts != nil && opts.SessionState == nil && opts.SessionStore != nil {
576581
return nil, errors.New("ServerSessionOptions has store but no state")
@@ -579,11 +584,21 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions)
579584
if err != nil {
580585
return nil, err
581586
}
587+
var state *SessionState
582588
if opts != nil {
583-
ss.opts = *opts
589+
ss.sessionID = opts.SessionID
590+
ss.store = opts.SessionStore
591+
state = opts.SessionState
584592
}
585-
if ss.opts.SessionState == nil {
586-
ss.opts.SessionState = &SessionState{}
593+
if ss.store == nil {
594+
ss.store = NewMemorySessionStore()
595+
}
596+
if state == nil {
597+
state = &SessionState{}
598+
}
599+
// TODO(jba): This store is redundant with the one in StreamableHTTPHandler.ServeHTTP; dedup.
600+
if err := ss.storeState(ctx, state); err != nil {
601+
return nil, err
587602
}
588603
return ss, nil
589604
}
@@ -592,20 +607,21 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
592607
if ss.server.opts.KeepAlive > 0 {
593608
ss.startKeepalive(ss.server.opts.KeepAlive)
594609
}
595-
ss.mu.Lock()
596-
hasParams := ss.opts.SessionState.InitializeParams != nil
597-
wasInitialized := ss._initialized
598-
if hasParams {
599-
ss._initialized = true
610+
// TODO(jba): optimistic locking
611+
state, err := ss.loadState(ctx)
612+
if err != nil {
613+
return nil, err
600614
}
601-
ss.mu.Unlock()
602-
603-
if !hasParams {
615+
if state.InitializeParams == nil {
604616
return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize)
605617
}
606-
if wasInitialized {
618+
if state.Initialized {
607619
return nil, fmt.Errorf("duplicate %q received", notificationInitialized)
608620
}
621+
state.Initialized = true
622+
if err := ss.storeState(ctx, state); err != nil {
623+
return nil, err
624+
}
609625
return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params)
610626
}
611627

@@ -634,19 +650,26 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
634650
type ServerSession struct {
635651
server *Server
636652
conn *jsonrpc2.Connection
637-
opts SessionOptions
638653
mu sync.Mutex
639654
logLevel LoggingLevel
640-
_initialized bool
641655
keepaliveCancel context.CancelFunc
642656
sessionID string
657+
store SessionStore
643658
}
644659

645660
func (ss *ServerSession) setConn(c Connection) {
646661
}
647662

648663
func (ss *ServerSession) ID() string {
649-
return ss.opts.SessionID
664+
return ss.sessionID
665+
}
666+
667+
func (ss *ServerSession) loadState(ctx context.Context) (*SessionState, error) {
668+
return ss.store.Load(ctx, ss.sessionID)
669+
}
670+
671+
func (ss *ServerSession) storeState(ctx context.Context, state *SessionState) error {
672+
return ss.store.Store(ctx, ss.sessionID, state)
650673
}
651674

652675
// Ping pings the client.
@@ -669,16 +692,19 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag
669692
// The message is not sent if the client has not called SetLevel, or if its level
670693
// is below that of the last SetLevel.
671694
func (ss *ServerSession) Log(ctx context.Context, params *LoggingMessageParams) error {
672-
ss.mu.Lock()
673-
logLevel := ss.opts.SessionState.LogLevel
674-
ss.mu.Unlock()
675-
if logLevel == "" {
695+
// TODO: Loading the state on every log message can be expensive. Consider caching it briefly, perhaps for the
696+
// duration of a request.
697+
state, err := ss.loadState(ctx)
698+
if err != nil {
699+
return err
700+
}
701+
if state.LogLevel == "" {
676702
// The spec is unclear, but seems to imply that no log messages are sent until the client
677703
// sets the level.
678704
// TODO(jba): read other SDKs, possibly file an issue.
679705
return nil
680706
}
681-
if compareLevels(params.Level, logLevel) < 0 {
707+
if compareLevels(params.Level, state.LogLevel) < 0 {
682708
return nil
683709
}
684710
return handleNotify(ctx, ss, notificationLoggingMessage, params)
@@ -759,37 +785,44 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn }
759785

760786
// handle invokes the method described by the given JSON RPC request.
761787
func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) {
762-
ss.mu.Lock()
763-
initialized := ss._initialized
764-
ss.mu.Unlock()
788+
state, err := ss.loadState(ctx)
789+
if err != nil {
790+
return nil, err
791+
}
765792
// From the spec:
766793
// "The client SHOULD NOT send requests other than pings before the server
767794
// has responded to the initialize request."
768795
switch req.Method {
769796
case methodInitialize, methodPing, notificationInitialized:
770797
default:
771-
if !initialized {
798+
if !state.Initialized {
772799
return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method)
773800
}
774801
}
775802
// For the streamable transport, we need the request ID to correlate
776803
// server->client calls and notifications to the incoming request from which
777804
// they originated. See [idContextKey] for details.
778805
ctx = context.WithValue(ctx, idContextKey{}, req.ID)
806+
// TODO(jba): pass the state down so that it doesn't get loaded again.
779807
return handleReceive(ctx, ss, req)
780808
}
781809

782810
func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) {
783811
if params == nil {
784812
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
785813
}
786-
ss.mu.Lock()
787-
ss.opts.SessionState.InitializeParams = params
788-
ss.mu.Unlock()
789-
if store := ss.opts.SessionStore; store != nil {
790-
if err := store.Store(ctx, ss.opts.SessionID, ss.opts.SessionState); err != nil {
791-
return nil, fmt.Errorf("storing session state: %w", err)
792-
}
814+
815+
// TODO(jba): optimistic locking
816+
state, err := ss.loadState(ctx)
817+
if err != nil {
818+
return nil, err
819+
}
820+
if state.InitializeParams != nil {
821+
return nil, fmt.Errorf("session %s already initialized", ss.sessionID)
822+
}
823+
state.InitializeParams = params
824+
if err := ss.storeState(ctx, state); err != nil {
825+
return nil, fmt.Errorf("storing session state: %w", err)
793826
}
794827

795828
// If we support the client's version, reply with it. Otherwise, reply with our
@@ -816,11 +849,14 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error
816849
func (ss *ServerSession) setLevel(ctx context.Context, params *SetLevelParams) (*emptyResult, error) {
817850
ss.mu.Lock()
818851
defer ss.mu.Unlock()
819-
ss.opts.SessionState.LogLevel = params.Level
820-
if store := ss.opts.SessionStore; store != nil {
821-
if err := store.Store(ctx, ss.opts.SessionID, ss.opts.SessionState); err != nil {
822-
return nil, err
823-
}
852+
853+
state, err := ss.loadState(ctx)
854+
if err != nil {
855+
return nil, err
856+
}
857+
state.LogLevel = params.Level
858+
if err := ss.storeState(ctx, state); err != nil {
859+
return nil, err
824860
}
825861
return &emptyResult{}, nil
826862
}

mcp/session.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,32 @@ package mcp
66

77
import (
88
"context"
9-
"io/fs"
9+
"errors"
10+
"fmt"
1011
"sync"
1112
)
1213

1314
// SessionState is the state of a session.
1415
type SessionState struct {
1516
// InitializeParams are the parameters from the initialize request.
1617
InitializeParams *InitializeParams `json:"initializeParams"`
18+
// Initialized reports whether the session received an "initialized" notification
19+
// from the client.
20+
Initialized bool
1721

1822
// LogLevel is the logging level for the session.
1923
LogLevel LoggingLevel `json:"logLevel"`
2024

2125
// TODO: resource subscriptions
2226
}
2327

28+
// ErrNotSession indicates that a session is not in a SessionStore.
29+
var ErrNoSession = errors.New("no such session")
30+
2431
// SessionStore is an interface for storing and retrieving session state.
2532
type SessionStore interface {
2633
// Load retrieves the session state for the given session ID.
27-
// If there is none, it returns nil, fs.ErrNotExist.
34+
// If there is none, it returns nil and an error wrapping ErrNoSession.
2835
Load(ctx context.Context, sessionID string) (*SessionState, error)
2936
// Store saves the session state for the given session ID.
3037
Store(ctx context.Context, sessionID string, state *SessionState) error
@@ -52,7 +59,7 @@ func (s *MemorySessionStore) Load(ctx context.Context, sessionID string) (*Sessi
5259
defer s.mu.Unlock()
5360
state, ok := s.store[sessionID]
5461
if !ok {
55-
return nil, fs.ErrNotExist
62+
return nil, fmt.Errorf("session ID %q: %w", sessionID, ErrNoSession)
5663
}
5764
return state, nil
5865
}

mcp/session_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ package mcp
66

77
import (
88
"context"
9-
"io/fs"
9+
"errors"
1010
"testing"
1111
)
1212

@@ -39,8 +39,8 @@ func TestMemorySessionStore(t *testing.T) {
3939
}
4040

4141
deletedState, err := store.Load(ctx, sessionID)
42-
if err != fs.ErrNotExist {
43-
t.Fatalf("Load() after Delete(): got %v, want fs.ErrNotExist", err)
42+
if !errors.Is(err, ErrNoSession) {
43+
t.Fatalf("Load() after Delete(): got %v, want ErrNoSession", err)
4444
}
4545
if deletedState != nil {
4646
t.Error("Load() after Delete() returned non-nil state")

mcp/streamable.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
163163
return
164164
}
165165
} else {
166+
// New session: store an empty state.
166167
state = &SessionState{}
167-
sessionID = randText()
168+
sessionID = newSessionID()
168169
if err := h.opts.SessionStore.Store(req.Context(), sessionID, state); err != nil {
169170
http.Error(w, fmt.Sprintf("SessionStore.Store, new session: %v", err), http.StatusInternalServerError)
170171
return
@@ -1233,3 +1234,6 @@ func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time
12331234

12341235
return backoffDuration + jitter
12351236
}
1237+
1238+
// For testing.
1239+
var newSessionID = randText

0 commit comments

Comments
 (0)