Skip to content

Commit 544259b

Browse files
johnstcnkylecarbs
andauthored
feat: add database tables and API routes for agentic chat feature (#17570)
Backend portion of experimental `AgenticChat` feature: - Adds database tables for chats and chat messages - Adds functionality to stream messages from LLM providers using `kylecarbs/aisdk-go` - Adds API routes with relevant functionality (list, create, update chats, insert chat message) - Adds experiment `codersdk.AgenticChat` --------- Co-authored-by: Kyle Carberry <kyle@carberry.com>
1 parent 64b9bc1 commit 544259b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+4264
-16
lines changed

cli/server.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ import (
6161
"github.com/coder/serpent"
6262
"github.com/coder/wgtunnel/tunnelsdk"
6363

64+
"github.com/coder/coder/v2/coderd/ai"
6465
"github.com/coder/coder/v2/coderd/entitlements"
6566
"github.com/coder/coder/v2/coderd/notifications/reports"
6667
"github.com/coder/coder/v2/coderd/runtimeconfig"
@@ -610,6 +611,22 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
610611
)
611612
}
612613

614+
aiProviders, err := ReadAIProvidersFromEnv(os.Environ())
615+
if err != nil {
616+
return xerrors.Errorf("read ai providers from env: %w", err)
617+
}
618+
vals.AI.Value.Providers = append(vals.AI.Value.Providers, aiProviders...)
619+
for _, provider := range aiProviders {
620+
logger.Debug(
621+
ctx, "loaded ai provider",
622+
slog.F("type", provider.Type),
623+
)
624+
}
625+
languageModels, err := ai.ModelsFromConfig(ctx, vals.AI.Value.Providers)
626+
if err != nil {
627+
return xerrors.Errorf("create language models: %w", err)
628+
}
629+
613630
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
614631
if err != nil {
615632
return xerrors.Errorf("parse real ip config: %w", err)
@@ -640,6 +657,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
640657
CacheDir: cacheDir,
641658
GoogleTokenValidator: googleTokenValidator,
642659
ExternalAuthConfigs: externalAuthConfigs,
660+
LanguageModels: languageModels,
643661
RealIPConfig: realIPConfig,
644662
SSHKeygenAlgorithm: sshKeygenAlgorithm,
645663
TracerProvider: tracerProvider,
@@ -2621,6 +2639,77 @@ func redirectHTTPToHTTPSDeprecation(ctx context.Context, logger slog.Logger, inv
26212639
}
26222640
}
26232641

2642+
func ReadAIProvidersFromEnv(environ []string) ([]codersdk.AIProviderConfig, error) {
2643+
// The index numbers must be in-order.
2644+
sort.Strings(environ)
2645+
2646+
var providers []codersdk.AIProviderConfig
2647+
for _, v := range serpent.ParseEnviron(environ, "CODER_AI_PROVIDER_") {
2648+
tokens := strings.SplitN(v.Name, "_", 2)
2649+
if len(tokens) != 2 {
2650+
return nil, xerrors.Errorf("invalid env var: %s", v.Name)
2651+
}
2652+
2653+
providerNum, err := strconv.Atoi(tokens[0])
2654+
if err != nil {
2655+
return nil, xerrors.Errorf("parse number: %s", v.Name)
2656+
}
2657+
2658+
var provider codersdk.AIProviderConfig
2659+
switch {
2660+
case len(providers) < providerNum:
2661+
return nil, xerrors.Errorf(
2662+
"provider num %v skipped: %s",
2663+
len(providers),
2664+
v.Name,
2665+
)
2666+
case len(providers) == providerNum:
2667+
// At the next next provider.
2668+
providers = append(providers, provider)
2669+
case len(providers) == providerNum+1:
2670+
// At the current provider.
2671+
provider = providers[providerNum]
2672+
}
2673+
2674+
key := tokens[1]
2675+
switch key {
2676+
case "TYPE":
2677+
provider.Type = v.Value
2678+
case "API_KEY":
2679+
provider.APIKey = v.Value
2680+
case "BASE_URL":
2681+
provider.BaseURL = v.Value
2682+
case "MODELS":
2683+
provider.Models = strings.Split(v.Value, ",")
2684+
}
2685+
providers[providerNum] = provider
2686+
}
2687+
for _, envVar := range environ {
2688+
tokens := strings.SplitN(envVar, "=", 2)
2689+
if len(tokens) != 2 {
2690+
continue
2691+
}
2692+
switch tokens[0] {
2693+
case "OPENAI_API_KEY":
2694+
providers = append(providers, codersdk.AIProviderConfig{
2695+
Type: "openai",
2696+
APIKey: tokens[1],
2697+
})
2698+
case "ANTHROPIC_API_KEY":
2699+
providers = append(providers, codersdk.AIProviderConfig{
2700+
Type: "anthropic",
2701+
APIKey: tokens[1],
2702+
})
2703+
case "GOOGLE_API_KEY":
2704+
providers = append(providers, codersdk.AIProviderConfig{
2705+
Type: "google",
2706+
APIKey: tokens[1],
2707+
})
2708+
}
2709+
}
2710+
return providers, nil
2711+
}
2712+
26242713
// ReadExternalAuthProvidersFromEnv is provided for compatibility purposes with
26252714
// the viper CLI.
26262715
func ReadExternalAuthProvidersFromEnv(environ []string) ([]codersdk.ExternalAuthConfig, error) {

cli/testdata/server-config.yaml.golden

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,9 @@ client:
519519
# Support links to display in the top right drop down menu.
520520
# (default: <unset>, type: struct[[]codersdk.LinkConfig])
521521
supportLinks: []
522+
# Configure AI providers.
523+
# (default: <unset>, type: struct[codersdk.AIConfig])
524+
ai: {}
522525
# External Authentication providers.
523526
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
524527
externalAuthProviders: []

coderd/ai/ai.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package ai
2+
3+
import (
4+
"context"
5+
6+
"github.com/anthropics/anthropic-sdk-go"
7+
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
8+
"github.com/kylecarbs/aisdk-go"
9+
"github.com/openai/openai-go"
10+
openaioption "github.com/openai/openai-go/option"
11+
"golang.org/x/xerrors"
12+
"google.golang.org/genai"
13+
14+
"github.com/coder/coder/v2/codersdk"
15+
)
16+
17+
type LanguageModel struct {
18+
codersdk.LanguageModel
19+
StreamFunc StreamFunc
20+
}
21+
22+
type StreamOptions struct {
23+
SystemPrompt string
24+
Model string
25+
Messages []aisdk.Message
26+
Thinking bool
27+
Tools []aisdk.Tool
28+
}
29+
30+
type StreamFunc func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error)
31+
32+
// LanguageModels is a map of language model ID to language model.
33+
type LanguageModels map[string]LanguageModel
34+
35+
func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig) (LanguageModels, error) {
36+
models := make(LanguageModels)
37+
38+
for _, config := range configs {
39+
var streamFunc StreamFunc
40+
41+
switch config.Type {
42+
case "openai":
43+
opts := []openaioption.RequestOption{
44+
openaioption.WithAPIKey(config.APIKey),
45+
}
46+
if config.BaseURL != "" {
47+
opts = append(opts, openaioption.WithBaseURL(config.BaseURL))
48+
}
49+
client := openai.NewClient(opts...)
50+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
51+
openaiMessages, err := aisdk.MessagesToOpenAI(options.Messages)
52+
if err != nil {
53+
return nil, err
54+
}
55+
tools := aisdk.ToolsToOpenAI(options.Tools)
56+
if options.SystemPrompt != "" {
57+
openaiMessages = append([]openai.ChatCompletionMessageParamUnion{
58+
openai.SystemMessage(options.SystemPrompt),
59+
}, openaiMessages...)
60+
}
61+
62+
return aisdk.OpenAIToDataStream(client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
63+
Messages: openaiMessages,
64+
Model: options.Model,
65+
Tools: tools,
66+
MaxTokens: openai.Int(8192),
67+
})), nil
68+
}
69+
if config.Models == nil {
70+
models, err := client.Models.List(ctx)
71+
if err != nil {
72+
return nil, err
73+
}
74+
config.Models = make([]string, len(models.Data))
75+
for i, model := range models.Data {
76+
config.Models[i] = model.ID
77+
}
78+
}
79+
case "anthropic":
80+
client := anthropic.NewClient(anthropicoption.WithAPIKey(config.APIKey))
81+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
82+
anthropicMessages, systemMessage, err := aisdk.MessagesToAnthropic(options.Messages)
83+
if err != nil {
84+
return nil, err
85+
}
86+
if options.SystemPrompt != "" {
87+
systemMessage = []anthropic.TextBlockParam{
88+
*anthropic.NewTextBlock(options.SystemPrompt).OfRequestTextBlock,
89+
}
90+
}
91+
return aisdk.AnthropicToDataStream(client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
92+
Messages: anthropicMessages,
93+
Model: options.Model,
94+
System: systemMessage,
95+
Tools: aisdk.ToolsToAnthropic(options.Tools),
96+
MaxTokens: 8192,
97+
})), nil
98+
}
99+
if config.Models == nil {
100+
models, err := client.Models.List(ctx, anthropic.ModelListParams{})
101+
if err != nil {
102+
return nil, err
103+
}
104+
config.Models = make([]string, len(models.Data))
105+
for i, model := range models.Data {
106+
config.Models[i] = model.ID
107+
}
108+
}
109+
case "google":
110+
client, err := genai.NewClient(ctx, &genai.ClientConfig{
111+
APIKey: config.APIKey,
112+
Backend: genai.BackendGeminiAPI,
113+
})
114+
if err != nil {
115+
return nil, err
116+
}
117+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
118+
googleMessages, err := aisdk.MessagesToGoogle(options.Messages)
119+
if err != nil {
120+
return nil, err
121+
}
122+
tools, err := aisdk.ToolsToGoogle(options.Tools)
123+
if err != nil {
124+
return nil, err
125+
}
126+
var systemInstruction *genai.Content
127+
if options.SystemPrompt != "" {
128+
systemInstruction = &genai.Content{
129+
Parts: []*genai.Part{
130+
genai.NewPartFromText(options.SystemPrompt),
131+
},
132+
Role: "model",
133+
}
134+
}
135+
return aisdk.GoogleToDataStream(client.Models.GenerateContentStream(ctx, options.Model, googleMessages, &genai.GenerateContentConfig{
136+
SystemInstruction: systemInstruction,
137+
Tools: tools,
138+
})), nil
139+
}
140+
if config.Models == nil {
141+
models, err := client.Models.List(ctx, &genai.ListModelsConfig{})
142+
if err != nil {
143+
return nil, err
144+
}
145+
config.Models = make([]string, len(models.Items))
146+
for i, model := range models.Items {
147+
config.Models[i] = model.Name
148+
}
149+
}
150+
default:
151+
return nil, xerrors.Errorf("unsupported model type: %s", config.Type)
152+
}
153+
154+
for _, model := range config.Models {
155+
models[model] = LanguageModel{
156+
LanguageModel: codersdk.LanguageModel{
157+
ID: model,
158+
DisplayName: model,
159+
Provider: config.Type,
160+
},
161+
StreamFunc: streamFunc,
162+
}
163+
}
164+
}
165+
166+
return models, nil
167+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy