@@ -560,17 +560,22 @@ func (s *Server) disconnect(cc *ServerSession) {
560
560
}
561
561
562
562
type SessionOptions struct {
563
- SessionID string
563
+ // SessionID is the session's unique ID.
564
+ SessionID string
565
+ // SessionState is the current state of the session. The default is the initial
566
+ // state.
564
567
SessionState * SessionState
568
+ // SessionStore stores SessionStates. By default it is a MemorySessionStore.
565
569
SessionStore SessionStore
566
570
}
567
571
568
572
// Connect connects the MCP server over the given transport and starts handling
569
573
// messages.
574
+ // It returns a [ServerSession] for interacting with a [Client].
570
575
//
571
- // It returns a connection object that may be used to terminate the connection
572
- // (with [Connection.Close]), or await client termination (with
573
- // [Connection.Wait]) .
576
+ // [SessionOptions.SessionStore] should be nil only for single-session transports,
577
+ // like [StdioTransport]. Multi-session transports, like [StreamableServerTransport],
578
+ // must provide a [SessionStore] .
574
579
func (s * Server ) Connect (ctx context.Context , t Transport , opts * SessionOptions ) (* ServerSession , error ) {
575
580
if opts != nil && opts .SessionState == nil && opts .SessionStore != nil {
576
581
return nil , errors .New ("ServerSessionOptions has store but no state" )
@@ -579,11 +584,21 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *SessionOptions)
579
584
if err != nil {
580
585
return nil , err
581
586
}
587
+ var state * SessionState
582
588
if opts != nil {
583
- ss .opts = * opts
589
+ ss .sessionID = opts .SessionID
590
+ ss .store = opts .SessionStore
591
+ state = opts .SessionState
584
592
}
585
- if ss .opts .SessionState == nil {
586
- ss .opts .SessionState = & SessionState {}
593
+ if ss .store == nil {
594
+ ss .store = NewMemorySessionStore ()
595
+ }
596
+ if state == nil {
597
+ state = & SessionState {}
598
+ }
599
+ // TODO(jba): This store is redundant with the one in StreamableHTTPHandler.ServeHTTP; dedup.
600
+ if err := ss .storeState (ctx , state ); err != nil {
601
+ return nil , err
587
602
}
588
603
return ss , nil
589
604
}
@@ -592,20 +607,21 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
592
607
if ss .server .opts .KeepAlive > 0 {
593
608
ss .startKeepalive (ss .server .opts .KeepAlive )
594
609
}
595
- ss .mu .Lock ()
596
- hasParams := ss .opts .SessionState .InitializeParams != nil
597
- wasInitialized := ss ._initialized
598
- if hasParams {
599
- ss ._initialized = true
610
+ // TODO(jba): optimistic locking
611
+ state , err := ss .loadState (ctx )
612
+ if err != nil {
613
+ return nil , err
600
614
}
601
- ss .mu .Unlock ()
602
-
603
- if ! hasParams {
615
+ if state .InitializeParams == nil {
604
616
return nil , fmt .Errorf ("%q before %q" , notificationInitialized , methodInitialize )
605
617
}
606
- if wasInitialized {
618
+ if state . Initialized {
607
619
return nil , fmt .Errorf ("duplicate %q received" , notificationInitialized )
608
620
}
621
+ state .Initialized = true
622
+ if err := ss .storeState (ctx , state ); err != nil {
623
+ return nil , err
624
+ }
609
625
return callNotificationHandler (ctx , ss .server .opts .InitializedHandler , ss , params )
610
626
}
611
627
@@ -634,19 +650,26 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot
634
650
type ServerSession struct {
635
651
server * Server
636
652
conn * jsonrpc2.Connection
637
- opts SessionOptions
638
653
mu sync.Mutex
639
654
logLevel LoggingLevel
640
- _initialized bool
641
655
keepaliveCancel context.CancelFunc
642
656
sessionID string
657
+ store SessionStore
643
658
}
644
659
645
660
func (ss * ServerSession ) setConn (c Connection ) {
646
661
}
647
662
648
663
func (ss * ServerSession ) ID () string {
649
- return ss .opts .SessionID
664
+ return ss .sessionID
665
+ }
666
+
667
+ func (ss * ServerSession ) loadState (ctx context.Context ) (* SessionState , error ) {
668
+ return ss .store .Load (ctx , ss .sessionID )
669
+ }
670
+
671
+ func (ss * ServerSession ) storeState (ctx context.Context , state * SessionState ) error {
672
+ return ss .store .Store (ctx , ss .sessionID , state )
650
673
}
651
674
652
675
// Ping pings the client.
@@ -669,16 +692,19 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag
669
692
// The message is not sent if the client has not called SetLevel, or if its level
670
693
// is below that of the last SetLevel.
671
694
func (ss * ServerSession ) Log (ctx context.Context , params * LoggingMessageParams ) error {
672
- ss .mu .Lock ()
673
- logLevel := ss .opts .SessionState .LogLevel
674
- ss .mu .Unlock ()
675
- if logLevel == "" {
695
+ // TODO: Loading the state on every log message can be expensive. Consider caching it briefly, perhaps for the
696
+ // duration of a request.
697
+ state , err := ss .loadState (ctx )
698
+ if err != nil {
699
+ return err
700
+ }
701
+ if state .LogLevel == "" {
676
702
// The spec is unclear, but seems to imply that no log messages are sent until the client
677
703
// sets the level.
678
704
// TODO(jba): read other SDKs, possibly file an issue.
679
705
return nil
680
706
}
681
- if compareLevels (params .Level , logLevel ) < 0 {
707
+ if compareLevels (params .Level , state . LogLevel ) < 0 {
682
708
return nil
683
709
}
684
710
return handleNotify (ctx , ss , notificationLoggingMessage , params )
@@ -759,37 +785,44 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn }
759
785
760
786
// handle invokes the method described by the given JSON RPC request.
761
787
func (ss * ServerSession ) handle (ctx context.Context , req * jsonrpc.Request ) (any , error ) {
762
- ss .mu .Lock ()
763
- initialized := ss ._initialized
764
- ss .mu .Unlock ()
788
+ state , err := ss .loadState (ctx )
789
+ if err != nil {
790
+ return nil , err
791
+ }
765
792
// From the spec:
766
793
// "The client SHOULD NOT send requests other than pings before the server
767
794
// has responded to the initialize request."
768
795
switch req .Method {
769
796
case methodInitialize , methodPing , notificationInitialized :
770
797
default :
771
- if ! initialized {
798
+ if ! state . Initialized {
772
799
return nil , fmt .Errorf ("method %q is invalid during session initialization" , req .Method )
773
800
}
774
801
}
775
802
// For the streamable transport, we need the request ID to correlate
776
803
// server->client calls and notifications to the incoming request from which
777
804
// they originated. See [idContextKey] for details.
778
805
ctx = context .WithValue (ctx , idContextKey {}, req .ID )
806
+ // TODO(jba): pass the state down so that it doesn't get loaded again.
779
807
return handleReceive (ctx , ss , req )
780
808
}
781
809
782
810
func (ss * ServerSession ) initialize (ctx context.Context , params * InitializeParams ) (* InitializeResult , error ) {
783
811
if params == nil {
784
812
return nil , fmt .Errorf ("%w: \" params\" must be be provided" , jsonrpc2 .ErrInvalidParams )
785
813
}
786
- ss .mu .Lock ()
787
- ss .opts .SessionState .InitializeParams = params
788
- ss .mu .Unlock ()
789
- if store := ss .opts .SessionStore ; store != nil {
790
- if err := store .Store (ctx , ss .opts .SessionID , ss .opts .SessionState ); err != nil {
791
- return nil , fmt .Errorf ("storing session state: %w" , err )
792
- }
814
+
815
+ // TODO(jba): optimistic locking
816
+ state , err := ss .loadState (ctx )
817
+ if err != nil {
818
+ return nil , err
819
+ }
820
+ if state .InitializeParams != nil {
821
+ return nil , fmt .Errorf ("session %s already initialized" , ss .sessionID )
822
+ }
823
+ state .InitializeParams = params
824
+ if err := ss .storeState (ctx , state ); err != nil {
825
+ return nil , fmt .Errorf ("storing session state: %w" , err )
793
826
}
794
827
795
828
// If we support the client's version, reply with it. Otherwise, reply with our
@@ -816,11 +849,14 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error
816
849
func (ss * ServerSession ) setLevel (ctx context.Context , params * SetLevelParams ) (* emptyResult , error ) {
817
850
ss .mu .Lock ()
818
851
defer ss .mu .Unlock ()
819
- ss .opts .SessionState .LogLevel = params .Level
820
- if store := ss .opts .SessionStore ; store != nil {
821
- if err := store .Store (ctx , ss .opts .SessionID , ss .opts .SessionState ); err != nil {
822
- return nil , err
823
- }
852
+
853
+ state , err := ss .loadState (ctx )
854
+ if err != nil {
855
+ return nil , err
856
+ }
857
+ state .LogLevel = params .Level
858
+ if err := ss .storeState (ctx , state ); err != nil {
859
+ return nil , err
824
860
}
825
861
return & emptyResult {}, nil
826
862
}
0 commit comments