Skip to content

Commit 0c07739

Browse files
committed
And we have chat!
1 parent 9ac4643 commit 0c07739

27 files changed

+2902
-305
lines changed

cli/server.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ import (
6161
"github.com/coder/serpent"
6262
"github.com/coder/wgtunnel/tunnelsdk"
6363

64+
"github.com/coder/coder/v2/coderd/ai"
6465
"github.com/coder/coder/v2/coderd/entitlements"
6566
"github.com/coder/coder/v2/coderd/notifications/reports"
6667
"github.com/coder/coder/v2/coderd/runtimeconfig"
@@ -610,6 +611,22 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
610611
)
611612
}
612613

614+
aiProviders, err := ReadAIProvidersFromEnv(os.Environ())
615+
if err != nil {
616+
return xerrors.Errorf("read ai providers from env: %w", err)
617+
}
618+
vals.AI.Value.Providers = append(vals.AI.Value.Providers, aiProviders...)
619+
for _, provider := range aiProviders {
620+
logger.Debug(
621+
ctx, "loaded ai provider",
622+
slog.F("type", provider.Type),
623+
)
624+
}
625+
languageModels, err := ai.ModelsFromConfig(ctx, vals.AI.Value.Providers)
626+
if err != nil {
627+
return xerrors.Errorf("create language models: %w", err)
628+
}
629+
613630
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
614631
if err != nil {
615632
return xerrors.Errorf("parse real ip config: %w", err)
@@ -640,6 +657,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
640657
CacheDir: cacheDir,
641658
GoogleTokenValidator: googleTokenValidator,
642659
ExternalAuthConfigs: externalAuthConfigs,
660+
LanguageModels: languageModels,
643661
RealIPConfig: realIPConfig,
644662
SSHKeygenAlgorithm: sshKeygenAlgorithm,
645663
TracerProvider: tracerProvider,
@@ -2655,6 +2673,29 @@ func ReadAIProvidersFromEnv(environ []string) ([]codersdk.AIProviderConfig, erro
26552673
}
26562674
providers[providerNum] = provider
26572675
}
2676+
for _, envVar := range environ {
2677+
tokens := strings.SplitN(envVar, "=", 2)
2678+
if len(tokens) != 2 {
2679+
continue
2680+
}
2681+
switch tokens[0] {
2682+
case "OPENAI_API_KEY":
2683+
providers = append(providers, codersdk.AIProviderConfig{
2684+
Type: "openai",
2685+
APIKey: tokens[1],
2686+
})
2687+
case "ANTHROPIC_API_KEY":
2688+
providers = append(providers, codersdk.AIProviderConfig{
2689+
Type: "anthropic",
2690+
APIKey: tokens[1],
2691+
})
2692+
case "GOOGLE_API_KEY":
2693+
providers = append(providers, codersdk.AIProviderConfig{
2694+
Type: "google",
2695+
APIKey: tokens[1],
2696+
})
2697+
}
2698+
}
26582699
return providers, nil
26592700
}
26602701

coderd/ai/ai.go

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,140 @@ package ai
22

33
import (
44
"context"
5+
"fmt"
56

7+
"github.com/anthropics/anthropic-sdk-go"
8+
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
9+
"github.com/coder/coder/v2/codersdk"
610
"github.com/kylecarbs/aisdk-go"
11+
"github.com/openai/openai-go"
12+
openaioption "github.com/openai/openai-go/option"
13+
"google.golang.org/genai"
714
)
815

9-
type Provider func(ctx context.Context, messages []aisdk.Message) (aisdk.DataStream, error)
16+
type LanguageModel struct {
17+
codersdk.LanguageModel
18+
StreamFunc StreamFunc
19+
}
20+
21+
type StreamOptions struct {
22+
Model string
23+
Messages []aisdk.Message
24+
Thinking bool
25+
Tools []aisdk.Tool
26+
}
27+
28+
type StreamFunc func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error)
29+
30+
// LanguageModels is a map of language model ID to language model.
31+
type LanguageModels map[string]LanguageModel
32+
33+
func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig) (LanguageModels, error) {
34+
models := make(LanguageModels)
35+
36+
for _, config := range configs {
37+
var streamFunc StreamFunc
38+
39+
switch config.Type {
40+
case "openai":
41+
client := openai.NewClient(openaioption.WithAPIKey(config.APIKey))
42+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
43+
openaiMessages, err := aisdk.MessagesToOpenAI(options.Messages)
44+
if err != nil {
45+
return nil, err
46+
}
47+
tools := aisdk.ToolsToOpenAI(options.Tools)
48+
return aisdk.OpenAIToDataStream(client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
49+
Messages: openaiMessages,
50+
Model: options.Model,
51+
Tools: tools,
52+
MaxTokens: openai.Int(8192),
53+
})), nil
54+
}
55+
if config.Models == nil {
56+
models, err := client.Models.List(ctx)
57+
if err != nil {
58+
return nil, err
59+
}
60+
config.Models = make([]string, len(models.Data))
61+
for i, model := range models.Data {
62+
config.Models[i] = model.ID
63+
}
64+
}
65+
break
66+
case "anthropic":
67+
client := anthropic.NewClient(anthropicoption.WithAPIKey(config.APIKey))
68+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
69+
anthropicMessages, systemMessage, err := aisdk.MessagesToAnthropic(options.Messages)
70+
if err != nil {
71+
return nil, err
72+
}
73+
return aisdk.AnthropicToDataStream(client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
74+
Messages: anthropicMessages,
75+
Model: options.Model,
76+
System: systemMessage,
77+
Tools: aisdk.ToolsToAnthropic(options.Tools),
78+
MaxTokens: 8192,
79+
})), nil
80+
}
81+
if config.Models == nil {
82+
models, err := client.Models.List(ctx, anthropic.ModelListParams{})
83+
if err != nil {
84+
return nil, err
85+
}
86+
config.Models = make([]string, len(models.Data))
87+
for i, model := range models.Data {
88+
config.Models[i] = model.ID
89+
}
90+
}
91+
break
92+
case "google":
93+
client, err := genai.NewClient(ctx, &genai.ClientConfig{
94+
APIKey: config.APIKey,
95+
Backend: genai.BackendGeminiAPI,
96+
})
97+
if err != nil {
98+
return nil, err
99+
}
100+
streamFunc = func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error) {
101+
googleMessages, err := aisdk.MessagesToGoogle(options.Messages)
102+
if err != nil {
103+
return nil, err
104+
}
105+
tools, err := aisdk.ToolsToGoogle(options.Tools)
106+
if err != nil {
107+
return nil, err
108+
}
109+
return aisdk.GoogleToDataStream(client.Models.GenerateContentStream(ctx, options.Model, googleMessages, &genai.GenerateContentConfig{
110+
Tools: tools,
111+
})), nil
112+
}
113+
if config.Models == nil {
114+
models, err := client.Models.List(ctx, &genai.ListModelsConfig{})
115+
if err != nil {
116+
return nil, err
117+
}
118+
config.Models = make([]string, len(models.Items))
119+
for i, model := range models.Items {
120+
config.Models[i] = model.Name
121+
}
122+
}
123+
break
124+
default:
125+
return nil, fmt.Errorf("unsupported model type: %s", config.Type)
126+
}
127+
128+
for _, model := range config.Models {
129+
models[model] = LanguageModel{
130+
LanguageModel: codersdk.LanguageModel{
131+
ID: model,
132+
DisplayName: model,
133+
Provider: config.Type,
134+
},
135+
StreamFunc: streamFunc,
136+
}
137+
}
138+
}
139+
140+
return models, nil
141+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy