Skip to content

Commit 567d395

Browse files
committed
Add more MCP stuff
1 parent 0c07739 commit 567d395

File tree

13 files changed

+1369
-48
lines changed

13 files changed

+1369
-48
lines changed

coderd/ai/ai.go

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ type LanguageModel struct {
1919
}
2020

2121
type StreamOptions struct {
22-
Model string
23-
Messages []aisdk.Message
24-
Thinking bool
25-
Tools []aisdk.Tool
22+
SystemPrompt string
23+
Model string
24+
Messages []aisdk.Message
25+
Thinking bool
26+
Tools []aisdk.Tool
2627
}
2728

2829
type StreamFunc func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error)
@@ -45,6 +46,12 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
4546
return nil, err
4647
}
4748
tools := aisdk.ToolsToOpenAI(options.Tools)
49+
if options.SystemPrompt != "" {
50+
openaiMessages = append([]openai.ChatCompletionMessageParamUnion{
51+
openai.SystemMessage(options.SystemPrompt),
52+
}, openaiMessages...)
53+
}
54+
4855
return aisdk.OpenAIToDataStream(client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
4956
Messages: openaiMessages,
5057
Model: options.Model,
@@ -70,6 +77,11 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
7077
if err != nil {
7178
return nil, err
7279
}
80+
if options.SystemPrompt != "" {
81+
systemMessage = []anthropic.TextBlockParam{
82+
*anthropic.NewTextBlock(options.SystemPrompt).OfRequestTextBlock,
83+
}
84+
}
7385
return aisdk.AnthropicToDataStream(client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
7486
Messages: anthropicMessages,
7587
Model: options.Model,
@@ -106,8 +118,18 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
106118
if err != nil {
107119
return nil, err
108120
}
121+
var systemInstruction *genai.Content
122+
if options.SystemPrompt != "" {
123+
systemInstruction = &genai.Content{
124+
Parts: []*genai.Part{
125+
genai.NewPartFromText(options.SystemPrompt),
126+
},
127+
Role: "model",
128+
}
129+
}
109130
return aisdk.GoogleToDataStream(client.Models.GenerateContentStream(ctx, options.Model, googleMessages, &genai.GenerateContentConfig{
110-
Tools: tools,
131+
SystemInstruction: systemInstruction,
132+
Tools: tools,
111133
})), nil
112134
}
113135
if config.Models == nil {

coderd/chat.go

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package coderd
22

33
import (
44
"encoding/json"
5+
"io"
56
"net/http"
67
"time"
78

@@ -12,11 +13,9 @@ import (
1213
"github.com/coder/coder/v2/coderd/httpapi"
1314
"github.com/coder/coder/v2/coderd/httpmw"
1415
"github.com/coder/coder/v2/codersdk"
15-
codermcp "github.com/coder/coder/v2/mcp"
16+
"github.com/coder/coder/v2/codersdk/toolsdk"
1617
"github.com/google/uuid"
1718
"github.com/kylecarbs/aisdk-go"
18-
"github.com/mark3labs/mcp-go/mcp"
19-
"github.com/mark3labs/mcp-go/server"
2019
)
2120

2221
// postChats creates a new chat.
@@ -157,31 +156,17 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
157156
}
158157
messages = append(messages, req.Message)
159158

160-
toolMap := codermcp.AllTools()
161-
toolsByName := make(map[string]server.ToolHandlerFunc)
162159
client := codersdk.New(api.AccessURL)
163160
client.SetSessionToken(httpmw.APITokenFromRequest(r))
164-
toolDeps := codermcp.ToolDeps{
165-
Client: client,
166-
Logger: &api.Logger,
167-
}
168-
for _, tool := range toolMap {
169-
toolsByName[tool.Tool.Name] = tool.MakeHandler(toolDeps)
170-
}
171-
convertedTools := make([]aisdk.Tool, len(toolMap))
172-
for i, tool := range toolMap {
173-
schema := aisdk.Schema{
174-
Required: tool.Tool.InputSchema.Required,
175-
Properties: tool.Tool.InputSchema.Properties,
176-
}
177-
if tool.Tool.InputSchema.Required == nil {
178-
schema.Required = []string{}
179-
}
180-
convertedTools[i] = aisdk.Tool{
181-
Name: tool.Tool.Name,
182-
Description: tool.Tool.Description,
183-
Schema: schema,
161+
162+
tools := make([]aisdk.Tool, len(toolsdk.All))
163+
handlers := map[string]toolsdk.HandlerFunc[any]{}
164+
for i, tool := range toolsdk.All {
165+
if tool.Tool.Schema.Required == nil {
166+
tool.Tool.Schema.Required = []string{}
184167
}
168+
tools[i] = tool.Tool
169+
handlers[tool.Tool.Name] = tool.Handler
185170
}
186171

187172
provider, ok := api.LanguageModels[req.Model]
@@ -192,6 +177,43 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
192177
return
193178
}
194179

180+
// If it's the user's first message, generate a title for the chat.
181+
if len(messages) == 1 {
182+
var acc aisdk.DataStreamAccumulator
183+
stream, err := provider.StreamFunc(ctx, ai.StreamOptions{
184+
Model: req.Model,
185+
SystemPrompt: `- You will generate a short title based on the user's message.
186+
- It should be maximum of 40 characters.
187+
- Do not use quotes, colons, special characters, or emojis.`,
188+
Messages: messages,
189+
Tools: tools,
190+
})
191+
if err != nil {
192+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
193+
Message: "Failed to create stream",
194+
Detail: err.Error(),
195+
})
196+
}
197+
stream = stream.WithAccumulator(&acc)
198+
err = stream.Pipe(io.Discard)
199+
if err != nil {
200+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
201+
Message: "Failed to pipe stream",
202+
Detail: err.Error(),
203+
})
204+
}
205+
err = api.Database.UpdateChatByID(ctx, database.UpdateChatByIDParams{
206+
ID: chat.ID,
207+
Title: acc.Messages()[0].Content,
208+
})
209+
if err != nil {
210+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
211+
Message: "Failed to update chat title",
212+
Detail: err.Error(),
213+
})
214+
}
215+
}
216+
195217
// Write headers for the data stream!
196218
aisdk.WriteDataStreamHeaders(w)
197219

@@ -224,7 +246,11 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
224246
stream, err := provider.StreamFunc(ctx, ai.StreamOptions{
225247
Model: req.Model,
226248
Messages: messages,
227-
Tools: convertedTools,
249+
Tools: tools,
250+
SystemPrompt: `You are a chat assistant for Coder. You will attempt to resolve the user's
251+
request to the maximum utilization of your tools.
252+
253+
Try your best to not ask the user for help - solve the task with your tools!`,
228254
})
229255
if err != nil {
230256
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
@@ -234,28 +260,17 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
234260
return
235261
}
236262
stream = stream.WithToolCalling(func(toolCall aisdk.ToolCall) any {
237-
tool, ok := toolsByName[toolCall.Name]
263+
tool, ok := handlers[toolCall.Name]
238264
if !ok {
239265
return nil
240266
}
241-
result, err := tool(ctx, mcp.CallToolRequest{
242-
Params: struct {
243-
Name string "json:\"name\""
244-
Arguments map[string]interface{} "json:\"arguments,omitempty\""
245-
Meta *struct {
246-
ProgressToken mcp.ProgressToken "json:\"progressToken,omitempty\""
247-
} "json:\"_meta,omitempty\""
248-
}{
249-
Name: toolCall.Name,
250-
Arguments: toolCall.Args,
251-
},
252-
})
267+
result, err := tool(toolsdk.WithClient(ctx, client), toolCall.Args)
253268
if err != nil {
254269
return map[string]any{
255270
"error": err.Error(),
256271
}
257272
}
258-
return result.Content
273+
return result
259274
}).WithAccumulator(&acc)
260275

261276
err = stream.Pipe(w)

coderd/database/dbauthz/dbauthz.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3993,7 +3993,10 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
39933993
}
39943994

39953995
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) error {
3996-
panic("not implemented")
3996+
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat.WithID(arg.ID)); err != nil {
3997+
return err
3998+
}
3999+
return q.db.UpdateChatByID(ctx, arg)
39974000
}
39984001

39994002
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy