@@ -2,8 +2,140 @@ package ai
22
33import (
44 "context"
5+ "fmt"
56
7+ "github.com/anthropics/anthropic-sdk-go"
8+ anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
9+ "github.com/coder/coder/v2/codersdk"
610 "github.com/kylecarbs/aisdk-go"
11+ "github.com/openai/openai-go"
12+ openaioption "github.com/openai/openai-go/option"
13+ "google.golang.org/genai"
714)
815
9- type Provider func (ctx context.Context , messages []aisdk.Message ) (aisdk.DataStream , error )
16+ type LanguageModel struct {
17+ codersdk.LanguageModel
18+ StreamFunc StreamFunc
19+ }
20+
21+ type StreamOptions struct {
22+ Model string
23+ Messages []aisdk.Message
24+ Thinking bool
25+ Tools []aisdk.Tool
26+ }
27+
28+ type StreamFunc func (ctx context.Context , options StreamOptions ) (aisdk.DataStream , error )
29+
30+ // LanguageModels is a map of language model ID to language model.
31+ type LanguageModels map [string ]LanguageModel
32+
33+ func ModelsFromConfig (ctx context.Context , configs []codersdk.AIProviderConfig ) (LanguageModels , error ) {
34+ models := make (LanguageModels )
35+
36+ for _ , config := range configs {
37+ var streamFunc StreamFunc
38+
39+ switch config .Type {
40+ case "openai" :
41+ client := openai .NewClient (openaioption .WithAPIKey (config .APIKey ))
42+ streamFunc = func (ctx context.Context , options StreamOptions ) (aisdk.DataStream , error ) {
43+ openaiMessages , err := aisdk .MessagesToOpenAI (options .Messages )
44+ if err != nil {
45+ return nil , err
46+ }
47+ tools := aisdk .ToolsToOpenAI (options .Tools )
48+ return aisdk .OpenAIToDataStream (client .Chat .Completions .NewStreaming (ctx , openai.ChatCompletionNewParams {
49+ Messages : openaiMessages ,
50+ Model : options .Model ,
51+ Tools : tools ,
52+ MaxTokens : openai .Int (8192 ),
53+ })), nil
54+ }
55+ if config .Models == nil {
56+ models , err := client .Models .List (ctx )
57+ if err != nil {
58+ return nil , err
59+ }
60+ config .Models = make ([]string , len (models .Data ))
61+ for i , model := range models .Data {
62+ config .Models [i ] = model .ID
63+ }
64+ }
65+ break
66+ case "anthropic" :
67+ client := anthropic .NewClient (anthropicoption .WithAPIKey (config .APIKey ))
68+ streamFunc = func (ctx context.Context , options StreamOptions ) (aisdk.DataStream , error ) {
69+ anthropicMessages , systemMessage , err := aisdk .MessagesToAnthropic (options .Messages )
70+ if err != nil {
71+ return nil , err
72+ }
73+ return aisdk .AnthropicToDataStream (client .Messages .NewStreaming (ctx , anthropic.MessageNewParams {
74+ Messages : anthropicMessages ,
75+ Model : options .Model ,
76+ System : systemMessage ,
77+ Tools : aisdk .ToolsToAnthropic (options .Tools ),
78+ MaxTokens : 8192 ,
79+ })), nil
80+ }
81+ if config .Models == nil {
82+ models , err := client .Models .List (ctx , anthropic.ModelListParams {})
83+ if err != nil {
84+ return nil , err
85+ }
86+ config .Models = make ([]string , len (models .Data ))
87+ for i , model := range models .Data {
88+ config .Models [i ] = model .ID
89+ }
90+ }
91+ break
92+ case "google" :
93+ client , err := genai .NewClient (ctx , & genai.ClientConfig {
94+ APIKey : config .APIKey ,
95+ Backend : genai .BackendGeminiAPI ,
96+ })
97+ if err != nil {
98+ return nil , err
99+ }
100+ streamFunc = func (ctx context.Context , options StreamOptions ) (aisdk.DataStream , error ) {
101+ googleMessages , err := aisdk .MessagesToGoogle (options .Messages )
102+ if err != nil {
103+ return nil , err
104+ }
105+ tools , err := aisdk .ToolsToGoogle (options .Tools )
106+ if err != nil {
107+ return nil , err
108+ }
109+ return aisdk .GoogleToDataStream (client .Models .GenerateContentStream (ctx , options .Model , googleMessages , & genai.GenerateContentConfig {
110+ Tools : tools ,
111+ })), nil
112+ }
113+ if config .Models == nil {
114+ models , err := client .Models .List (ctx , & genai.ListModelsConfig {})
115+ if err != nil {
116+ return nil , err
117+ }
118+ config .Models = make ([]string , len (models .Items ))
119+ for i , model := range models .Items {
120+ config .Models [i ] = model .Name
121+ }
122+ }
123+ break
124+ default :
125+ return nil , fmt .Errorf ("unsupported model type: %s" , config .Type )
126+ }
127+
128+ for _ , model := range config .Models {
129+ models [model ] = LanguageModel {
130+ LanguageModel : codersdk.LanguageModel {
131+ ID : model ,
132+ DisplayName : model ,
133+ Provider : config .Type ,
134+ },
135+ StreamFunc : streamFunc ,
136+ }
137+ }
138+ }
139+
140+ return models , nil
141+ }
0 commit comments