diff --git a/client/client.go b/client/client.go index 60fe0cbf..63986328 100644 --- a/client/client.go +++ b/client/client.go @@ -16,11 +16,21 @@ import ( type Client struct { transport transport.Interface - initialized bool - notifications []func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - requestID atomic.Int64 - capabilities mcp.ServerCapabilities + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + requestID atomic.Int64 + clientCapabilities mcp.ClientCapabilities + serverCapabilities mcp.ServerCapabilities +} + +type ClientOption func(*Client) + +// WithClientCapabilities sets the client capabilities for the client. +func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { + return func(c *Client) { + c.clientCapabilities = capabilities + } } // NewClient creates a new MCP client with the given transport. @@ -31,10 +41,16 @@ type Client struct { // if err != nil { // log.Fatalf("Failed to create client: %v", err) // } -func NewClient(transport transport.Interface) *Client { - return &Client{ +func NewClient(transport transport.Interface, options ...ClientOption) *Client { + client := &Client{ transport: transport, } + + for _, opt := range options { + opt(client) + } + + return client } // Start initiates the connection to the server. @@ -115,7 +131,7 @@ func (c *Client) Initialize( params := struct { ProtocolVersion string `json:"protocolVersion"` ClientInfo mcp.Implementation `json:"clientInfo"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` + Capabilities mcp.ClientCapabilities `json:"serverCapabilities"` }{ ProtocolVersion: request.Params.ProtocolVersion, ClientInfo: request.Params.ClientInfo, @@ -132,8 +148,8 @@ func (c *Client) Initialize( return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - // Store capabilities - c.capabilities = result.Capabilities + // Store serverCapabilities + c.serverCapabilities = result.Capabilities // Send initialized notification notification := mcp.JSONRPCNotification{ @@ -406,3 +422,13 @@ func listByPage[T any]( func (c *Client) GetTransport() transport.Interface { return c.transport } + +// GetServerCapabilities returns the server capabilities. +func (c *Client) GetServerCapabilities() mcp.ServerCapabilities { + return c.serverCapabilities +} + +// GetClientCapabilities returns the client capabilities. +func (c *Client) GetClientCapabilities() mcp.ClientCapabilities { + return c.clientCapabilities +} diff --git a/client/inprocess.go b/client/inprocess.go new file mode 100644 index 00000000..5d8559de --- /dev/null +++ b/client/inprocess.go @@ -0,0 +1,12 @@ +package client + +import ( + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/server" +) + +// NewInProcessClient connect directly to a mcp server object in the same process +func NewInProcessClient(server *server.MCPServer) (*Client, error) { + inProcessTransport := transport.NewInProcessTransport(server) + return NewClient(inProcessTransport), nil +} diff --git a/client/inprocess_test.go b/client/inprocess_test.go new file mode 100644 index 00000000..de447602 --- /dev/null +++ b/client/inprocess_test.go @@ -0,0 +1,407 @@ +package client + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func TestInProcessMCPClient(t *testing.T) { + mcpServer := server.NewMCPServer( + "test-server", + "1.0.0", + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + ) + + // Add a test tool + mcpServer.AddTool(mcp.NewTool( + "test-tool", + mcp.WithDescription("Test tool"), + mcp.WithString("parameter-1", mcp.Description("A string tool parameter")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: "Test Tool Annotation Title", + ReadOnlyHint: true, + DestructiveHint: false, + IdempotentHint: true, + OpenWorldHint: false, + }), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string), + }, + }, + }, nil + }) + + mcpServer.AddResource( + mcp.Resource{ + URI: "resource://testresource", + Name: "My Resource", + }, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "resource://testresource", + MIMEType: "text/plain", + Text: "test content", + }, + }, nil + }, + ) + + mcpServer.AddPrompt( + mcp.Prompt{ + Name: "test-prompt", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + }, + }, + }, + func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Test prompt with arg1: " + request.Params.Arguments["arg1"], + }, + }, + }, + }, nil + }, + ) + + t.Run("Can initialize and make requests", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Start the client + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + result, err := client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + if result.ServerInfo.Name != "test-server" { + t.Errorf( + "Expected server name 'test-server', got '%s'", + result.ServerInfo.Name, + ) + } + + // Test Ping + if err := client.Ping(context.Background()); err != nil { + t.Errorf("Ping failed: %v", err) + } + + // Test ListTools + toolsRequest := mcp.ListToolsRequest{} + toolListResult, err := client.ListTools(context.Background(), toolsRequest) + if err != nil { + t.Errorf("ListTools failed: %v", err) + } + if toolListResult == nil || len((*toolListResult).Tools) == 0 { + t.Errorf("Expected one tool") + } + testToolAnnotations := (*toolListResult).Tools[0].Annotations + if testToolAnnotations.Title != "Test Tool Annotation Title" || + testToolAnnotations.ReadOnlyHint != true || + testToolAnnotations.DestructiveHint != false || + testToolAnnotations.IdempotentHint != true || + testToolAnnotations.OpenWorldHint != false { + t.Errorf("The annotations of the tools are invalid") + } + }) + + t.Run("Handles errors properly", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Try to make a request without initializing + toolsRequest := mcp.ListToolsRequest{} + _, err = client.ListTools(context.Background(), toolsRequest) + if err == nil { + t.Error("Expected error when making request before initialization") + } + }) + + t.Run("CallTool", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.CallToolRequest{} + request.Params.Name = "test-tool" + request.Params.Arguments = map[string]interface{}{ + "parameter-1": "value1", + } + + result, err := client.CallTool(context.Background(), request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + if len(result.Content) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Content)) + } + }) + + t.Run("Ping", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + err = client.Ping(context.Background()) + if err != nil { + t.Errorf("Ping failed: %v", err) + } + }) + + t.Run("ListResources", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.ListResourcesRequest{} + result, err := client.ListResources(context.Background(), request) + if err != nil { + t.Errorf("ListResources failed: %v", err) + } + + if len(result.Resources) != 1 { + t.Errorf("Expected 1 resource, got %d", len(result.Resources)) + } + }) + + t.Run("ReadResource", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.ReadResourceRequest{} + request.Params.URI = "resource://testresource" + + result, err := client.ReadResource(context.Background(), request) + if err != nil { + t.Errorf("ReadResource failed: %v", err) + } + + if len(result.Contents) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Contents)) + } + }) + + t.Run("ListPrompts", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + request := mcp.ListPromptsRequest{} + result, err := client.ListPrompts(context.Background(), request) + if err != nil { + t.Errorf("ListPrompts failed: %v", err) + } + + if len(result.Prompts) != 1 { + t.Errorf("Expected 1 prompt, got %d", len(result.Prompts)) + } + }) + + t.Run("GetPrompt", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.GetPromptRequest{} + request.Params.Name = "test-prompt" + + result, err := client.GetPrompt(context.Background(), request) + if err != nil { + t.Errorf("GetPrompt failed: %v", err) + } + + if len(result.Messages) != 1 { + t.Errorf("Expected 1 message, got %d", len(result.Messages)) + } + }) + + t.Run("ListTools", func(t *testing.T) { + client, err := NewInProcessClient(mcpServer) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(context.Background(), initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + request := mcp.ListToolsRequest{} + result, err := client.ListTools(context.Background(), request) + if err != nil { + t.Errorf("ListTools failed: %v", err) + } + + if len(result.Tools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(result.Tools)) + } + }) +} diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go new file mode 100644 index 00000000..90fc2fae --- /dev/null +++ b/client/transport/inprocess.go @@ -0,0 +1,70 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +type InProcessTransport struct { + server *server.MCPServer + + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex +} + +func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { + return &InProcessTransport{ + server: server, + } +} + +func (c *InProcessTransport) Start(ctx context.Context) error { + return nil +} + +func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + respMessage := c.server.HandleMessage(ctx, requestBytes) + respByte, err := json.Marshal(respMessage) + if err != nil { + return nil, fmt.Errorf("failed to marshal response message: %w", err) + } + rpcResp := JSONRPCResponse{} + err = json.Unmarshal(respByte, &rpcResp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response message: %w", err) + } + + return &rpcResp, nil +} + +func (c *InProcessTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + notificationBytes, err := json.Marshal(notification) + if err != nil { + return fmt.Errorf("failed to marshal notification: %w", err) + } + notificationBytes = append(notificationBytes, '\n') + c.server.HandleMessage(ctx, notificationBytes) + + return nil +} + +func (c *InProcessTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.onNotification = handler +} + +func (*InProcessTransport) Close() error { + return nil +} diff --git a/mcp/types.go b/mcp/types.go index c940a460..2b2c6f00 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -12,40 +12,54 @@ type MCPMethod string const ( // Initiates connection and negotiates protocol capabilities. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization + // https://modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization MethodInitialize MCPMethod = "initialize" // Verifies connection liveness between client and server. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ + // https://modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ MethodPing MCPMethod = "ping" // Lists all available server resources. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesList MCPMethod = "resources/list" // Provides URI templates for constructing resource URIs. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesTemplatesList MCPMethod = "resources/templates/list" // Retrieves content of a specific resource by URI. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ MethodResourcesRead MCPMethod = "resources/read" // Lists all available prompt templates. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ MethodPromptsList MCPMethod = "prompts/list" // Retrieves a specific prompt template with filled parameters. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ MethodPromptsGet MCPMethod = "prompts/get" // Lists all available executable tools. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ MethodToolsList MCPMethod = "tools/list" // Invokes a specific tool with provided parameters. - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ MethodToolsCall MCPMethod = "tools/call" + + // Notifies when the list of available resources changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification + MethodNotificationResourcesListChanged = "notifications/resources/list_changed" + + MethodNotificationResourceUpdated = "notifications/resources/updated" + + // Notifies when the list of available prompt templates changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification + MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" + + // Notifies when the list of available tools changes. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/ + MethodNotificationToolsListChanged = "notifications/tools/list_changed" ) type URITemplate struct { @@ -226,6 +240,11 @@ const ( INTERNAL_ERROR = -32603 ) +// MCP error codes +const ( + RESOURCE_NOT_FOUND = -32002 +) + /* Empty result */ // EmptyResult represents a response that indicates success but carries no data. diff --git a/server/hooks.go b/server/hooks.go index ce976a6c..30519d4c 100644 --- a/server/hooks.go +++ b/server/hooks.go @@ -11,6 +11,9 @@ import ( // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) +// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. +type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) + // BeforeAnyHookFunc is a function that is called after the request is // parsed but before the method is called. type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) @@ -33,7 +36,7 @@ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, m // } // // // Use errors.As to get specific error types -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { // // Access specific methods/fields of the error type // log.Printf("Failed to parse message for method %s: %v", @@ -83,6 +86,7 @@ type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallTool type Hooks struct { OnRegisterSession []OnRegisterSessionHookFunc + OnUnregisterSession []OnUnregisterSessionHookFunc OnBeforeAny []BeforeAnyHookFunc OnSuccess []OnSuccessHookFunc OnError []OnErrorHookFunc @@ -135,9 +139,9 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { // } // // // For parsing errors -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { -// // Handle unparseable message errors +// // Handle unparsable message errors // fmt.Printf("Failed to parse %s request: %v\n", // parseErr.GetMethod(), parseErr.Unwrap()) // errChan <- parseErr @@ -191,7 +195,7 @@ func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, mes // // Common error types include: // - ErrUnsupported: When a capability is not enabled -// - UnparseableMessageError: When request parsing fails +// - UnparsableMessageError: When request parsing fails // - ErrResourceNotFound: When a resource is not found // - ErrPromptNotFound: When a prompt is not found // - ErrToolNotFound: When a tool is not found @@ -216,6 +220,19 @@ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { hook(ctx, session) } } + +func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { + c.OnUnregisterSession = append(c.OnUnregisterSession, hook) +} + +func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnUnregisterSession { + hook(ctx, session) + } +} func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) { c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook) } diff --git a/server/internal/gen/hooks.go.tmpl b/server/internal/gen/hooks.go.tmpl index 4a8dcf1b..9451589d 100644 --- a/server/internal/gen/hooks.go.tmpl +++ b/server/internal/gen/hooks.go.tmpl @@ -14,6 +14,8 @@ import ( // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) +// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. +type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) // BeforeAnyHookFunc is a function that is called after the request is // parsed but before the method is called. @@ -36,7 +38,7 @@ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, m // } // // // Use errors.As to get specific error types -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { // // Access specific methods/fields of the error type // log.Printf("Failed to parse message for method %s: %v", @@ -63,7 +65,8 @@ type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{. {{end}} type Hooks struct { - OnRegisterSession []OnRegisterSessionHookFunc + OnRegisterSession []OnRegisterSessionHookFunc + OnUnregisterSession []OnUnregisterSessionHookFunc OnBeforeAny []BeforeAnyHookFunc OnSuccess []OnSuccessHookFunc OnError []OnErrorHookFunc @@ -101,9 +104,9 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { // } // // // For parsing errors -// var parseErr = &UnparseableMessageError{} +// var parseErr = &UnparsableMessageError{} // if errors.As(err, &parseErr) { -// // Handle unparseable message errors +// // Handle unparsable message errors // fmt.Printf("Failed to parse %s request: %v\n", // parseErr.GetMethod(), parseErr.Unwrap()) // errChan <- parseErr @@ -157,7 +160,7 @@ func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, mes // // Common error types include: // - ErrUnsupported: When a capability is not enabled -// - UnparseableMessageError: When request parsing fails +// - UnparsableMessageError: When request parsing fails // - ErrResourceNotFound: When a resource is not found // - ErrPromptNotFound: When a prompt is not found // - ErrToolNotFound: When a tool is not found @@ -183,6 +186,19 @@ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { } } +func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { + c.OnUnregisterSession = append(c.OnUnregisterSession, hook) +} + +func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnUnregisterSession { + hook(ctx, session) + } +} + {{- range .}} func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) { c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook) diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 5c69f5fa..e78f2799 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -78,7 +78,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request) diff --git a/server/request_handler.go b/server/request_handler.go index 55d2d19e..0d0e68e8 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -70,7 +70,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) @@ -89,7 +89,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforePing(ctx, baseMessage.ID, &request) @@ -114,7 +114,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListResources(ctx, baseMessage.ID, &request) @@ -139,7 +139,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) @@ -164,7 +164,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) @@ -189,7 +189,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) @@ -214,7 +214,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) @@ -239,7 +239,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeListTools(ctx, baseMessage.ID, &request) @@ -264,7 +264,7 @@ func (s *MCPServer) HandleMessage( err = &requestError{ id: baseMessage.ID, code: mcp.INVALID_REQUEST, - err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) diff --git a/server/server.go b/server/server.go index 8ebd40bd..430f8d53 100644 --- a/server/server.go +++ b/server/server.go @@ -73,27 +73,27 @@ func ClientSessionFromContext(ctx context.Context) ClientSession { return nil } -// UnparseableMessageError is attached to the RequestError when json.Unmarshal +// UnparsableMessageError is attached to the RequestError when json.Unmarshal // fails on the request. -type UnparseableMessageError struct { +type UnparsableMessageError struct { message json.RawMessage method mcp.MCPMethod err error } -func (e *UnparseableMessageError) Error() string { - return fmt.Sprintf("unparseable %s request: %s", e.method, e.err) +func (e *UnparsableMessageError) Error() string { + return fmt.Sprintf("unparsable %s request: %s", e.method, e.err) } -func (e *UnparseableMessageError) Unwrap() error { +func (e *UnparsableMessageError) Unwrap() error { return e.err } -func (e *UnparseableMessageError) GetMessage() json.RawMessage { +func (e *UnparsableMessageError) GetMessage() json.RawMessage { return e.message } -func (e *UnparseableMessageError) GetMethod() mcp.MCPMethod { +func (e *UnparsableMessageError) GetMethod() mcp.MCPMethod { return e.method } @@ -206,13 +206,15 @@ func (s *MCPServer) RegisterSession( // UnregisterSession removes from storage session that is shut down. func (s *MCPServer) UnregisterSession( + ctx context.Context, sessionID string, ) { - s.sessions.Delete(sessionID) + session, _ := s.sessions.LoadAndDelete(sessionID) + s.hooks.UnregisterSession(ctx, session.(ClientSession)) } -// sendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) sendNotificationToAllClients( +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( method string, params map[string]any, ) { @@ -417,6 +419,12 @@ func (s *MCPServer) AddResource( resource: resource, handler: handler, } + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } } // RemoveResource removes a resource from the server @@ -427,7 +435,7 @@ func (s *MCPServer) RemoveResource(uri string) { // Send notification to all initialized sessions if listChanged capability is enabled if s.capabilities.resources != nil && s.capabilities.resources.listChanged { - s.sendNotificationToAllClients("resources/list_changed", nil) + s.SendNotificationToAllClients("resources/list_changed", nil) } } @@ -448,6 +456,12 @@ func (s *MCPServer) AddResourceTemplate( template: template, handler: handler, } + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } } // AddPrompt registers a new prompt handler with the given name @@ -462,6 +476,12 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { defer s.promptsMu.Unlock() s.prompts[prompt.Name] = prompt s.promptHandlers[prompt.Name] = handler + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + } } // AddTool registers a new tool and its handler @@ -483,14 +503,17 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { } s.toolsMu.Unlock() - // Send notification to all initialized sessions - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } } // SetTools replaces all existing tools with the provided list func (s *MCPServer) SetTools(tools ...ServerTool) { s.toolsMu.Lock() - s.tools = make(map[string]ServerTool) + s.tools = make(map[string]ServerTool, len(tools)) s.toolsMu.Unlock() s.AddTools(tools...) } @@ -503,8 +526,11 @@ func (s *MCPServer) DeleteTools(names ...string) { } s.toolsMu.Unlock() - // Send notification to all initialized sessions - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } } // AddNotificationHandler registers a new handler for incoming notifications @@ -712,7 +738,7 @@ func (s *MCPServer) handleReadResource( matched = true matchedVars := template.URITemplate.Match(request.Params.URI) // Convert matched variables to a map - request.Params.Arguments = make(map[string]interface{}) + request.Params.Arguments = make(map[string]interface{}, len(matchedVars)) for name, value := range matchedVars { request.Params.Arguments[name] = value.V } @@ -735,7 +761,7 @@ func (s *MCPServer) handleReadResource( return nil, &requestError{ id: id, - code: mcp.INVALID_PARAMS, + code: mcp.RESOURCE_NOT_FOUND, err: fmt.Errorf("handler not found for resource URI '%s': %w", request.Params.URI, ErrResourceNotFound), } } diff --git a/server/server_test.go b/server/server_test.go index e55008f1..641a3c88 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -199,7 +199,7 @@ func TestMCPServer_Tools(t *testing.T) { }, expectedNotifications: 1, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -241,7 +241,7 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 5, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { for _, notification := range notifications { - assert.Equal(t, "notifications/tools/list_changed", notification.Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notification.Method) } tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) @@ -269,8 +269,8 @@ func TestMCPServer_Tools(t *testing.T) { }, expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) - assert.Equal(t, "notifications/tools/list_changed", notifications[1].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -294,9 +294,9 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { // One for SetTools - assert.Equal(t, "notifications/tools/list_changed", notifications[0].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) // One for DeleteTools - assert.Equal(t, "notifications/tools/list_changed", notifications[1].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) // Expect a successful response with an empty list of tools resp, ok := toolsList.(mcp.JSONRPCResponse) @@ -312,7 +312,7 @@ func TestMCPServer_Tools(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - server := NewMCPServer("test-server", "1.0.0") + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) _ = server.HandleMessage(ctx, []byte(`{ "jsonrpc": "2.0", "id": 1, @@ -340,7 +340,6 @@ func TestMCPServer_Tools(t *testing.T) { }`)) tt.validate(t, notifications, toolsList.(mcp.JSONRPCMessage)) }) - } } @@ -573,6 +572,75 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) { } } +func TestMCPServer_SendNotificationToAllClients(t *testing.T) { + + contextPrepare := func(ctx context.Context, srv *MCPServer) context.Context { + // Create 5 active sessions + for i := 0; i < 5; i++ { + err := srv.RegisterSession(ctx, &fakeSession{ + sessionID: fmt.Sprintf("test%d", i), + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + }) + require.NoError(t, err) + } + return ctx + } + + validate := func(t *testing.T, ctx context.Context, srv *MCPServer) { + // Send 10 notifications to all sessions + for i := 0; i < 10; i++ { + srv.SendNotificationToAllClients("method", map[string]any{ + "count": i, + }) + } + + // Verify each session received all 10 notifications + srv.sessions.Range(func(k, v any) bool { + session := v.(ClientSession) + fakeSess := session.(*fakeSession) + notificationCount := 0 + + // Read all notifications from the channel + for notificationCount < 10 { + select { + case notification := <-fakeSess.notificationChannel: + // Verify notification method + assert.Equal(t, "method", notification.Method) + // Verify count parameter + count, ok := notification.Params.AdditionalFields["count"] + assert.True(t, ok, "count parameter not found") + assert.Equal(t, notificationCount, count.(int), "count should match notification count") + notificationCount++ + case <-time.After(100 * time.Millisecond): + t.Errorf("timeout waiting for notification %d for session %s", notificationCount, session.SessionID()) + return false + } + } + + // Verify no more notifications + select { + case notification := <-fakeSess.notificationChannel: + t.Errorf("unexpected notification received: %v", notification) + default: + // Channel empty as expected + } + return true + }) + } + + t.Run("all sessions", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + ctx := contextPrepare(context.Background(), server) + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + validate(t, ctx, server) + }) +} + func TestMCPServer_PromptHandling(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true), @@ -725,11 +793,11 @@ func TestMCPServer_HandleInvalidMessages(t *testing.T) { message: `{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`, expectedErr: mcp.INVALID_REQUEST, validateErr: func(t *testing.T, err error) { - var unparseableErr = &UnparseableMessageError{} - var ok = errors.As(err, &unparseableErr) - assert.True(t, ok, "Error should be UnparseableMessageError") - assert.Equal(t, mcp.MethodInitialize, unparseableErr.GetMethod()) - assert.Equal(t, json.RawMessage(`{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`), unparseableErr.GetMessage()) + unparsableErr := &UnparsableMessageError{} + ok := errors.As(err, &unparsableErr) + assert.True(t, ok, "Error should be UnparsableMessageError") + assert.Equal(t, mcp.MethodInitialize, unparsableErr.GetMethod()) + assert.Equal(t, json.RawMessage(`{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": "invalid"}`), unparsableErr.GetMessage()) }, }, { @@ -861,7 +929,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { "uri": "undefined-resource" } }`, - expectedErr: mcp.INVALID_PARAMS, + expectedErr: mcp.RESOURCE_NOT_FOUND, validateCallbacks: func(t *testing.T, err error, beforeResults beforeResult) { assert.Equal(t, mcp.MethodResourcesRead, beforeResults.method) assert.True(t, errors.Is(err, ErrResourceNotFound)) @@ -1125,7 +1193,6 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { assert.Equal(t, "test://something/test-resource/a/b/c", resultContent.URI) assert.Equal(t, "text/plain", resultContent.MIMEType) assert.Equal(t, "test content: something", resultContent.Text) - }) } @@ -1353,6 +1420,76 @@ func TestMCPServer_WithHooks(t *testing.T) { assert.IsType(t, afterPingData[0].res, onSuccessData[0].res, "OnSuccess result should be same type as AfterPing result") } +func TestMCPServer_SessionHooks(t *testing.T) { + var ( + registerCalled bool + unregisterCalled bool + + registeredContext context.Context + unregisteredContext context.Context + + registeredSession ClientSession + unregisteredSession ClientSession + ) + + hooks := &Hooks{} + hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) { + registerCalled = true + registeredContext = ctx + registeredSession = session + }) + hooks.AddOnUnregisterSession(func(ctx context.Context, session ClientSession) { + unregisterCalled = true + unregisteredContext = ctx + unregisteredSession = session + }) + + server := NewMCPServer( + "test-server", + "1.0.0", + WithHooks(hooks), + ) + + testSession := &fakeSession{ + sessionID: "test-session-id", + notificationChannel: make(chan mcp.JSONRPCNotification, 5), + initialized: false, + } + + ctx := context.WithoutCancel(context.Background()) + err := server.RegisterSession(ctx, testSession) + require.NoError(t, err) + + assert.True(t, registerCalled, "Register session hook was not called") + assert.Equal(t, testSession.SessionID(), registeredSession.SessionID(), + "Register hook received wrong session") + + server.UnregisterSession(ctx, testSession.SessionID()) + + assert.True(t, unregisterCalled, "Unregister session hook was not called") + assert.Equal(t, testSession.SessionID(), unregisteredSession.SessionID(), + "Unregister hook received wrong session") + + assert.Equal(t, ctx, unregisteredContext, "Unregister hook received wrong context") + assert.Equal(t, ctx, registeredContext, "Register hook received wrong context") +} + +func TestMCPServer_SessionHooks_NilHooks(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + testSession := &fakeSession{ + sessionID: "test-session-id", + notificationChannel: make(chan mcp.JSONRPCNotification, 5), + initialized: false, + } + + ctx := context.WithoutCancel(context.Background()) + err := server.RegisterSession(ctx, testSession) + require.NoError(t, err) + + server.UnregisterSession(ctx, testSession.SessionID()) +} + func TestMCPServer_WithRecover(t *testing.T) { panicToolHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { panic("test panic") diff --git a/server/sse.go b/server/sse.go index b6ae2144..9a419150 100644 --- a/server/sse.go +++ b/server/sse.go @@ -179,10 +179,7 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { // NewTestServer creates a test server for testing purposes func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { - sseServer := NewSSEServer(server) - for _, opt := range opts { - opt(sseServer) - } + sseServer := NewSSEServer(server, opts...) testServer := httptest.NewServer(sseServer) sseServer.baseURL = testServer.URL @@ -259,7 +256,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError) return } - defer s.server.UnregisterSession(sessionID) + defer s.server.UnregisterSession(r.Context(), sessionID) // Start notification handler for this session go func() { @@ -324,6 +321,8 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): close(session.done) return + case <-session.done: + return } } } @@ -438,6 +437,7 @@ func (s *SSEServer) SendEventToSession( return fmt.Errorf("event queue full") } } + func (s *SSEServer) GetUrlPath(input string) (string, error) { parse, err := url.Parse(input) if err != nil { @@ -449,6 +449,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) { func (s *SSEServer) CompleteSseEndpoint() string { return s.baseURL + s.basePath + s.sseEndpoint } + func (s *SSEServer) CompleteSsePath() string { path, err := s.GetUrlPath(s.CompleteSseEndpoint()) if err != nil { @@ -460,6 +461,7 @@ func (s *SSEServer) CompleteSsePath() string { func (s *SSEServer) CompleteMessageEndpoint() string { return s.baseURL + s.basePath + s.messageEndpoint } + func (s *SSEServer) CompleteMessagePath() string { path, err := s.GetUrlPath(s.CompleteMessageEndpoint()) if err != nil { diff --git a/server/stdio.go b/server/stdio.go index 43d9570c..0de9f347 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -204,7 +204,7 @@ func (s *StdioServer) Listen( if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { return fmt.Errorf("register session: %w", err) } - defer s.server.UnregisterSession(stdioSessionInstance.SessionID()) + defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) ctx = s.server.WithContext(ctx, &stdioSessionInstance) // Add in any custom context. pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy