diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 97b4f6647f05..ebbc17856430 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,8 +9,8 @@ updates: - package-ecosystem: "nuget" directory: "dotnet/" schedule: - interval: "weekly" - day: "monday" + interval: "cron" + cronjob: "0 8 * * 4,0" # Every Thursday(4) and Sunday(0) at 8:00 UTC ignore: # For all System.* and Microsoft.Extensions/Bcl.* packages, ignore all major version updates - dependency-name: "System.*" @@ -24,20 +24,6 @@ updates: - ".NET" - "dependencies" - # Maintain dependencies for nuget - - package-ecosystem: "nuget" - directory: "samples/dotnet" - schedule: - interval: "weekly" - day: "monday" - - # Maintain dependencies for npm - - package-ecosystem: "npm" - directory: "samples/apps" - schedule: - interval: "weekly" - day: "monday" - # Maintain dependencies for pip - package-ecosystem: "pip" directory: "python/" @@ -55,4 +41,4 @@ updates: directory: "/" schedule: interval: "weekly" - day: "monday" + day: "tuesday" diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index af629d01a3a8..c60c6aedd669 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -5,26 +5,25 @@ true - - - - - - - - - - - - + + + + + + + + + + + + - - + - + @@ -33,7 +32,7 @@ - + @@ -48,15 +47,7 @@ - - - - - - - - @@ -64,52 +55,51 @@ - - - + + - - - + + + - + - - + + - - - - - - + + + + + + - + - - + - + + + + - - - + - + @@ -134,62 +124,62 @@ - + - + - + - + - - + + - + - + - - - - - - + + + + + + - + - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive @@ -199,23 +189,23 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + diff --git a/dotnet/SK-dotnet.slnx b/dotnet/SK-dotnet.slnx index ee828c730ad8..b662baf25562 100644 --- a/dotnet/SK-dotnet.slnx +++ b/dotnet/SK-dotnet.slnx @@ -38,6 +38,7 @@ + @@ -49,6 +50,10 @@ + + + + @@ -87,6 +92,7 @@ + diff --git a/dotnet/docs/EXPERIMENTS.md b/dotnet/docs/EXPERIMENTS.md index 56b5e073a0f0..e51bfd54b04a 100644 --- a/dotnet/docs/EXPERIMENTS.md +++ b/dotnet/docs/EXPERIMENTS.md @@ -76,13 +76,6 @@ You can use the following diagnostic IDs to ignore warnings or errors for a part | SKEXP0060 | Handlebars planner | | SKEXP0060 | OpenAI Stepwise planner | | | | | | | | | -| SKEXP0070 | Ollama AI connector | | | | | | -| SKEXP0070 | Gemini AI connector | | | | | | -| SKEXP0070 | Mistral AI connector | | | | | | -| SKEXP0070 | ONNX AI connector | | | | | | -| SKEXP0070 | Hugging Face AI connector | | | | | | -| SKEXP0070 | Amazon AI connector | | | | | | -| | | | | | | | | SKEXP0080 | Process Framework | | SKEXP0081 | Process Framework - Foundry Process | | | | | | | | diff --git a/dotnet/nuget/nuget-package.props b/dotnet/nuget/nuget-package.props index b28ab2ff0693..f5aac39af5c1 100644 --- a/dotnet/nuget/nuget-package.props +++ b/dotnet/nuget/nuget-package.props @@ -1,7 +1,7 @@ - 1.57.0 + 1.60.0 $(VersionPrefix)-$(VersionSuffix) $(VersionPrefix) @@ -9,7 +9,7 @@ true - 1.56.0 + 1.59.0 $(NoWarn);CP0003 diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index 18c9411c171e..97460d3e7c79 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -8,7 +8,7 @@ false true - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001,CA1724,IDE1006,IDE0009,MEVD9000 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0101,SKEXP0110,OPENAI001,CA1724,IDE1006,IDE0009,MEVD9000 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Concepts/Resources/Plugins/CopilotAgentPlugins/RetrievalPlugin/retrieval-apiplugin.json b/dotnet/samples/Concepts/Resources/Plugins/CopilotAgentPlugins/RetrievalPlugin/retrieval-apiplugin.json new file mode 100644 index 000000000000..adcd81020dda --- /dev/null +++ b/dotnet/samples/Concepts/Resources/Plugins/CopilotAgentPlugins/RetrievalPlugin/retrieval-apiplugin.json @@ -0,0 +1,36 @@ +{ + "$schema": "https://developer.microsoft.com/json-schemas/copilot/plugin/v2.1/schema.json", + "schema_version": "v2.1", + "name_for_human": "OData Service for namespace microsoft.graph", + "description_for_human": "This OData service is located at https://graph.microsoft.com/beta", + "description_for_model": "This OData service is located at https://graph.microsoft.com/beta", + "contact_email": "publisher-email@example.com", + "namespace": "Retrieval", + "capabilities": { + "conversation_starters": [ + { + "text": "Invoke action retrieval" + } + ] + }, + "functions": [ + { + "name": "copilot_retrieval", + "description": "Invoke action retrieval" + } + ], + "runtimes": [ + { + "type": "OpenApi", + "auth": { + "type": "None" + }, + "spec": { + "url": "retrieval-openapi.yml" + }, + "run_for_functions": [ + "copilot_retrieval" + ] + } + ] +} \ No newline at end of file diff --git a/dotnet/samples/Concepts/Resources/Plugins/CopilotAgentPlugins/RetrievalPlugin/retrieval-openapi.yml b/dotnet/samples/Concepts/Resources/Plugins/CopilotAgentPlugins/RetrievalPlugin/retrieval-openapi.yml new file mode 100644 index 000000000000..1ed7e8e7953a --- /dev/null +++ b/dotnet/samples/Concepts/Resources/Plugins/CopilotAgentPlugins/RetrievalPlugin/retrieval-openapi.yml @@ -0,0 +1,162 @@ +openapi: 3.0.4 +info: + title: OData Service for namespace microsoft.graph - Subset + description: This OData service is located at https://graph.microsoft.com/beta + version: beta +servers: + - url: https://graph.microsoft.com/beta +paths: + /copilot/retrieval: + post: + tags: + - copilot.copilotRoot.Actions + summary: Invoke action retrieval + operationId: copilot_retrieval + requestBody: + description: Action parameters + content: + application/json: + schema: + type: object + properties: + queryString: + type: string + dataSource: + title: retrievalDataSource + enum: + - sharePoint + - oneDriveBusiness + - externalItem + - mail + - calendar + - teams + - people + - sharePointEmbedded + - unknownFutureValue + type: string + filterExpression: + type: string + nullable: true + resourceMetadata: + type: array + items: + type: string + nullable: true + maximumNumberOfResults: + maximum: 2147483647 + minimum: -2147483648 + type: number + format: int32 + nullable: true + required: true + responses: + 2XX: + description: Success + content: + application/json: + schema: + $ref: '#/components/schemas/microsoft.graph.retrievalResponse' + deprecated: true + x-ms-deprecation: + removalDate: '2025-12-31T00:00:00.0000000+00:00' + date: '2024-02-23T00:00:00.0000000+00:00' + version: 2024-12/PrivatePreview:retrievalAPI +components: + schemas: + microsoft.graph.retrievalResponse: + title: retrievalResponse + required: + - '@odata.type' + type: object + properties: + retrievalHits: + type: array + items: + $ref: '#/components/schemas/microsoft.graph.retrievalHit' + '@odata.type': + type: string + microsoft.graph.retrievalHit: + title: retrievalHit + required: + - '@odata.type' + type: object + properties: + extracts: + type: array + items: + $ref: '#/components/schemas/microsoft.graph.retrievalExtract' + resourceMetadata: + $ref: '#/components/schemas/microsoft.graph.searchResourceMetadataDictionary' + resourceType: + title: retrievalEntityType + enum: + - site + - list + - listItem + - drive + - driveItem + - externalItem + - unknownFutureValue + type: string + sensitivityLabel: + $ref: '#/components/schemas/microsoft.graph.searchSensitivityLabelInfo' + webUrl: + type: string + nullable: true + '@odata.type': + type: string + microsoft.graph.retrievalExtract: + title: retrievalExtract + required: + - '@odata.type' + type: object + properties: + text: + type: string + nullable: true + '@odata.type': + type: string + microsoft.graph.searchResourceMetadataDictionary: + title: searchResourceMetadataDictionary + required: + - '@odata.type' + type: object + properties: + '@odata.type': + type: string + microsoft.graph.searchSensitivityLabelInfo: + title: searchSensitivityLabelInfo + required: + - '@odata.type' + type: object + properties: + color: + type: string + nullable: true + readOnly: true + displayName: + type: string + nullable: true + readOnly: true + isEncrypted: + type: boolean + nullable: true + readOnly: true + priority: + maximum: 2147483647 + minimum: -2147483648 + type: number + format: int32 + nullable: true + readOnly: true + sensitivityLabelId: + type: string + nullable: true + readOnly: true + tooltip: + type: string + nullable: true + readOnly: true + '@odata.type': + type: string + description: "Represents a sensitivityLabel.\nThis model is shared with the CCS retrieval API and search where it is already unhidden." diff --git a/dotnet/samples/Demos/A2AClientServer/A2AClient/A2AClient.csproj b/dotnet/samples/Demos/A2AClientServer/A2AClient/A2AClient.csproj new file mode 100644 index 000000000000..bc5375269c8a --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AClient/A2AClient.csproj @@ -0,0 +1,24 @@ + + + + Exe + net8.0 + enable + enable + 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 + $(NoWarn);CS1591;VSTHRD111;CA2007;SKEXP0110 + + + + + + + + + + + + + + + diff --git a/dotnet/samples/Demos/A2AClientServer/A2AClient/HostClientAgent.cs b/dotnet/samples/Demos/A2AClientServer/A2AClient/HostClientAgent.cs new file mode 100644 index 000000000000..ed5be9a77d5f --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AClient/HostClientAgent.cs @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.A2A; +using SharpA2A.Core; + +namespace A2A; + +internal sealed class HostClientAgent +{ + internal HostClientAgent(ILogger logger) + { + this._logger = logger; + } + internal async Task InitializeAgentAsync(string modelId, string apiKey, string[] agentUrls) + { + try + { + this._logger.LogInformation("Initializing Semantic Kernel agent with model: {ModelId}", modelId); + + // Connect to the remote agents via A2A + var createAgentTasks = agentUrls.Select(agentUrl => this.CreateAgentAsync(agentUrl)); + var agents = await Task.WhenAll(createAgentTasks); + var agentFunctions = agents.Select(agent => AgentKernelFunctionFactory.CreateFromAgent(agent)).ToList(); + var agentPlugin = KernelPluginFactory.CreateFromFunctions("AgentPlugin", agentFunctions); + + // Define the Host agent + var builder = Kernel.CreateBuilder(); + builder.AddOpenAIChatCompletion(modelId, apiKey); + builder.Plugins.Add(agentPlugin); + var kernel = builder.Build(); + kernel.FunctionInvocationFilters.Add(new ConsoleOutputFunctionInvocationFilter()); + + this.Agent = new ChatCompletionAgent() + { + Kernel = kernel, + Name = "HostClient", + Instructions = + """ + You specialize in handling queries for users and using your tools to provide answers. + """, + Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), + }; + } + catch (Exception ex) + { + this._logger.LogError(ex, "Failed to initialize HostClientAgent"); + throw; + } + } + + /// + /// The associated + /// + public Agent? Agent { get; private set; } + + #region private + private readonly ILogger _logger; + + private async Task CreateAgentAsync(string agentUri) + { + var httpClient = new HttpClient + { + BaseAddress = new Uri(agentUri), + Timeout = TimeSpan.FromSeconds(60) + }; + + var client = new A2AClient(httpClient); + var cardResolver = new A2ACardResolver(httpClient); + var agentCard = await cardResolver.GetAgentCardAsync(); + + return new A2AAgent(client, agentCard!); + } + #endregion +} + +internal sealed class ConsoleOutputFunctionInvocationFilter() : IFunctionInvocationFilter +{ + private static string IndentMultilineString(string multilineText, int indentLevel = 1, int spacesPerIndent = 4) + { + // Create the indentation string + var indentation = new string(' ', indentLevel * spacesPerIndent); + + // Split the text into lines, add indentation, and rejoin + char[] NewLineChars = { '\r', '\n' }; + string[] lines = multilineText.Split(NewLineChars, StringSplitOptions.None); + + return string.Join(Environment.NewLine, lines.Select(line => indentation + line)); + } + public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func next) + { + Console.ForegroundColor = ConsoleColor.DarkGray; + + Console.WriteLine($"\nCalling Agent {context.Function.Name} with arguments:"); + Console.ForegroundColor = ConsoleColor.Gray; + + foreach (var kvp in context.Arguments) + { + Console.WriteLine(IndentMultilineString($" {kvp.Key}: {kvp.Value}")); + } + + await next(context); + + if (context.Result.GetValue() is ChatMessageContent[] chatMessages) + { + Console.ForegroundColor = ConsoleColor.DarkGray; + + Console.WriteLine($"Response from Agent {context.Function.Name}:"); + foreach (var message in chatMessages) + { + Console.ForegroundColor = ConsoleColor.Gray; + + Console.WriteLine(IndentMultilineString($"{message}")); + } + } + Console.ResetColor(); + } +} diff --git a/dotnet/samples/Demos/A2AClientServer/A2AClient/Program.cs b/dotnet/samples/Demos/A2AClientServer/A2AClient/Program.cs new file mode 100644 index 000000000000..7aaa078b6735 --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AClient/Program.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.CommandLine; +using System.CommandLine.Invocation; +using System.Reflection; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; + +namespace A2A; + +public static class Program +{ + public static async Task Main(string[] args) + { + // Create root command with options + var rootCommand = new RootCommand("A2AClient"); + rootCommand.SetHandler(HandleCommandsAsync); + + // Run the command + return await rootCommand.InvokeAsync(args); + } + + public static async System.Threading.Tasks.Task HandleCommandsAsync(InvocationContext context) + { + await RunCliAsync(); + } + + #region private + private static async System.Threading.Tasks.Task RunCliAsync() + { + // Set up the logging + using var loggerFactory = LoggerFactory.Create(builder => + { + builder.AddConsole(); + builder.SetMinimumLevel(LogLevel.Information); + }); + var logger = loggerFactory.CreateLogger("A2AClient"); + + // Retrieve configuration settings + IConfigurationRoot configRoot = new ConfigurationBuilder() + .AddEnvironmentVariables() + .AddUserSecrets(Assembly.GetExecutingAssembly()) + .Build(); + var apiKey = configRoot["A2AClient:ApiKey"] ?? throw new ArgumentException("A2AClient:ApiKey must be provided"); + var modelId = configRoot["A2AClient:ModelId"] ?? "gpt-4.1"; + var agentUrls = configRoot["A2AClient:AgentUrls"] ?? "http://localhost:5000/;http://localhost:5001/;http://localhost:5002/"; + + // Create the Host agent + var hostAgent = new HostClientAgent(logger); + await hostAgent.InitializeAgentAsync(modelId, apiKey, agentUrls!.Split(";")); + AgentThread thread = new ChatHistoryAgentThread(); + try + { + while (true) + { + // Get user message + Console.Write("\nUser (:q or quit to exit): "); + string? message = Console.ReadLine(); + if (string.IsNullOrWhiteSpace(message)) + { + Console.WriteLine("Request cannot be empty."); + continue; + } + + if (message == ":q" || message == "quit") + { + break; + } + + await foreach (AgentResponseItem response in hostAgent.Agent!.InvokeAsync(message, thread)) + { + Console.ForegroundColor = ConsoleColor.Cyan; + Console.WriteLine($"\nAgent: {response.Message.Content}"); + Console.ResetColor(); + + thread = response.Thread; + } + } + } + catch (Exception ex) + { + logger.LogError(ex, "An error occurred while running the A2AClient"); + return; + } + } + #endregion +} diff --git a/dotnet/samples/Demos/A2AClientServer/A2AClient/README.md b/dotnet/samples/Demos/A2AClientServer/A2AClient/README.md new file mode 100644 index 000000000000..d04c622a5ea4 --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AClient/README.md @@ -0,0 +1,26 @@ + +# A2A Client Sample +Show how to create an A2A Client with a command line interface which invokes agents using the A2A protocol. + +## Run the Sample + +To run the sample, follow these steps: + +1. Run the A2A client: + ```bash + cd A2AClient + dotnet run + ``` +2. Enter your request e.g. "Show me all invoices for Contoso?" + +## Set Secrets with Secret Manager + +The agent urls are provided as a ` ` delimited list of strings + +```text +cd dotnet/samples/Demos/A2AClientServer/A2AClient + +dotnet user-secrets set "A2AClient:ModelId" "..." +dotnet user-secrets set "A2AClient":ApiKey" "..." +dotnet user-secrets set "A2AClient:AgentUrls" "http://localhost:5000/policy;http://localhost:5000/invoice;http://localhost:5000/logistics" +``` \ No newline at end of file diff --git a/dotnet/samples/Demos/A2AClientServer/A2AServer/A2AServer.csproj b/dotnet/samples/Demos/A2AClientServer/A2AServer/A2AServer.csproj new file mode 100644 index 000000000000..b82136c3d4ef --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AServer/A2AServer.csproj @@ -0,0 +1,24 @@ + + + + Exe + net8.0 + enable + enable + 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 + $(NoWarn);CS1591;VSTHRD111;CA2007;SKEXP0110 + + + + + + + + + + + + + + + diff --git a/dotnet/samples/Demos/A2AClientServer/A2AServer/A2AServer.http b/dotnet/samples/Demos/A2AClientServer/A2AServer/A2AServer.http new file mode 100644 index 000000000000..7e543ae005ba --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AServer/A2AServer.http @@ -0,0 +1,82 @@ +### Each A2A agent is available at a different host address +@hostInvoice = http://localhost:5000 +@hostPolicy = http://localhost:5001 +@hostLogistics = http://localhost:5002 + +### Query agent card for the invoice agent +GET {{hostInvoice}}/.well-known/agent.json + +### Send a message to the invoice agent +POST {{hostInvoice}} +Content-Type: application/json + +{ + "id": "1", + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "id": "12345", + "message": { + "role": "user", + "messageId": "msg_1", + "parts": [ + { + "kind": "text", + "text": "Show me all invoices for Contoso?" + } + ] + } + } +} + +### Query agent card for the policy agent +GET {{hostPolicy}}/.well-known/agent.json + +### Send a message to the policy agent +POST {{hostPolicy}} +Content-Type: application/json + +{ + "id": "1", + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "id": "12345", + "message": { + "role": "user", + "messageId": "msg_1", + "parts": [ + { + "kind": "text", + "text": "What is the policy for short shipments?" + } + ] + } + } +} + +### Query agent card for the logistics agent +GET {{hostLogistics}}/.well-known/agent.json + +### Send a message to the logistics agent +POST {{hostLogistics}} +Content-Type: application/json + +{ + "id": "1", + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "id": "12345", + "message": { + "role": "user", + "messageId": "msg_1", + "parts": [ + { + "kind": "text", + "text": "What is the status for SHPMT-SAP-001?" + } + ] + } + } +} \ No newline at end of file diff --git a/dotnet/samples/Demos/A2AClientServer/A2AServer/HostAgentFactory.cs b/dotnet/samples/Demos/A2AClientServer/A2AServer/HostAgentFactory.cs new file mode 100644 index 000000000000..8fd9953fa0b1 --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AServer/HostAgentFactory.cs @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Azure.AI.Agents.Persistent; +using Azure.Identity; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.A2A; +using Microsoft.SemanticKernel.Agents.AzureAI; +using SharpA2A.Core; + +namespace A2AServer; + +internal static class HostAgentFactory +{ + internal static async Task CreateFoundryHostAgentAsync(string agentType, string modelId, string endpoint, string assistantId, IEnumerable? plugins = null) + { + var agentsClient = new PersistentAgentsClient(endpoint, new AzureCliCredential()); + PersistentAgent definition = await agentsClient.Administration.GetAgentAsync(assistantId); + + var agent = new AzureAIAgent(definition, agentsClient, plugins); + + AgentCard agentCard = agentType.ToUpperInvariant() switch + { + "INVOICE" => GetInvoiceAgentCard(), + "POLICY" => GetPolicyAgentCard(), + "LOGISTICS" => GetLogisticsAgentCard(), + _ => throw new ArgumentException($"Unsupported agent type: {agentType}"), + }; + + return new A2AHostAgent(agent, agentCard); + } + + internal static async Task CreateChatCompletionHostAgentAsync(string agentType, string modelId, string apiKey, string name, string instructions, IEnumerable? plugins = null) + { + var builder = Kernel.CreateBuilder(); + builder.AddOpenAIChatCompletion(modelId, apiKey); + if (plugins is not null) + { + foreach (var plugin in plugins) + { + builder.Plugins.Add(plugin); + } + } + var kernel = builder.Build(); + + var agent = new ChatCompletionAgent() + { + Kernel = kernel, + Name = name, + Instructions = instructions, + Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), + }; + + AgentCard agentCard = agentType.ToUpperInvariant() switch + { + "INVOICE" => GetInvoiceAgentCard(), + "POLICY" => GetPolicyAgentCard(), + "LOGISTICS" => GetLogisticsAgentCard(), + _ => throw new ArgumentException($"Unsupported agent type: {agentType}"), + }; + + return new A2AHostAgent(agent, agentCard); + } + + #region private + private static AgentCard GetInvoiceAgentCard() + { + var capabilities = new AgentCapabilities() + { + Streaming = false, + PushNotifications = false, + }; + + var invoiceQuery = new AgentSkill() + { + Id = "id_invoice_agent", + Name = "InvoiceQuery", + Description = "Handles requests relating to invoices.", + Tags = ["invoice", "semantic-kernel"], + Examples = + [ + "List the latest invoices for Contoso.", + ], + }; + + return new() + { + Name = "InvoiceAgent", + Description = "Handles requests relating to invoices.", + Version = "1.0.0", + DefaultInputModes = ["text"], + DefaultOutputModes = ["text"], + Capabilities = capabilities, + Skills = [invoiceQuery], + }; + } + + private static AgentCard GetPolicyAgentCard() + { + var capabilities = new AgentCapabilities() + { + Streaming = false, + PushNotifications = false, + }; + + var invoiceQuery = new AgentSkill() + { + Id = "id_policy_agent", + Name = "PolicyAgent", + Description = "Handles requests relating to policies and customer communications.", + Tags = ["policy", "semantic-kernel"], + Examples = + [ + "What is the policy for short shipments?", + ], + }; + + return new AgentCard() + { + Name = "PolicyAgent", + Description = "Handles requests relating to policies and customer communications.", + Version = "1.0.0", + DefaultInputModes = ["text"], + DefaultOutputModes = ["text"], + Capabilities = capabilities, + Skills = [invoiceQuery], + }; + } + + private static AgentCard GetLogisticsAgentCard() + { + var capabilities = new AgentCapabilities() + { + Streaming = false, + PushNotifications = false, + }; + + var invoiceQuery = new AgentSkill() + { + Id = "id_invoice_agent", + Name = "LogisticsQuery", + Description = "Handles requests relating to logistics.", + Tags = ["logistics", "semantic-kernel"], + Examples = + [ + "What is the status for SHPMT-SAP-001", + ], + }; + + return new AgentCard() + { + Name = "LogisticsAgent", + Description = "Handles requests relating to logistics.", + Version = "1.0.0", + DefaultInputModes = ["text"], + DefaultOutputModes = ["text"], + Capabilities = capabilities, + Skills = [invoiceQuery], + }; + } + #endregion +} diff --git a/dotnet/samples/Demos/A2AClientServer/A2AServer/Plugins/InvoiceQueryPlugin.cs b/dotnet/samples/Demos/A2AClientServer/A2AServer/Plugins/InvoiceQueryPlugin.cs new file mode 100644 index 000000000000..453f339005f8 --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AServer/Plugins/InvoiceQueryPlugin.cs @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.ComponentModel; +using Microsoft.SemanticKernel; + +namespace A2A; +/// +/// A simple invoice plugin that returns mock data. +/// +public class Product +{ + public string Name { get; set; } + public int Quantity { get; set; } + public decimal Price { get; set; } // Price per unit + + public Product(string name, int quantity, decimal price) + { + this.Name = name; + this.Quantity = quantity; + this.Price = price; + } + + public decimal TotalPrice() + { + return this.Quantity * this.Price; // Total price for this product + } +} + +public class Invoice +{ + public string TransactionId { get; set; } + public string InvoiceId { get; set; } + public string CompanyName { get; set; } + public DateTime InvoiceDate { get; set; } + public List Products { get; set; } // List of products + + public Invoice(string transactionId, string invoiceId, string companyName, DateTime invoiceDate, List products) + { + this.TransactionId = transactionId; + this.InvoiceId = invoiceId; + this.CompanyName = companyName; + this.InvoiceDate = invoiceDate; + this.Products = products; + } + + public decimal TotalInvoicePrice() + { + return this.Products.Sum(product => product.TotalPrice()); // Total price of all products in the invoice + } +} + +public class InvoiceQueryPlugin +{ + private readonly List _invoices; + private static readonly Random s_random = new(); + + public InvoiceQueryPlugin() + { + // Extended mock data with quantities and prices + this._invoices = + [ + new("TICKET-XYZ987", "INV789", "Contoso", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 150, 10.00m), + new("Hats", 200, 15.00m), + new("Glasses", 300, 5.00m) + }), + new("TICKET-XYZ111", "INV111", "XStore", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 2500, 12.00m), + new("Hats", 1500, 8.00m), + new("Glasses", 200, 20.00m) + }), + new("TICKET-XYZ222", "INV222", "Cymbal Direct", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 1200, 14.00m), + new("Hats", 800, 7.00m), + new("Glasses", 500, 25.00m) + }), + new("TICKET-XYZ333", "INV333", "Contoso", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 400, 11.00m), + new("Hats", 600, 15.00m), + new("Glasses", 700, 5.00m) + }), + new("TICKET-XYZ444", "INV444", "XStore", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 800, 10.00m), + new("Hats", 500, 18.00m), + new("Glasses", 300, 22.00m) + }), + new("TICKET-XYZ555", "INV555", "Cymbal Direct", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 1100, 9.00m), + new("Hats", 900, 12.00m), + new("Glasses", 1200, 15.00m) + }), + new("TICKET-XYZ666", "INV666", "Contoso", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 2500, 8.00m), + new("Hats", 1200, 10.00m), + new("Glasses", 1000, 6.00m) + }), + new("TICKET-XYZ777", "INV777", "XStore", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 1900, 13.00m), + new("Hats", 1300, 16.00m), + new("Glasses", 800, 19.00m) + }), + new("TICKET-XYZ888", "INV888", "Cymbal Direct", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 2200, 11.00m), + new("Hats", 1700, 8.50m), + new("Glasses", 600, 21.00m) + }), + new("TICKET-XYZ999", "INV999", "Contoso", GetRandomDateWithinLastTwoMonths(), new List + { + new("T-Shirts", 1400, 10.50m), + new("Hats", 1100, 9.00m), + new("Glasses", 950, 12.00m) + }) + ]; + } + + public static DateTime GetRandomDateWithinLastTwoMonths() + { + // Get the current date and time + DateTime endDate = DateTime.Now; + + // Calculate the start date, which is two months before the current date + DateTime startDate = endDate.AddMonths(-2); + + // Generate a random number of days between 0 and the total number of days in the range + int totalDays = (endDate - startDate).Days; + int randomDays = s_random.Next(0, totalDays + 1); // +1 to include the end date + + // Return the random date + return startDate.AddDays(randomDays); + } + + [KernelFunction] + [Description("Retrieves invoices for the specified company and optionally within the specified time range")] + public IEnumerable QueryInvoices(string companyName, DateTime? startDate = null, DateTime? endDate = null) + { + var query = this._invoices.Where(i => i.CompanyName.Equals(companyName, StringComparison.OrdinalIgnoreCase)); + + if (startDate.HasValue) + { + query = query.Where(i => i.InvoiceDate >= startDate.Value); + } + + if (endDate.HasValue) + { + query = query.Where(i => i.InvoiceDate <= endDate.Value); + } + + return query.ToList(); + } + + [KernelFunction] + [Description("Retrieves invoice using the transaction id")] + public IEnumerable QueryByTransactionId(string transactionId) + { + var query = this._invoices.Where(i => i.TransactionId.Equals(transactionId, StringComparison.OrdinalIgnoreCase)); + + return query.ToList(); + } + + [KernelFunction] + [Description("Retrieves invoice using the invoice id")] + public IEnumerable QueryByInvoiceId(string invoiceId) + { + var query = this._invoices.Where(i => i.InvoiceId.Equals(invoiceId, StringComparison.OrdinalIgnoreCase)); + + return query.ToList(); + } +} diff --git a/dotnet/samples/Demos/A2AClientServer/A2AServer/Program.cs b/dotnet/samples/Demos/A2AClientServer/A2AServer/Program.cs new file mode 100644 index 000000000000..72680bbd36bb --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/A2AServer/Program.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft. All rights reserved. +using A2A; +using A2AServer; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.A2A; +using SharpA2A.AspNetCore; + +string agentId = string.Empty; +string agentType = string.Empty; + +for (var i = 0; i < args.Length; i++) +{ + if (args[i].StartsWith("--agentId", StringComparison.InvariantCultureIgnoreCase) && i + 1 < args.Length) + { + agentId = args[++i]; + } + else if (args[i].StartsWith("--agentType", StringComparison.InvariantCultureIgnoreCase) && i + 1 < args.Length) + { + agentType = args[++i]; + } +} + +var builder = WebApplication.CreateBuilder(args); +builder.Services.AddHttpClient().AddLogging(); +var app = builder.Build(); + +var httpClient = app.Services.GetRequiredService().CreateClient(); +var logger = app.Logger; + +IConfigurationRoot configuration = new ConfigurationBuilder() + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + +string? apiKey = configuration["A2AServer:ApiKey"]; +string? endpoint = configuration["A2AServer:Endpoint"]; +string modelId = configuration["A2AServer:ModelId"] ?? "gpt-4o-mini"; + +IEnumerable invoicePlugins = [KernelPluginFactory.CreateFromType()]; + +A2AHostAgent? hostAgent = null; +if (!string.IsNullOrEmpty(endpoint) && !string.IsNullOrEmpty(agentId)) +{ + hostAgent = agentType.ToUpperInvariant() switch + { + "INVOICE" => await HostAgentFactory.CreateFoundryHostAgentAsync(agentType, modelId, endpoint, agentId, invoicePlugins), + "POLICY" => await HostAgentFactory.CreateFoundryHostAgentAsync(agentType, modelId, endpoint, agentId), + "LOGISTICS" => await HostAgentFactory.CreateFoundryHostAgentAsync(agentType, modelId, endpoint, agentId), + _ => throw new ArgumentException($"Unsupported agent type: {agentType}"), + }; +} +else if (!string.IsNullOrEmpty(apiKey)) +{ + hostAgent = agentType.ToUpperInvariant() switch + { + "INVOICE" => await HostAgentFactory.CreateChatCompletionHostAgentAsync( + agentType, modelId, apiKey, "InvoiceAgent", + """ + You specialize in handling queries related to invoices. + """, invoicePlugins), + "POLICY" => await HostAgentFactory.CreateChatCompletionHostAgentAsync( + agentType, modelId, apiKey, "PolicyAgent", + """ + You specialize in handling queries related to policies and customer communications. + + Always reply with exactly this text: + + Policy: Short Shipment Dispute Handling Policy V2.1 + + Summary: "For short shipments reported by customers, first verify internal shipment records + (SAP) and physical logistics scan data (BigQuery). If discrepancy is confirmed and logistics data + shows fewer items packed than invoiced, issue a credit for the missing items. Document the + resolution in SAP CRM and notify the customer via email within 2 business days, referencing the + original invoice and the credit memo number. Use the 'Formal Credit Notification' email + template." + """, invoicePlugins), + "LOGISTICS" => await HostAgentFactory.CreateChatCompletionHostAgentAsync( + agentType, modelId, apiKey, "LogisticsAgent", + """ + You specialize in handling queries related to logistics. + + Always reply with exactly: + + Shipment number: SHPMT-SAP-001 + Item: TSHIRT-RED-L + Quantity: 900 + """, invoicePlugins), + _ => throw new ArgumentException($"Unsupported agent type: {agentType}"), + }; +} +else +{ + throw new ArgumentException("Either A2AServer:ApiKey or A2AServer:ConnectionString & agentId must be provided"); +} + +app.MapA2A(hostAgent!.TaskManager!, ""); + +await app.RunAsync(); diff --git a/dotnet/samples/Demos/A2AClientServer/README.md b/dotnet/samples/Demos/A2AClientServer/README.md new file mode 100644 index 000000000000..6038a44477f3 --- /dev/null +++ b/dotnet/samples/Demos/A2AClientServer/README.md @@ -0,0 +1,266 @@ +# A2A Client and Server samples + +> **Warning** +> The [A2A protocol](https://google.github.io/A2A/) is still under development and changing fast. +> We will try to keep these samples updated as the protocol evolves. + +These samples are built with [SharpA2A.Core](https://www.nuget.org/packages/SharpA2A.Core) and demonstrate: + +1. Creating an A2A Server which makes an agent available via the A2A protocol. +2. Creating an A2A Client with a command line interface which invokes agents using the A2A protocol. + +The demonstration has two components: + +1. `A2AServer` - You will run three instances of the server to correspond to three A2A servers each providing a single Agent i.e., the Invoice, Policy and Logistics agents. +2. `A2AClient` - This represents a client application which will connect to the remote A2A servers using the A2A protocol so that it can use those agents when answering questions you will ask. + +Demo Architecture + +## Configuring Secrets or Environment Variables + +The samples can be configured to use chat completion agents or Azure AI agents. + +### Configuring for use with Chat Completion Agents + +Provide your OpenAI API key via .Net secrets + +```bash +dotnet user-secrets set "A2AClient:ApiKey" "..." +``` + +Optionally if you want to use chat completion agents in the server then set the OpenAI key for the server to use. + +```bash +dotnet user-secrets set "A2AServer:ApiKey" "..." +``` + +Use the following commands to run each A2A server: + +```bash +cd A2AServer +dotnet run --urls "http://localhost:5000;https://localhost:5010" --agentType "invoice" +``` + +```bash +cd A2AServer +dotnet run --urls "http://localhost:5001;https://localhost:5011" --agentType "policy" +``` + +```bash +cd A2AServer +dotnet run --urls "http://localhost:5002;https://localhost:5012" --agentType "logistics" +``` + +### Configuring for use with Azure AI Agents + +You must create the agents in an Azure AI Foundry project and then provide the project endpoint and agents ids. The instructions for each agent are as follows: + +- Invoice Agent + ``` + You specialize in handling queries related to invoices. + ``` +- Policy Agent + ``` + You specialize in handling queries related to policies and customer communications. + + Always reply with exactly this text: + + Policy: Short Shipment Dispute Handling Policy V2.1 + + Summary: "For short shipments reported by customers, first verify internal shipment records + (SAP) and physical logistics scan data (BigQuery). If discrepancy is confirmed and logistics data + shows fewer items packed than invoiced, issue a credit for the missing items. Document the + resolution in SAP CRM and notify the customer via email within 2 business days, referencing the + original invoice and the credit memo number. Use the 'Formal Credit Notification' email + template." + ``` +- Logistics Agent + ``` + You specialize in handling queries related to logistics. + + Always reply with exactly: + + Shipment number: SHPMT-SAP-001 + Item: TSHIRT-RED-L + Quantity: 900" + ``` + +```bash +dotnet user-secrets set "A2AServer:Endpoint" "..." +``` + +Use the following commands to run each A2A server + +```bash +cd A2AServer +dotnet run --urls "http://localhost:5000;https://localhost:5010" --agentId "" --agentType "invoice" +``` + +```bash +cd A2AServer +dotnet run --urls "http://localhost:5001;https://localhost:5011" --agentId "" --agentType "policy" +``` + +```bash +cd A2AServer +dotnet run --urls "http://localhost:5002;https://localhost:5012" --agentId "" --agentType "logistics" +``` + +### Testing the Agents using the Rest Client + +This sample contains a [.http file](https://learn.microsoft.com/aspnet/core/test/http-files?view=aspnetcore-9.0) which can be used to test the agent. + +1. In Visual Studio open [./A2AServer/A2AServer.http](./A2AServer/A2AServer.http) +1. There are two sent requests for each agent, e.g., for the invoice agent: + 1. Query agent card for the invoice agent + `GET {{hostInvoice}}/.well-known/agent.json` + 1. Send a message to the invoice agent + ``` + POST {{hostInvoice}} + Content-Type: application/json + + { + "id": "1", + "jsonrpc": "2.0", + "method": "message/send", + "params": { + "id": "12345", + "message": { + "role": "user", + "messageId": "msg_1", + "parts": [ + { + "kind": "text", + "text": "Show me all invoices for Contoso?" + } + ] + } + } + } + ``` + +Sample output from the request to display the agent card: + +Agent Card + +Sample output from the request to send a message to the agent via A2A protocol: + +Send Message + +### Testing the Agents using the A2A Inspector + +The A2A Inspector is a web-based tool designed to help developers inspect, debug, and validate servers that implement the Google A2A (Agent-to-Agent) protocol. It provides a user-friendly interface to interact with an A2A agent, view communication, and ensure specification compliance. + +For more information go [here](https://github.com/a2aproject/a2a-inspector). + +Running the [inspector with Docker](https://github.com/a2aproject/a2a-inspector?tab=readme-ov-file#option-two-run-with-docker) is the easiest way to get started. + +1. Navigate to the A2A Inspector in your browser: [http://127.0.0.1:8080/](http://127.0.0.1:8080/) +1. Enter the URL of the Agent you are running e.g., [http://host.docker.internal:5000](http://host.docker.internal:5000) +1. Connect to the agent and the agent card will be displayed and validated. +1. Type a message and send it to the agent using A2A protocol. + 1. The response will be validated automatically and then displayed in the UI. + 1. You can select the response to view the raw json. + +Agent card after connecting to an agent using the A2A protocol: + +Agent Card + +Sample response after sending a message to the agent via A2A protocol: + +Send Message + +Raw JSON response from an A2A agent: + +Response Raw JSON + +### Configuring Agents for the A2A Client + +The A2A client will connect to remote agents using the A2A protocol. + +By default the client will connect to the invoice, policy and logistics agents provided by the sample A2A Server. + +These are available at the following URL's: + +- Invoice Agent: http://localhost:5000/ +- Policy Agent: http://localhost:5001/ +- Logistics Agent: http://localhost:5002/ + +If you want to change which agents are using then set the agents url as a space delimited string as follows: + +```bash +dotnet user-secrets set "A2AClient:AgentUrls" "http://localhost:5000/;http://localhost:5001/;http://localhost:5002/" +``` + +## Run the Sample + +To run the sample, follow these steps: + +1. Run the A2A server's using the commands shown earlier +2. Run the A2A client: + ```bash + cd A2AClient + dotnet run + ``` +3. Enter your request e.g. "Customer is disputing transaction TICKET-XYZ987 as they claim the received fewer t-shirts than ordered." +4. The host client agent will call the remote agents, these calls will be displayed as console output. The final answer will use information from the remote agents. The sample below includes all three agents but in your case you may only see the policy and invoice agent. + +Sample output from the A2A client: + +``` +A2AClient> dotnet run +info: A2AClient[0] + Initializing Semantic Kernel agent with model: gpt-4o-mini + +User (:q or quit to exit): Customer is disputing transaction TICKET-XYZ987 as they claim the received fewer t-shirts than ordered. + +Calling Agent InvoiceAgent with arguments: + query: TICKET-XYZ987 + instructions: Investigate the transaction details for TICKET-XYZ987 and verify the number of t-shirts ordered versus the number received. + +Response from Agent InvoiceAgent: + The invoice associated with the transaction ID TICKET-XYZ987 is for the company Contoso. It was issued on June 18, 2025. The products in the invoice include 150 T-Shirts priced at $10.00 each, 200 Hats priced at $15.00 each, and 300 Glasses priced at $5.00 each. If you need more details or a copy of the invoice, please let me know! + +Calling Agent LogisticsAgent with arguments: + query: TICKET-XYZ987 + instructions: Check the shipping details for TICKET-XYZ987, specifically the quantity of t-shirts dispatched to confirm if fewer t-shirts were sent. + +Response from Agent LogisticsAgent: + Shipment number: SHPMT-SAP-001 + Item: TSHIRT-RED-L + Quantity: 900 + +Calling Agent PolicyAgent with arguments: + query: TICKET-XYZ987 + instructions: Review the policy regarding disputes and claims related to shipment discrepancies, especially concerning t-shirts. + +Response from Agent PolicyAgent: + Policy: Short Shipment Dispute Handling Policy V2.1 + + Summary: "For short shipments reported by customers, first verify internal shipment records + (SAP) and physical logistics scan data (BigQuery). If discrepancy is confirmed and logistics data + shows fewer items packed than invoiced, issue a credit for the missing items. Document the + resolution in SAP CRM and notify the customer via email within 2 business days, referencing the + original invoice and the credit memo number. Use the 'Formal Credit Notification' email + template." + +Agent: Here's the investigation result for transaction TICKET-XYZ987: + +1. **Invoice Details**: The invoice for transaction TICKET-XYZ987 indicates that 150 t-shirts were ordered. + +2. **Shipment Details**: The logistics records show that a total of 900 t-shirts were dispatched under the shipment number SHPMT-SAP-001. + +There seems to be a significant discrepancy between the number of t-shirts ordered and the number shipped. According to the Short Shipment Dispute Handling Policy, the next steps are as follows: + +1. **Confirm Discrepancy**: Since the logistics data confirms that 900 t-shirts were packed, it is necessary to check if this aligns with the customer's claim. + +2. **Issue Credit**: If the customer is indeed correct and fewer items were actually received compared to what was invoiced, you would need to issue a credit for the missing items. + +3. **Document Resolution**: Ensure to document the resolution in SAP CRM. + +4. **Notify the Customer**: Notify the customer via email within 2 business days, using the 'Formal Credit Notification' email template, and reference both the original invoice and the credit memo number. + +Please let me know if you would like to proceed with any specific action! + +User (:q or quit to exit): +``` diff --git a/dotnet/samples/Demos/A2AClientServer/a2a-inspector-agent-card.png b/dotnet/samples/Demos/A2AClientServer/a2a-inspector-agent-card.png new file mode 100644 index 000000000000..8385a2b68a45 Binary files /dev/null and b/dotnet/samples/Demos/A2AClientServer/a2a-inspector-agent-card.png differ diff --git a/dotnet/samples/Demos/A2AClientServer/a2a-inspector-raw-json-response.png b/dotnet/samples/Demos/A2AClientServer/a2a-inspector-raw-json-response.png new file mode 100644 index 000000000000..038ef344edf7 Binary files /dev/null and b/dotnet/samples/Demos/A2AClientServer/a2a-inspector-raw-json-response.png differ diff --git a/dotnet/samples/Demos/A2AClientServer/a2a-inspector-send-message.png b/dotnet/samples/Demos/A2AClientServer/a2a-inspector-send-message.png new file mode 100644 index 000000000000..49fa857955cc Binary files /dev/null and b/dotnet/samples/Demos/A2AClientServer/a2a-inspector-send-message.png differ diff --git a/dotnet/samples/Demos/A2AClientServer/demo-architecture.png b/dotnet/samples/Demos/A2AClientServer/demo-architecture.png new file mode 100644 index 000000000000..6ae351907a64 Binary files /dev/null and b/dotnet/samples/Demos/A2AClientServer/demo-architecture.png differ diff --git a/dotnet/samples/Demos/A2AClientServer/rest-client-agent-card.png b/dotnet/samples/Demos/A2AClientServer/rest-client-agent-card.png new file mode 100644 index 000000000000..44651487a8ec Binary files /dev/null and b/dotnet/samples/Demos/A2AClientServer/rest-client-agent-card.png differ diff --git a/dotnet/samples/Demos/A2AClientServer/rest-client-send-message.png b/dotnet/samples/Demos/A2AClientServer/rest-client-send-message.png new file mode 100644 index 000000000000..fe65f5c92dd7 Binary files /dev/null and b/dotnet/samples/Demos/A2AClientServer/rest-client-send-message.png differ diff --git a/dotnet/samples/Demos/AIModelRouter/Program.cs b/dotnet/samples/Demos/AIModelRouter/Program.cs index d2ca630a8843..28e38e7a96ac 100644 --- a/dotnet/samples/Demos/AIModelRouter/Program.cs +++ b/dotnet/samples/Demos/AIModelRouter/Program.cs @@ -7,7 +7,6 @@ #pragma warning disable SKEXP0001 #pragma warning disable SKEXP0010 -#pragma warning disable SKEXP0070 namespace AIModelRouter; diff --git a/dotnet/samples/Demos/AgentFrameworkWithAspire/ChatWithAgent.AppHost/Program.cs b/dotnet/samples/Demos/AgentFrameworkWithAspire/ChatWithAgent.AppHost/Program.cs index 5e9671f1dac9..c148aaf805ec 100644 --- a/dotnet/samples/Demos/AgentFrameworkWithAspire/ChatWithAgent.AppHost/Program.cs +++ b/dotnet/samples/Demos/AgentFrameworkWithAspire/ChatWithAgent.AppHost/Program.cs @@ -44,25 +44,44 @@ static List> AddAIServices(IDist // Add chat deployment if (config.AIChatService == AzureOpenAIChatConfig.ConfigSectionName) { - chatResource = azureOpenAI.AddDeployment(new AzureOpenAIDeployment( - name: config.AzureOpenAIChat.DeploymentName, - modelName: config.AzureOpenAIChat.ModelName, - modelVersion: config.AzureOpenAIChat.ModelVersion, - skuName: config.AzureOpenAIChat.SkuName, - skuCapacity: config.AzureOpenAIChat.SkuCapacity) - ); + chatResource = azureOpenAI + .AddDeployment( + name: config.AzureOpenAIChat.DeploymentName, + modelName: config.AzureOpenAIChat.ModelName, + modelVersion: config.AzureOpenAIChat.ModelVersion) + .WithProperties((resource) => + { + if (config.AzureOpenAIChat.SkuName is { } skuName) + { + resource.SkuName = skuName; + } + + if (config.AzureOpenAIChat.SkuCapacity is { } skuCapacity) + { + resource.SkuCapacity = skuCapacity; + } + }); } // Add deployment if (config.Rag.AIEmbeddingService == AzureOpenAIEmbeddingsConfig.ConfigSectionName) { - embeddingsResource = azureOpenAI.AddDeployment(new AzureOpenAIDeployment( - name: config.AzureOpenAIEmbeddings.DeploymentName, - modelName: config.AzureOpenAIEmbeddings.ModelName, - modelVersion: config.AzureOpenAIEmbeddings.ModelVersion, - skuName: config.AzureOpenAIEmbeddings.SkuName, - skuCapacity: config.AzureOpenAIEmbeddings.SkuCapacity) - ); + embeddingsResource = azureOpenAI + .AddDeployment( + name: config.AzureOpenAIEmbeddings.DeploymentName, + modelName: config.AzureOpenAIEmbeddings.ModelName, + modelVersion: config.AzureOpenAIEmbeddings.ModelVersion) + .WithProperties((resource) => + { + if (config.AzureOpenAIEmbeddings.SkuName is { } skuName) + { + resource.SkuName = skuName; + } + if (config.AzureOpenAIEmbeddings.SkuCapacity is { } skuCapacity) + { + resource.SkuCapacity = skuCapacity; + } + }); } } else diff --git a/dotnet/samples/Demos/AmazonBedrockModels/AmazonBedrockAIModels.csproj b/dotnet/samples/Demos/AmazonBedrockModels/AmazonBedrockAIModels.csproj index 47306a16a84d..b74ff50cf227 100644 --- a/dotnet/samples/Demos/AmazonBedrockModels/AmazonBedrockAIModels.csproj +++ b/dotnet/samples/Demos/AmazonBedrockModels/AmazonBedrockAIModels.csproj @@ -8,7 +8,7 @@ 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 - $(NoWarn);SKEXP0001;SKEXP0070 + $(NoWarn);SKEXP0001 diff --git a/dotnet/samples/Demos/AotCompatibility/OnnxChatCompletionSamples.cs b/dotnet/samples/Demos/AotCompatibility/OnnxChatCompletionSamples.cs index 1157bc829886..38fa5e265f8e 100644 --- a/dotnet/samples/Demos/AotCompatibility/OnnxChatCompletionSamples.cs +++ b/dotnet/samples/Demos/AotCompatibility/OnnxChatCompletionSamples.cs @@ -1,7 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -#pragma warning disable SKEXP0070 - using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; diff --git a/dotnet/samples/Demos/BookingRestaurant/BookingRestaurant.csproj b/dotnet/samples/Demos/BookingRestaurant/BookingRestaurant.csproj index ec2f84a4fa67..cc3c77d903a5 100644 --- a/dotnet/samples/Demos/BookingRestaurant/BookingRestaurant.csproj +++ b/dotnet/samples/Demos/BookingRestaurant/BookingRestaurant.csproj @@ -17,7 +17,6 @@ - diff --git a/dotnet/samples/Demos/CopilotAgentPlugins/CopilotAgentPluginsDemoSample/DemoCommand.cs b/dotnet/samples/Demos/CopilotAgentPlugins/CopilotAgentPluginsDemoSample/DemoCommand.cs index 336b1832e455..abdd16856a3e 100644 --- a/dotnet/samples/Demos/CopilotAgentPlugins/CopilotAgentPluginsDemoSample/DemoCommand.cs +++ b/dotnet/samples/Demos/CopilotAgentPlugins/CopilotAgentPluginsDemoSample/DemoCommand.cs @@ -238,7 +238,6 @@ private static (Kernel, PromptExecutionSettings) InitializeKernelForOllama(IConf loggingBuilder.AddProvider(new SemanticKernelLoggerProvider()); }); } -#pragma warning disable SKEXP0070 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. #pragma warning disable SKEXP0001 return (builder.AddOllamaChatCompletion( chatModelId, @@ -253,7 +252,6 @@ private static (Kernel, PromptExecutionSettings) InitializeKernelForOllama(IConf ) }); #pragma warning restore SKEXP0001 -#pragma warning restore SKEXP0070 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. } private static (Kernel, PromptExecutionSettings) InitializeAzureOpenAiKernel(IConfiguration configuration, bool enableLogging) @@ -364,6 +362,7 @@ private async Task AddCopilotAgentPluginAsync(Kernel kernel, IConfigurationRoot FunctionExecutionParameters = new() { { "https://graph.microsoft.com/v1.0", new OpenApiFunctionExecutionParameters(authCallback: this._bearerAuthenticationProviderWithCancellationToken.AuthenticateRequestAsync, enableDynamicOperationPayload: false, enablePayloadNamespacing: true) { ParameterFilter = s_restApiParameterFilter} }, + { "https://graph.microsoft.com/beta", new OpenApiFunctionExecutionParameters(authCallback: this._bearerAuthenticationProviderWithCancellationToken.AuthenticateRequestAsync, enableDynamicOperationPayload: false, enablePayloadNamespacing: true) { ParameterFilter = s_restApiParameterFilter} }, { "https://api.nasa.gov/planetary", new OpenApiFunctionExecutionParameters(authCallback: GetApiKeyAuthProvider("DEMO_KEY", "api_key", false), enableDynamicOperationPayload: false, enablePayloadNamespacing: true)} }, }; @@ -500,8 +499,9 @@ private static void TrimPropertiesFromJsonNode(JsonNode jsonNode) { #pragma warning restore SKEXP0040 if (("me_sendMail".Equals(context.Operation.Id, StringComparison.OrdinalIgnoreCase) || - ("me_calendar_CreateEvents".Equals(context.Operation.Id, StringComparison.OrdinalIgnoreCase)) && - "payload".Equals(context.Parameter.Name, StringComparison.OrdinalIgnoreCase))) + ("me_calendar_CreateEvents".Equals(context.Operation.Id, StringComparison.OrdinalIgnoreCase) || + ("copilot_retrieval".Equals(context.Operation.Id, StringComparison.OrdinalIgnoreCase)) && + "payload".Equals(context.Parameter.Name, StringComparison.OrdinalIgnoreCase)))) { context.Parameter.Schema = TrimPropertiesFromRequestBody(context.Parameter.Schema); return context.Parameter; diff --git a/dotnet/samples/Demos/HuggingFaceImageToText/FormMain.cs b/dotnet/samples/Demos/HuggingFaceImageToText/FormMain.cs index e269c265489b..042cd3d19a87 100644 --- a/dotnet/samples/Demos/HuggingFaceImageToText/FormMain.cs +++ b/dotnet/samples/Demos/HuggingFaceImageToText/FormMain.cs @@ -7,7 +7,6 @@ namespace HuggingFaceImageTextDemo; #pragma warning disable SKEXP0001 // Type is for evaluation purposes only and is subject to change or removal in future updates. -#pragma warning disable SKEXP0070 // Type is for evaluation purposes only and is subject to change or removal in future updates. /// /// Main form of the application. diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ChatMessageContentExtensions.cs b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ChatMessageContentExtensions.cs index 9cf65d3c99c5..bae19daa1374 100644 --- a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ChatMessageContentExtensions.cs +++ b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ChatMessageContentExtensions.cs @@ -23,37 +23,33 @@ public static CreateMessageResult ToCreateMessageResult(this ChatMessageContent // ChatMessageContent can contain multiple items of different modalities, while the CreateMessageResult // can only have a single content type: text, image, or audio. First, look for image or audio content, // and if not found, fall back to the text content type by concatenating the text of all text contents. - Content? content = null; + ContentBlock? content = null; foreach (KernelContent item in chatMessageContent.Items) { if (item is ImageContent image) { - content = new Content + content = new ImageContentBlock { - Type = "image", Data = Convert.ToBase64String(image.Data!.Value.Span), - MimeType = image.MimeType + MimeType = image.MimeType ?? "image/jpeg" }; break; } else if (item is AudioContent audio) { - content = new Content + content = new AudioContentBlock { - Type = "audio", Data = Convert.ToBase64String(audio.Data!.Value.Span), - MimeType = audio.MimeType + MimeType = audio.MimeType ?? "audio/mpeg" }; break; } } - content ??= new Content + content ??= new TextContentBlock { - Type = "text", Text = string.Concat(chatMessageContent.Items.OfType()), - MimeType = "text/plain" }; return new CreateMessageResult diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ContentBlockExtensions.cs b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ContentBlockExtensions.cs new file mode 100644 index 000000000000..b4dff034ddfa --- /dev/null +++ b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ContentBlockExtensions.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.SemanticKernel; +using ModelContextProtocol.Protocol; + +namespace MCPClient; + +/// +/// Extension methods for the class. +/// +public static class ContentBlockExtensions +{ + /// + /// Converts a object to a object. + /// + /// The object to convert. + /// The corresponding object. + public static KernelContent ToKernelContent(this ContentBlock content) + { + return content switch + { + TextContentBlock textContentBlock => new TextContent(textContentBlock.Text), + ImageContentBlock imageContentBlock => new ImageContent(Convert.FromBase64String(imageContentBlock.Data!), imageContentBlock.MimeType), + AudioContentBlock audioContentBlock => new AudioContent(Convert.FromBase64String(audioContentBlock.Data!), audioContentBlock.MimeType), + _ => throw new InvalidOperationException($"Unexpected message content type '{content.Type}'"), + }; + } +} diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ContentExtensions.cs b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ContentExtensions.cs deleted file mode 100644 index 17723034dba9..000000000000 --- a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/Extensions/ContentExtensions.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using Microsoft.SemanticKernel; -using ModelContextProtocol.Protocol; - -namespace MCPClient; - -/// -/// Extension methods for the class. -/// -public static class ContentExtensions -{ - /// - /// Converts a object to a object. - /// - /// The object to convert. - /// The corresponding object. - public static KernelContent ToKernelContent(this Content content) - { - return content.Type switch - { - "text" => new TextContent(content.Text), - "image" => new ImageContent(Convert.FromBase64String(content.Data!), content.MimeType), - "audio" => new AudioContent(Convert.FromBase64String(content.Data!), content.MimeType), - _ => throw new InvalidOperationException($"Unexpected message content type '{content.Type}'"), - }; - } -} diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/MCPClient.csproj b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/MCPClient.csproj index 5c5a2299c4a2..c523311cf38a 100644 --- a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/MCPClient.csproj +++ b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPClient/MCPClient.csproj @@ -12,7 +12,6 @@ - diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Prompts/PromptDefinition.cs b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Prompts/PromptDefinition.cs index 02eea7a4f600..e7f9fb987c5b 100644 --- a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Prompts/PromptDefinition.cs +++ b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Prompts/PromptDefinition.cs @@ -103,9 +103,8 @@ private static async Task GetPromptHandlerAsync(RequestContext< [ new PromptMessage() { - Content = new Content() + Content = new TextContentBlock() { - Type = "text", Text = renderedPrompt }, Role = Role.Assistant diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Tools/MailboxUtils.cs b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Tools/MailboxUtils.cs index 4918de25be2b..fe4ace51335f 100644 --- a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Tools/MailboxUtils.cs +++ b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Tools/MailboxUtils.cs @@ -55,7 +55,7 @@ public static async Task SummarizeUnreadEmailsAsync([FromKernelServices] CreateMessageResult result = await server.SampleAsync(request, cancellationToken: CancellationToken.None); // Assuming the response is a text message - return result.Content.Text!; + return (result.Content as TextContentBlock)!.Text; } /// @@ -72,11 +72,9 @@ private static List CreateMessagesFromEmails(params Email[] ema messages.Add(new SamplingMessage { Role = Role.User, - Content = new Content + Content = new TextContentBlock { Text = $"Email from {email.Sender} with subject {email.Subject}. Body: {email.Body}", - Type = "text", - MimeType = "text/plain" } }); @@ -87,9 +85,8 @@ private static List CreateMessagesFromEmails(params Email[] ema messages.Add(new SamplingMessage { Role = Role.User, - Content = new Content + Content = new ImageContentBlock { - Type = "image", Data = Convert.ToBase64String(attachment), MimeType = "image/png", } diff --git a/dotnet/samples/Demos/ModelContextProtocolPlugin/ModelContextProtocolPlugin.csproj b/dotnet/samples/Demos/ModelContextProtocolPlugin/ModelContextProtocolPlugin.csproj index 6eb850dbc343..bb6e26fa4df2 100644 --- a/dotnet/samples/Demos/ModelContextProtocolPlugin/ModelContextProtocolPlugin.csproj +++ b/dotnet/samples/Demos/ModelContextProtocolPlugin/ModelContextProtocolPlugin.csproj @@ -13,7 +13,6 @@ - diff --git a/dotnet/samples/Demos/ModelContextProtocolPluginAuth/ModelContextProtocolPluginAuth.csproj b/dotnet/samples/Demos/ModelContextProtocolPluginAuth/ModelContextProtocolPluginAuth.csproj new file mode 100644 index 000000000000..35e7321db7d2 --- /dev/null +++ b/dotnet/samples/Demos/ModelContextProtocolPluginAuth/ModelContextProtocolPluginAuth.csproj @@ -0,0 +1,28 @@ + + + + Exe + net8.0 + enable + enable + 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 + $(NoWarn);CA2249;CS0612;SKEXP0001;VSTHRD111;CA2007;RCS1263 + + + + + + + + + + + + + + + + + + + diff --git a/dotnet/samples/Demos/ModelContextProtocolPluginAuth/Program.cs b/dotnet/samples/Demos/ModelContextProtocolPluginAuth/Program.cs new file mode 100644 index 000000000000..29f6b0c509f8 --- /dev/null +++ b/dotnet/samples/Demos/ModelContextProtocolPluginAuth/Program.cs @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using System.Net; +using System.Text; +using System.Web; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using ModelContextProtocol.Client; + +var config = new ConfigurationBuilder() + .AddUserSecrets() + .AddEnvironmentVariables() + .Build(); + +if (config["OpenAI:ApiKey"] is not { } apiKey) +{ + Console.Error.WriteLine("Please provide a valid OpenAI:ApiKey to run this sample. See the associated README.md for more details."); + return; +} + +// We can customize a shared HttpClient with a custom handler if desired +using var sharedHandler = new SocketsHttpHandler +{ + PooledConnectionLifetime = TimeSpan.FromMinutes(2), + PooledConnectionIdleTimeout = TimeSpan.FromMinutes(1) +}; +using var httpClient = new HttpClient(sharedHandler); + +var consoleLoggerFactory = LoggerFactory.Create(builder => +{ + builder.AddConsole(); +}); + +// Create SSE client transport for the MCP server +var serverUrl = "http://localhost:7071/"; +var transport = new SseClientTransport(new() +{ + Endpoint = new Uri(serverUrl), + Name = "Secure Weather Client", + OAuth = new() + { + ClientName = "ProtectedMcpClient", + RedirectUri = new Uri("http://localhost:1179/callback"), + AuthorizationRedirectDelegate = HandleAuthorizationUrlAsync, + } +}, httpClient, consoleLoggerFactory); + +// Create an MCPClient for the protected MCP server +await using var mcpClient = await McpClientFactory.CreateAsync(transport, loggerFactory: consoleLoggerFactory); + +// Retrieve the list of tools available on the GitHub server +var tools = await mcpClient.ListToolsAsync().ConfigureAwait(false); +foreach (var tool in tools) +{ + Console.WriteLine($"{tool.Name}: {tool.Description}"); +} + +// Prepare and build kernel with the MCP tools as Kernel functions +var builder = Kernel.CreateBuilder(); +builder.Services + .AddLogging(c => c.AddDebug().SetMinimumLevel(Microsoft.Extensions.Logging.LogLevel.Trace)) + .AddOpenAIChatCompletion( + modelId: config["OpenAI:ChatModelId"] ?? "gpt-4o-mini", + apiKey: apiKey); +Kernel kernel = builder.Build(); +kernel.Plugins.AddFromFunctions("WeatherApi", tools.Select(aiFunction => aiFunction.AsKernelFunction())); + +// Enable automatic function calling +OpenAIPromptExecutionSettings executionSettings = new() +{ + Temperature = 0, + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new() { RetainArgumentTypes = true }) +}; + +// Test using weather tools +var prompt = "Get current weather alerts for New York?"; +var result = await kernel.InvokePromptAsync(prompt, new(executionSettings)).ConfigureAwait(false); +Console.WriteLine($"\n\n{prompt}\n{result}"); + +// Define the agent +ChatCompletionAgent agent = new() +{ + Instructions = "Answer questions about weather alerts for US states.", + Name = "WeatherAgent", + Kernel = kernel, + Arguments = new KernelArguments(executionSettings), +}; + +// Respond to user input, invoking functions where appropriate. +ChatMessageContent response = await agent.InvokeAsync("Get the current weather alerts for Washington?").FirstAsync(); +Console.WriteLine($"\n\nResponse from WeatherAgent:\n{response.Content}"); + +/// +/// Handles the OAuth authorization URL by starting a local HTTP server and opening a browser. +/// This implementation demonstrates how SDK consumers can provide their own authorization flow. +/// +/// The authorization URL to open in the browser. +/// The redirect URI where the authorization code will be sent. +/// The cancellation token. +/// The authorization code extracted from the callback, or null if the operation failed. +static async Task HandleAuthorizationUrlAsync(Uri authorizationUrl, Uri redirectUri, CancellationToken cancellationToken) +{ + Console.WriteLine("Starting OAuth authorization flow..."); + Console.WriteLine($"Opening browser to: {authorizationUrl}"); + + var listenerPrefix = redirectUri.GetLeftPart(UriPartial.Authority); + if (!listenerPrefix.EndsWith("/", StringComparison.InvariantCultureIgnoreCase)) + { + listenerPrefix += "/"; + } + + using var listener = new HttpListener(); + listener.Prefixes.Add(listenerPrefix); + + try + { + listener.Start(); + Console.WriteLine($"Listening for OAuth callback on: {listenerPrefix}"); + + OpenBrowser(authorizationUrl); + + var context = await listener.GetContextAsync(); + var query = HttpUtility.ParseQueryString(context.Request.Url?.Query ?? string.Empty); + var code = query["code"]; + var error = query["error"]; + + string responseHtml = "

Authentication complete

You can close this window now.

"; + byte[] buffer = Encoding.UTF8.GetBytes(responseHtml); + context.Response.ContentLength64 = buffer.Length; + context.Response.ContentType = "text/html"; + context.Response.OutputStream.Write(buffer, 0, buffer.Length); + context.Response.Close(); + + if (!string.IsNullOrEmpty(error)) + { + Console.WriteLine($"Auth error: {error}"); + return null; + } + + if (string.IsNullOrEmpty(code)) + { + Console.WriteLine("No authorization code received"); + return null; + } + + Console.WriteLine("Authorization code received successfully."); + return code; + } + catch (Exception ex) + { + Console.WriteLine($"Error getting auth code: {ex.Message}"); + return null; + } + finally + { + if (listener.IsListening) + { + listener.Stop(); + } + } +} + +/// +/// Opens the specified URL in the default browser. +/// +/// The URL to open. +static void OpenBrowser(Uri url) +{ + try + { + var psi = new ProcessStartInfo + { + FileName = url.ToString(), + UseShellExecute = true + }; + Process.Start(psi); + } + catch (Exception ex) + { + Console.WriteLine($"Error opening browser. {ex.Message}"); + Console.WriteLine($"Please manually open this URL: {url}"); + } +} diff --git a/dotnet/samples/Demos/ModelContextProtocolPluginAuth/README.md b/dotnet/samples/Demos/ModelContextProtocolPluginAuth/README.md new file mode 100644 index 000000000000..630e5972b06b --- /dev/null +++ b/dotnet/samples/Demos/ModelContextProtocolPluginAuth/README.md @@ -0,0 +1,144 @@ +# Model Context Protocol Sample + +This example demonstrates how to use tools from a protected Model Context Protocol server with Semantic Kernel. + +MCP is an open protocol that standardizes how applications provide context to LLMs. + +For information on Model Context Protocol (MCP) please refer to the [documentation](https://modelcontextprotocol.io/introduction). + +The sample shows: + +1. How to connect to a protected MCP Server using OAuth 2.0 authentication +1. How to implement a custom OAuth authorization flow with browser-based authentication +1. Retrieve the list of tools the MCP Server makes available +1. Convert the MCP tools to Semantic Kernel functions so they can be added to a Kernel instance +1. Invoke the tools from Semantic Kernel using function calling + +## Installing Prerequisites + +- A self-signed certificate to enable HTTPS use in development, see [dotnet dev-certs](https://learn.microsoft.com/en-us/dotnet/core/tools/dotnet-dev-certs) +- .NET 9.0 or later +- A running TestOAuthServer (for OAuth authentication), see [Start the Test OAuth Server](https://github.com/modelcontextprotocol/csharp-sdk/tree/main/samples/ProtectedMCPClient#step-1-start-the-test-oauth-server) +- A running ProtectedMCPServer (for MCP services), see [Start the Protected MCP Server](https://github.com/modelcontextprotocol/csharp-sdk/tree/main/samples/ProtectedMCPClient#step-2-start-the-protected-mcp-server) + +## Configuring Secrets or Environment Variables + +The example requires credentials to access OpenAI. + +If you have set up those credentials as secrets within Secret Manager or through environment variables for other samples from the solution in which this project is found, they will be re-used. + +### To set your secrets with Secret Manager + +```text +cd dotnet/samples/Demos/ModelContextProtocolPluginAuth + +dotnet user-secrets init + +dotnet user-secrets set "OpenAI:ChatModelId" "..." +dotnet user-secrets set "OpenAI:ApiKey" "..." + "..." +``` + +### To set your secrets with environment variables + +Use these names: + +```text +# OpenAI +OpenAI__ChatModelId +OpenAI__ApiKey +``` + +## Setup and Running + +### Step 1: Start the Test OAuth Server + +First, you need to start the TestOAuthServer which provides OAuth authentication: + +```bash +cd \tests\ModelContextProtocol.TestOAuthServer +dotnet run --framework net9.0 +``` + +The OAuth server will start at `https://localhost:7029` + +### Step 2: Start the Protected MCP Server + +Next, start the ProtectedMCPServer which provides the weather tools: + +```bash +cd \samples\ProtectedMCPServer +dotnet run +``` + +The protected server will start at `http://localhost:7071` + +### Step 3: Run the ModelContextProtocolPluginAuth sample + +Finally, run this client: + +```bash +dotnet run +``` + +## What Happens + +1. The client attempts to connect to the protected MCP server at `http://localhost:7071` +2. The server responds with OAuth metadata indicating authentication is required +3. The client initiates OAuth 2.0 authorization code flow: + - Opens a browser to the authorization URL at the OAuth server + - Starts a local HTTP listener on `http://localhost:1179/callback` to receive the authorization code + - Exchanges the authorization code for an access token +4. The client uses the access token to authenticate with the MCP server +5. The client lists available tools and calls the `GetAlerts` tool for New York state + +The following diagram outlines an example OAuth flow: + +```mermaid +sequenceDiagram + participant Client as Client + participant Server as MCP Server (Resource Server) + participant AuthServer as Authorization Server + + Client->>Server: MCP request without access token + Server-->>Client: HTTP 401 Unauthorized with WWW-Authenticate header + Note over Client: Analyze and delegate tasks + Client->>Server: GET /.well-known/oauth-protected-resource + Server-->>Client: Resource metadata with authorization server URL + Note over Client: Validate RS metadata, build AS metadata URL + Client->>AuthServer: GET /.well-known/oauth-authorization-server + AuthServer-->>Client: Authorization server metadata + Note over Client,AuthServer: OAuth 2.0 authorization flow happens here + Client->>AuthServer: Token request + AuthServer-->>Client: Access token + Client->>Server: MCP request with access token + Server-->>Client: MCP response + Note over Client,Server: MCP communication continues with valid token +``` + +## OAuth Configuration + +The client is configured with: +- **Client ID**: `demo-client` +- **Client Secret**: `demo-secret` +- **Redirect URI**: `http://localhost:1179/callback` +- **OAuth Server**: `https://localhost:7029` +- **Protected Resource**: `http://localhost:7071` + +## Available Tools + +Once authenticated, the client can access weather tools including: +- **GetAlerts**: Get weather alerts for a US state +- **GetForecast**: Get weather forecast for a location (latitude/longitude) + +## Troubleshooting + +- Ensure the ASP.NET Core dev certificate is trusted. + ``` + dotnet dev-certs https --clean + dotnet dev-certs https --trust + ``` +- Ensure all three services are running in the correct order +- Check that ports 7029, 7071, and 1179 are available +- If the browser doesn't open automatically, copy the authorization URL from the console and open it manually +- Make sure to allow the OAuth server's self-signed certificate in your browser \ No newline at end of file diff --git a/dotnet/samples/Demos/OllamaFunctionCalling/OllamaFunctionCalling.csproj b/dotnet/samples/Demos/OllamaFunctionCalling/OllamaFunctionCalling.csproj index bfcc8c0afaf0..e9a03e29dc83 100644 --- a/dotnet/samples/Demos/OllamaFunctionCalling/OllamaFunctionCalling.csproj +++ b/dotnet/samples/Demos/OllamaFunctionCalling/OllamaFunctionCalling.csproj @@ -3,7 +3,7 @@ Exe net8.0 - $(NoWarn);CA2007,CA2208,CS1591,CA1024,IDE0009,IDE0055,IDE0073,IDE0211,VSTHRD111,SKEXP0001,SKEXP0070 + $(NoWarn);CA2007,CA2208,CS1591,CA1024,IDE0009,IDE0055,IDE0073,IDE0211,VSTHRD111,SKEXP0001 diff --git a/dotnet/samples/Demos/OllamaFunctionCalling/Program.cs b/dotnet/samples/Demos/OllamaFunctionCalling/Program.cs index 3d52b6ea45c1..078f372c96cf 100644 --- a/dotnet/samples/Demos/OllamaFunctionCalling/Program.cs +++ b/dotnet/samples/Demos/OllamaFunctionCalling/Program.cs @@ -1,6 +1,4 @@ -#pragma warning disable SKEXP0070 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - -using System; +using System; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Ollama; diff --git a/dotnet/samples/Demos/OnnxSimpleRAG/OnnxSimpleRAG.csproj b/dotnet/samples/Demos/OnnxSimpleRAG/OnnxSimpleRAG.csproj index 7ae69898f692..796a2966697c 100644 --- a/dotnet/samples/Demos/OnnxSimpleRAG/OnnxSimpleRAG.csproj +++ b/dotnet/samples/Demos/OnnxSimpleRAG/OnnxSimpleRAG.csproj @@ -3,7 +3,7 @@ Exe net8.0 - $(NoWarn);CA2007;CS0612;VSTHRD111;SKEXP0070;SKEXP0050;SKEXP0001 + $(NoWarn);CA2007;CS0612;VSTHRD111;SKEXP0050;SKEXP0001 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.AppHost/ProcessFramework.Aspire.AppHost.csproj b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.AppHost/ProcessFramework.Aspire.AppHost.csproj index 9310b9a042eb..504c120220e9 100644 --- a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.AppHost/ProcessFramework.Aspire.AppHost.csproj +++ b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.AppHost/ProcessFramework.Aspire.AppHost.csproj @@ -15,20 +15,15 @@ - - + false - - - + + + \ No newline at end of file diff --git a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.ProcessOrchestrator/ProcessFramework.Aspire.ProcessOrchestrator.csproj b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.ProcessOrchestrator/ProcessFramework.Aspire.ProcessOrchestrator.csproj index 7d1d3995191d..7e65e2272553 100644 --- a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.ProcessOrchestrator/ProcessFramework.Aspire.ProcessOrchestrator.csproj +++ b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.ProcessOrchestrator/ProcessFramework.Aspire.ProcessOrchestrator.csproj @@ -6,23 +6,19 @@ enable enable - $(NoWarn);CS8618,IDE0009,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0101,SKEXP0110,OPENAI001 + $(NoWarn);CS8618,IDE0009,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0080,SKEXP0101,SKEXP0110,OPENAI001 - - - - + + + \ No newline at end of file diff --git a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.SummaryAgent/ProcessFramework.Aspire.SummaryAgent.csproj b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.SummaryAgent/ProcessFramework.Aspire.SummaryAgent.csproj index 187beb78372b..e045b2db37b4 100644 --- a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.SummaryAgent/ProcessFramework.Aspire.SummaryAgent.csproj +++ b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.SummaryAgent/ProcessFramework.Aspire.SummaryAgent.csproj @@ -10,16 +10,12 @@ - - - - + + + \ No newline at end of file diff --git a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.TranslatorAgent/ProcessFramework.Aspire.TranslatorAgent.csproj b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.TranslatorAgent/ProcessFramework.Aspire.TranslatorAgent.csproj index 59be1e8a4d6a..bb4f5bf1e1da 100644 --- a/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.TranslatorAgent/ProcessFramework.Aspire.TranslatorAgent.csproj +++ b/dotnet/samples/Demos/ProcessFrameworkWithAspire/ProcessFramework.Aspire/ProcessFramework.Aspire.TranslatorAgent/ProcessFramework.Aspire.TranslatorAgent.csproj @@ -10,17 +10,13 @@ - - - - + + + \ No newline at end of file diff --git a/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Grpc/ProcessWithCloudEvents.Grpc.csproj b/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Grpc/ProcessWithCloudEvents.Grpc.csproj index b2d5022ffa34..c8e87e203414 100644 --- a/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Grpc/ProcessWithCloudEvents.Grpc.csproj +++ b/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Grpc/ProcessWithCloudEvents.Grpc.csproj @@ -5,7 +5,7 @@ enable enable - $(NoWarn);CA2007,CS1591,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CS1591,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0080,SKEXP0110 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Processes/ProcessWithCloudEvents.Processes.csproj b/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Processes/ProcessWithCloudEvents.Processes.csproj index 1fafc3012f07..e312860af9a7 100644 --- a/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Processes/ProcessWithCloudEvents.Processes.csproj +++ b/dotnet/samples/Demos/ProcessWithCloudEvents/ProcessWithCloudEvents.Processes/ProcessWithCloudEvents.Processes.csproj @@ -5,7 +5,7 @@ enable enable - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0080,SKEXP0110 diff --git a/dotnet/samples/Demos/ProcessWithDapr/ProcessWithDapr.csproj b/dotnet/samples/Demos/ProcessWithDapr/ProcessWithDapr.csproj index d1bd90408672..2912992ce565 100644 --- a/dotnet/samples/Demos/ProcessWithDapr/ProcessWithDapr.csproj +++ b/dotnet/samples/Demos/ProcessWithDapr/ProcessWithDapr.csproj @@ -5,7 +5,7 @@ enable enable - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0080,SKEXP0110 diff --git a/dotnet/samples/Demos/StructuredDataPlugin/StructuredDataPlugin.csproj b/dotnet/samples/Demos/StructuredDataPlugin/StructuredDataPlugin.csproj index de0ceee44233..97dfceb63f85 100644 --- a/dotnet/samples/Demos/StructuredDataPlugin/StructuredDataPlugin.csproj +++ b/dotnet/samples/Demos/StructuredDataPlugin/StructuredDataPlugin.csproj @@ -5,7 +5,7 @@ net8.0 enable enable - $(NoWarn),VSTHRD111,CA2007,CA5399,SKEXP0050,SKEXP0070 + $(NoWarn),VSTHRD111,CA2007,CA5399,SKEXP0050 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Demos/TelemetryWithAppInsights/TelemetryWithAppInsights.csproj b/dotnet/samples/Demos/TelemetryWithAppInsights/TelemetryWithAppInsights.csproj index ee804dbc02dc..1aac0e84f709 100644 --- a/dotnet/samples/Demos/TelemetryWithAppInsights/TelemetryWithAppInsights.csproj +++ b/dotnet/samples/Demos/TelemetryWithAppInsights/TelemetryWithAppInsights.csproj @@ -7,7 +7,7 @@ disable false - $(NoWarn);CA1024;CA1050;CA1707;CA2007;CS1591;VSTHRD111,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0001 + $(NoWarn);CA1024;CA1050;CA1707;CA2007;CS1591;VSTHRD111,SKEXP0050,SKEXP0060,SKEXP0001 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/GettingStarted/GettingStarted.csproj b/dotnet/samples/GettingStarted/GettingStarted.csproj index d8f4a6ca8316..0314496a6af6 100644 --- a/dotnet/samples/GettingStarted/GettingStarted.csproj +++ b/dotnet/samples/GettingStarted/GettingStarted.csproj @@ -7,7 +7,7 @@ true false - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0101 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 @@ -46,9 +46,7 @@ - - diff --git a/dotnet/samples/GettingStartedWithAgents/A2A/Step01_A2AAgent.cs b/dotnet/samples/GettingStartedWithAgents/A2A/Step01_A2AAgent.cs new file mode 100644 index 000000000000..a6d7ba601ba0 --- /dev/null +++ b/dotnet/samples/GettingStartedWithAgents/A2A/Step01_A2AAgent.cs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.A2A; +using SharpA2A.Core; + +namespace GettingStarted.A2A; + +/// +/// This example demonstrates similarity between using +/// and other agent types. +/// +public class Step01_A2AAgent(ITestOutputHelper output) : BaseAgentsTest(output) +{ + [Fact] + public async Task UseA2AAgent() + { + // Create an A2A agent instance + using var httpClient = CreateHttpClient(); + var client = new A2AClient(httpClient); + var cardResolver = new A2ACardResolver(httpClient); + var agentCard = await cardResolver.GetAgentCardAsync(); + Console.WriteLine(JsonSerializer.Serialize(agentCard, s_jsonSerializerOptions)); + var agent = new A2AAgent(client, agentCard); + + // Invoke the A2A agent + await foreach (AgentResponseItem response in agent.InvokeAsync("List the latest invoices for Contoso?")) + { + this.WriteAgentChatMessage(response); + } + } + + [Fact] + public async Task UseA2AAgentStreaming() + { + // Create an A2A agent instance + using var httpClient = CreateHttpClient(); + var client = new A2AClient(httpClient); + var cardResolver = new A2ACardResolver(httpClient); + var agentCard = await cardResolver.GetAgentCardAsync(); + Console.WriteLine(JsonSerializer.Serialize(agentCard, s_jsonSerializerOptions)); + var agent = new A2AAgent(client, agentCard); + + // Invoke the A2A agent + var responseItems = agent.InvokeStreamingAsync("List the latest invoices for Contoso?"); + await WriteAgentStreamMessageAsync(responseItems); + } + + #region private + private bool EnableLogging { get; set; } = false; + + private HttpClient CreateHttpClient() + { + if (this.EnableLogging) + { + var handler = new LoggingHandler(new HttpClientHandler(), this.Output); + return new HttpClient(handler) + { + BaseAddress = TestConfiguration.A2A.AgentUrl + }; + } + + return new HttpClient() + { + BaseAddress = TestConfiguration.A2A.AgentUrl + }; + } + + private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() { WriteIndented = true }; + #endregion +} diff --git a/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj b/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj index a4bdf74f3505..0abbddeba9de 100644 --- a/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj +++ b/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj @@ -9,7 +9,7 @@ true - $(NoWarn);NU1008;CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001 + $(NoWarn);NU1008;CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0101,SKEXP0110,OPENAI001 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 @@ -23,11 +23,11 @@ - + - - + + @@ -43,6 +43,7 @@ + diff --git a/dotnet/samples/GettingStartedWithAgents/OpenAIResponse/Step03_OpenAIResponseAgent_ReasoningModel.cs b/dotnet/samples/GettingStartedWithAgents/OpenAIResponse/Step03_OpenAIResponseAgent_ReasoningModel.cs new file mode 100644 index 000000000000..a88059222efb --- /dev/null +++ b/dotnet/samples/GettingStartedWithAgents/OpenAIResponse/Step03_OpenAIResponseAgent_ReasoningModel.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft. All rights reserved. +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.OpenAI; +using OpenAI.Responses; + +namespace GettingStarted.OpenAIResponseAgents; + +/// +/// This example demonstrates using . +/// +public class Step03_OpenAIResponseAgent_ReasoningModel(ITestOutputHelper output) : BaseResponsesAgentTest(output, "o4-mini") +{ + [Fact] + public async Task UseOpenAIResponseAgentWithAReasoningModelAsync() + { + // Define the agent + OpenAIResponseAgent agent = new(this.Client) + { + Name = "ResponseAgent", + Instructions = "Answer all queries with a detailed response.", + }; + + // Invoke the agent and output the response + var responseItems = agent.InvokeAsync("Which of the last four Olympic host cities has the highest average temperature?"); + await foreach (ChatMessageContent responseItem in responseItems) + { + WriteAgentChatMessage(responseItem); + } + } + + [Fact] + public async Task UseOpenAIResponseAgentWithAReasoningModelAndSummariesAsync() + { + // Define the agent + OpenAIResponseAgent agent = new(this.Client); + + // ResponseCreationOptions allows you to specify tools for the agent. + OpenAIResponseAgentInvokeOptions invokeOptions = new() + { + ResponseCreationOptions = new() + { + ReasoningOptions = new() + { + ReasoningEffortLevel = ResponseReasoningEffortLevel.High, + // This parameter cannot be used due to a known issue in the OpenAI .NET SDK. + // https://github.com/openai/openai-dotnet/issues/457 + // ReasoningSummaryVerbosity = ResponseReasoningSummaryVerbosity.Detailed, + }, + }, + }; + + // Invoke the agent and output the response + var responseItems = agent.InvokeAsync( + """ + Instructions: + - Given the React component below, change it so that nonfiction books have red + text. + - Return only the code in your reply + - Do not include any additional formatting, such as markdown code blocks + - For formatting, use four space tabs, and do not allow any lines of code to + exceed 80 columns + const books = [ + { title: 'Dune', category: 'fiction', id: 1 }, + { title: 'Frankenstein', category: 'fiction', id: 2 }, + { title: 'Moneyball', category: 'nonfiction', id: 3 }, + ]; + export default function BookList() { + const listItems = books.map(book => +
  • + {book.title} +
  • + ); + return ( +
      {listItems}
    + ); + } + """, options: invokeOptions); + await foreach (ChatMessageContent responseItem in responseItems) + { + WriteAgentChatMessage(responseItem); + } + } +} diff --git a/dotnet/samples/GettingStartedWithAgents/OpenAIResponse/Step04_OpenAIResponseAgent_Tools.cs b/dotnet/samples/GettingStartedWithAgents/OpenAIResponse/Step04_OpenAIResponseAgent_Tools.cs new file mode 100644 index 000000000000..e3d3b6b46f8c --- /dev/null +++ b/dotnet/samples/GettingStartedWithAgents/OpenAIResponse/Step04_OpenAIResponseAgent_Tools.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.ClientModel.Primitives; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.OpenAI; +using Microsoft.SemanticKernel.ChatCompletion; +using OpenAI.Files; +using OpenAI.Responses; +using OpenAI.VectorStores; +using Plugins; +using Resources; + +namespace GettingStarted.OpenAIResponseAgents; + +/// +/// This example demonstrates how to use tools during a model interaction using . +/// +public class Step04_OpenAIResponseAgent_Tools(ITestOutputHelper output) : BaseResponsesAgentTest(output) +{ + [Fact] + public async Task InvokeAgentWithFunctionToolsAsync() + { + // Define the agent + OpenAIResponseAgent agent = new(this.Client) + { + StoreEnabled = false, + }; + + // Create a plugin that defines the tools to be used by the agent. + KernelPlugin plugin = KernelPluginFactory.CreateFromType(); + var tools = plugin.Select(f => f.ToToolDefinition(plugin.Name)); + agent.Kernel.Plugins.Add(plugin); + + ICollection messages = + [ + new ChatMessageContent(AuthorRole.User, "What is the special soup and its price?"), + new ChatMessageContent(AuthorRole.User, "What is the special drink and its price?"), + ]; + foreach (ChatMessageContent message in messages) + { + WriteAgentChatMessage(message); + } + + // Invoke the agent and output the response + var responseItems = agent.InvokeAsync(messages); + await foreach (ChatMessageContent responseItem in responseItems) + { + WriteAgentChatMessage(responseItem); + } + } + + [Fact] + public async Task InvokeAgentWithWebSearchAsync() + { + // Define the agent + OpenAIResponseAgent agent = new(this.Client) + { + StoreEnabled = false, + }; + + // ResponseCreationOptions allows you to specify tools for the agent. + ResponseCreationOptions creationOptions = new(); + creationOptions.Tools.Add(ResponseTool.CreateWebSearchTool()); + OpenAIResponseAgentInvokeOptions invokeOptions = new() + { + ResponseCreationOptions = creationOptions, + }; + + // Invoke the agent and output the response + var responseItems = agent.InvokeAsync("What was a positive news story from today?", options: invokeOptions); + await foreach (ChatMessageContent responseItem in responseItems) + { + WriteAgentChatMessage(responseItem); + } + } + + [Fact] + public async Task InvokeAgentWithFileSearchAsync() + { + // Upload a file to the OpenAI File API + await using Stream stream = EmbeddedResource.ReadStream("employees.pdf")!; + OpenAIFile file = await this.FileClient.UploadFileAsync(stream, filename: "employees.pdf", purpose: FileUploadPurpose.UserData); + + // Create a vector store for the file + CreateVectorStoreOperation createStoreOp = await this.VectorStoreClient.CreateVectorStoreAsync( + waitUntilCompleted: true, + new VectorStoreCreationOptions() + { + FileIds = { file.Id }, + }); + + // Define the agent + OpenAIResponseAgent agent = new(this.Client) + { + StoreEnabled = false, + }; + + // ResponseCreationOptions allows you to specify tools for the agent. + ResponseCreationOptions creationOptions = new(); + creationOptions.Tools.Add(ResponseTool.CreateFileSearchTool([createStoreOp.VectorStoreId], null)); + OpenAIResponseAgentInvokeOptions invokeOptions = new() + { + ResponseCreationOptions = creationOptions, + }; + + // Invoke the agent and output the response + ICollection messages = + [ + new ChatMessageContent(AuthorRole.User, "Who is the youngest employee?"), + new ChatMessageContent(AuthorRole.User, "Who works in sales?"), + new ChatMessageContent(AuthorRole.User, "I have a customer request, who can help me?"), + ]; + foreach (ChatMessageContent message in messages) + { + WriteAgentChatMessage(message); + } + + // Invoke the agent and output the response + var responseItems = agent.InvokeAsync(messages, options: invokeOptions); + await foreach (ChatMessageContent responseItem in responseItems) + { + WriteAgentChatMessage(responseItem); + } + + // Clean up resources + RequestOptions noThrowOptions = new() { ErrorOptions = ClientErrorBehaviors.NoThrow }; + this.FileClient.DeleteFile(file.Id, noThrowOptions); + this.VectorStoreClient.DeleteVectorStore(createStoreOp.VectorStoreId, noThrowOptions); + } +} diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step01_Concurrent.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step01_Concurrent.cs index 9709d37580db..4bc3987459d0 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step01_Concurrent.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step01_Concurrent.cs @@ -21,12 +21,12 @@ public async Task ConcurrentTaskAsync(bool streamedResponse) { // Define the agents ChatCompletionAgent physicist = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "You are an expert in physics. You answer questions from a physics perspective.", name: "Physicist", description: "An expert in physics"); ChatCompletionAgent chemist = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "You are an expert in chemistry. You answer questions from a chemistry perspective.", name: "Chemist", description: "An expert in chemistry"); diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step01a_ConcurrentWithStructuredOutput.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step01a_ConcurrentWithStructuredOutput.cs index bcb08bb8a7ff..da8d4a0b5938 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step01a_ConcurrentWithStructuredOutput.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step01a_ConcurrentWithStructuredOutput.cs @@ -25,15 +25,15 @@ public async Task ConcurrentStructuredOutputAsync() { // Define the agents ChatCompletionAgent agent1 = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "You are an expert in identifying themes in articles. Given an article, identify the main themes.", description: "An expert in identifying themes in articles"); ChatCompletionAgent agent2 = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "You are an expert in sentiment analysis. Given an article, identify the sentiment.", description: "An expert in sentiment analysis"); ChatCompletionAgent agent3 = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "You are an expert in entity recognition. Given an article, extract the entities.", description: "An expert in entity recognition"); diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step02_Sequential.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step02_Sequential.cs index 5bf6f55eeff1..6c8501113708 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step02_Sequential.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step02_Sequential.cs @@ -22,7 +22,7 @@ public async Task SequentialTaskAsync(bool streamedResponse) { // Define the agents ChatCompletionAgent analystAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Analyst", instructions: """ @@ -33,7 +33,7 @@ public async Task SequentialTaskAsync(bool streamedResponse) """, description: "A agent that extracts key concepts from a product description."); ChatCompletionAgent writerAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "copywriter", instructions: """ @@ -43,7 +43,7 @@ Output should be short (around 150 words), output just the copy as a single text """, description: "An agent that writes a marketing copy based on the extracted concepts."); ChatCompletionAgent editorAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "editor", instructions: """ diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step02a_SequentialCancellation.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step02a_SequentialCancellation.cs index 0c55ae7e4299..c67be1678877 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step02a_SequentialCancellation.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step02a_SequentialCancellation.cs @@ -17,7 +17,7 @@ public async Task SequentialCancelledAsync() { // Define the agents ChatCompletionAgent agent = - this.CreateAgent( + this.CreateChatCompletionAgent( """ If the input message is a number, return the number incremented by one. """, diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03_GroupChat.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03_GroupChat.cs index 775f1a73d289..712eab38d175 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03_GroupChat.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03_GroupChat.cs @@ -27,7 +27,7 @@ public async Task GroupChatAsync(bool streamedResponse) { // Define the agents ChatCompletionAgent writer = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "CopyWriter", description: "A copy writer", instructions: @@ -40,7 +40,7 @@ Only provide a single proposal per response. Consider suggestions when refining an idea. """); ChatCompletionAgent editor = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Reviewer", description: "An editor.", instructions: @@ -73,7 +73,7 @@ Consider suggestions when refining an idea. InProcessRuntime runtime = new(); await runtime.StartAsync(); - string input = "Create a slogon for a new eletric SUV that is affordable and fun to drive."; + string input = "Create a slogan for a new electric SUV that is affordable and fun to drive."; Console.WriteLine($"\n# INPUT: {input}\n"); OrchestrationResult result = await orchestration.InvokeAsync(input, runtime); string text = await result.GetValueAsync(TimeSpan.FromSeconds(ResultTimeoutInSeconds * 3)); diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03a_GroupChatWithHumanInTheLoop.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03a_GroupChatWithHumanInTheLoop.cs index f7c815e37323..ec40e603219c 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03a_GroupChatWithHumanInTheLoop.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03a_GroupChatWithHumanInTheLoop.cs @@ -19,7 +19,7 @@ public async Task GroupChatWithHumanAsync() { // Define the agents ChatCompletionAgent writer = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "CopyWriter", description: "A copy writer", instructions: @@ -32,7 +32,7 @@ Only provide a single proposal per response. Consider suggestions when refining an idea. """); ChatCompletionAgent editor = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Reviewer", description: "An editor.", instructions: @@ -73,7 +73,7 @@ Consider suggestions when refining an idea. await runtime.StartAsync(); // Run the orchestration - string input = "Create a slogon for a new eletric SUV that is affordable and fun to drive."; + string input = "Create a slogan for a new electric SUV that is affordable and fun to drive."; Console.WriteLine($"\n# INPUT: {input}\n"); OrchestrationResult result = await orchestration.InvokeAsync(input, runtime); string text = await result.GetValueAsync(TimeSpan.FromSeconds(ResultTimeoutInSeconds * 3)); diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03b_GroupChatWithAIManager.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03b_GroupChatWithAIManager.cs index 6586242fc9c7..950b6f8303f7 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03b_GroupChatWithAIManager.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step03b_GroupChatWithAIManager.cs @@ -23,7 +23,7 @@ public async Task GroupChatWithAIManagerAsync() { // Define the agents ChatCompletionAgent farmer = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Farmer", description: "A rural farmer from Southeast Asia.", instructions: @@ -34,7 +34,7 @@ You value tradition and sustainability. You are in a debate. Feel free to challenge the other participants with respect. """); ChatCompletionAgent developer = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Developer", description: "An urban software developer from the United States.", instructions: @@ -45,7 +45,7 @@ Your life is fast-paced and technology-driven. You are in a debate. Feel free to challenge the other participants with respect. """); ChatCompletionAgent teacher = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Teacher", description: "A retired history teacher from Eastern Europe", instructions: @@ -56,7 +56,7 @@ You bring historical and philosophical perspectives to discussions. You are in a debate. Feel free to challenge the other participants with respect. """); ChatCompletionAgent activist = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Activist", description: "A young activist from South America.", instructions: @@ -66,7 +66,7 @@ You are in a debate. Feel free to challenge the other participants with respect. You are in a debate. Feel free to challenge the other participants with respect. """); ChatCompletionAgent spiritual = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "SpiritualLeader", description: "A spiritual leader from the Middle East.", instructions: @@ -76,7 +76,7 @@ You are in a debate. Feel free to challenge the other participants with respect. You are in a debate. Feel free to challenge the other participants with respect. """); ChatCompletionAgent artist = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Artist", description: "An artist from Africa.", instructions: @@ -86,7 +86,7 @@ You are in a debate. Feel free to challenge the other participants with respect. You are in a debate. Feel free to challenge the other participants with respect. """); ChatCompletionAgent immigrant = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Immigrant", description: "An immigrant entrepreneur from Asia living in Canada.", instructions: @@ -97,7 +97,7 @@ You balance trandition with adaption. You are in a debate. Feel free to challenge the other participants with respect. """); ChatCompletionAgent doctor = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "Doctor", description: "A doctor from Scandinavia.", instructions: diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step04_Handoff.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step04_Handoff.cs index d8ad1791b425..7f2a00fb9abd 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step04_Handoff.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step04_Handoff.cs @@ -23,24 +23,24 @@ public async Task OrderSupportAsync(bool streamedResponse) { // Define the agents & tools ChatCompletionAgent triageAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "A customer support agent that triages issues.", name: "TriageAgent", description: "Handle customer requests."); ChatCompletionAgent statusAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "OrderStatusAgent", instructions: "Handle order status requests.", description: "A customer support agent that checks order status."); statusAgent.Kernel.Plugins.Add(KernelPluginFactory.CreateFromObject(new OrderStatusPlugin())); ChatCompletionAgent returnAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "OrderReturnAgent", instructions: "Handle order return requests.", description: "A customer support agent that handles order returns."); returnAgent.Kernel.Plugins.Add(KernelPluginFactory.CreateFromObject(new OrderReturnPlugin())); ChatCompletionAgent refundAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "OrderRefundAgent", instructions: "Handle order refund requests.", description: "A customer support agent that handles order refund."); diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step04a_HandoffWithStructuredInput.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step04a_HandoffWithStructuredInput.cs index 9b7bce047f7c..9c56f68cd583 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step04a_HandoffWithStructuredInput.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step04a_HandoffWithStructuredInput.cs @@ -23,18 +23,18 @@ public async Task HandoffStructuredInputAsync() // Define the agents ChatCompletionAgent triageAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "Given a GitHub issue, triage it.", name: "TriageAgent", description: "An agent that triages GitHub issues"); ChatCompletionAgent pythonAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "You are an agent that handles Python related GitHub issues.", name: "PythonAgent", description: "An agent that handles Python related issues"); pythonAgent.Kernel.Plugins.Add(plugin); ChatCompletionAgent dotnetAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( instructions: "You are an agent that handles .NET related GitHub issues.", name: "DotNetAgent", description: "An agent that handles .NET related issues"); diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step05_Magentic.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step05_Magentic.cs index 98ab4bf21996..c2520a3d5a94 100644 --- a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step05_Magentic.cs +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step05_Magentic.cs @@ -36,7 +36,7 @@ public async Task MagenticTaskAsync(bool streamedResponse) // Define the agents Kernel researchKernel = CreateKernelWithOpenAIChatCompletion(ResearcherModel); ChatCompletionAgent researchAgent = - this.CreateAgent( + this.CreateChatCompletionAgent( name: "ResearchAgent", description: "A helpful assistant with access to web search. Ask it to perform web searches.", instructions: "You are a Researcher. You find information without additional computation or quantitative analysis.", @@ -86,7 +86,7 @@ I am preparing a report on the energy efficiency of different machine learning m """; Console.WriteLine($"\n# INPUT:\n{input}\n"); OrchestrationResult result = await orchestration.InvokeAsync(input, runtime); - string text = await result.GetValueAsync(TimeSpan.FromSeconds(ResultTimeoutInSeconds * 10)); + string text = await result.GetValueAsync(TimeSpan.FromSeconds(ResultTimeoutInSeconds * 20)); Console.WriteLine($"\n# RESULT: {text}"); await runtime.RunUntilIdleAsync(); diff --git a/dotnet/samples/GettingStartedWithAgents/Orchestration/Step06_DifferentAgentTypes.cs b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step06_DifferentAgentTypes.cs new file mode 100644 index 000000000000..99671aec5a85 --- /dev/null +++ b/dotnet/samples/GettingStartedWithAgents/Orchestration/Step06_DifferentAgentTypes.cs @@ -0,0 +1,307 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.Magentic; +using Microsoft.SemanticKernel.Agents.Orchestration; +using Microsoft.SemanticKernel.Agents.Orchestration.Concurrent; +using Microsoft.SemanticKernel.Agents.Orchestration.GroupChat; +using Microsoft.SemanticKernel.Agents.Orchestration.Handoff; +using Microsoft.SemanticKernel.Agents.Orchestration.Sequential; +using Microsoft.SemanticKernel.Agents.Runtime.InProcess; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace GettingStarted.Orchestration; + +/// +/// Demonstrates how to use the with two agents: +/// - A Research agent that can perform web searches +/// - A Coder agent that can run code using the code interpreter +/// +public class Step06_DifferentAgentTypes(ITestOutputHelper output) : BaseOrchestrationTest(output) +{ + [Fact] + public async Task ConcurrentOrchestrationAsync() + { + // Define the agents + Agent physicist = + this.CreateChatCompletionAgent( + instructions: "You are an expert in physics. You answer questions from a physics perspective.", + name: "Physicist", + description: "An expert in physics"); + Agent chemist = + await this.CreateAzureAIAgentAsync( + instructions: "You are an expert in chemistry. You answer questions from a chemistry perspective.", + name: "Chemist", + description: "An expert in chemistry"); + + // Create a monitor to capturing agent responses (via ResponseCallback) + // to display at the end of this sample. (optional) + // NOTE: Create your own callback to capture responses in your application or service. + OrchestrationMonitor monitor = new(); + + // Define the orchestration + ConcurrentOrchestration orchestration = + new(physicist, chemist) + { + LoggerFactory = this.LoggerFactory, + ResponseCallback = monitor.ResponseCallback, + }; + + // Start the runtime + InProcessRuntime runtime = new(); + await runtime.StartAsync(); + + // Run the orchestration + string input = "What is temperature?"; + Console.WriteLine($"\n# INPUT: {input}\n"); + OrchestrationResult result = await orchestration.InvokeAsync(input, runtime); + + string[] output = await result.GetValueAsync(TimeSpan.FromSeconds(ResultTimeoutInSeconds)); + Console.WriteLine($"\n# RESULT:\n{string.Join("\n\n", output.Select(text => $"{text}"))}"); + + await runtime.RunUntilIdleAsync(); + + Console.WriteLine("\n\nORCHESTRATION HISTORY"); + foreach (ChatMessageContent message in monitor.History) + { + this.WriteAgentChatMessage(message); + } + } + + [Fact] + public async Task SequentialOrchestrationAsync() + { + // Define the agents + Agent analystAgent = + this.CreateChatCompletionAgent( + name: "Analyst", + instructions: + """ + You are a marketing analyst. Given a product description, identify: + - Key features + - Target audience + - Unique selling points + """, + description: "A agent that extracts key concepts from a product description."); + Agent writerAgent = + await this.CreateOpenAIAssistantAgentAsync( + name: "copywriter", + instructions: + """ + You are a marketing copywriter. Given a block of text describing features, audience, and USPs, + compose a compelling marketing copy (like a newsletter section) that highlights these points. + Output should be short (around 150 words), output just the copy as a single text block. + """, + description: "An agent that writes a marketing copy based on the extracted concepts."); + Agent editorAgent = + await this.CreateAzureAIAgentAsync( + name: "editor", + instructions: + """ + You are an editor. Given the draft copy, correct grammar, improve clarity, ensure consistent tone, + give format and make it polished. Output the final improved copy as a single text block. + """, + description: "An agent that formats and proofreads the marketing copy."); + + // Create a monitor to capturing agent responses (via ResponseCallback) + // to display at the end of this sample. (optional) + // NOTE: Create your own callback to capture responses in your application or service. + OrchestrationMonitor monitor = new(); + // Define the orchestration + SequentialOrchestration orchestration = + new(analystAgent, writerAgent, editorAgent) + { + LoggerFactory = this.LoggerFactory, + ResponseCallback = monitor.ResponseCallback, + }; + + // Start the runtime + InProcessRuntime runtime = new(); + await runtime.StartAsync(); + + // Run the orchestration + string input = "An eco-friendly stainless steel water bottle that keeps drinks cold for 24 hours"; + Console.WriteLine($"\n# INPUT: {input}\n"); + OrchestrationResult result = await orchestration.InvokeAsync(input, runtime); + string text = await result.GetValueAsync(TimeSpan.FromSeconds(ResultTimeoutInSeconds * 2)); + Console.WriteLine($"\n# RESULT: {text}"); + + await runtime.RunUntilIdleAsync(); + + Console.WriteLine("\n\nORCHESTRATION HISTORY"); + foreach (ChatMessageContent message in monitor.History) + { + this.WriteAgentChatMessage(message); + } + } + + [Fact] + public async Task GroupChatOrchestrationAsync() + { + // Define the agents + Agent writer = + this.CreateChatCompletionAgent( + name: "CopyWriter", + description: "A copy writer", + instructions: + """ + You are a copywriter with ten years of experience and are known for brevity and a dry humor. + The goal is to refine and decide on the single best copy as an expert in the field. + Only provide a single proposal per response. + You're laser focused on the goal at hand. + Don't waste time with chit chat. + Consider suggestions when refining an idea. + """); + Agent editor = + await this.CreateOpenAIAssistantAgentAsync( + name: "Reviewer", + description: "An editor.", + instructions: + """ + You are an art director who has opinions about copywriting born of a love for David Ogilvy. + The goal is to determine if the given copy is acceptable to print. + If so, state that it is approved. + If not, provide insight on how to refine suggested copy without example. + """); + + // Create a monitor to capturing agent responses (via ResponseCallback) + // to display at the end of this sample. (optional) + // NOTE: Create your own callback to capture responses in your application or service. + OrchestrationMonitor monitor = new(); + // Define the orchestration + GroupChatOrchestration orchestration = + new(new RoundRobinGroupChatManager() + { + MaximumInvocationCount = 5 + }, + writer, + editor) + { + LoggerFactory = this.LoggerFactory, + ResponseCallback = monitor.ResponseCallback, + }; + + // Start the runtime + InProcessRuntime runtime = new(); + await runtime.StartAsync(); + + string input = "Create a slogan for a new electric SUV that is affordable and fun to drive."; + Console.WriteLine($"\n# INPUT: {input}\n"); + OrchestrationResult result = await orchestration.InvokeAsync(input, runtime); + string text = await result.GetValueAsync(TimeSpan.FromSeconds(ResultTimeoutInSeconds * 3)); + Console.WriteLine($"\n# RESULT: {text}"); + + await runtime.RunUntilIdleAsync(); + + Console.WriteLine("\n\nORCHESTRATION HISTORY"); + foreach (ChatMessageContent message in monitor.History) + { + this.WriteAgentChatMessage(message); + } + } + + [Fact] + public async Task HandoffOrchestrationAsync() + { + // Define the agents & tools + Agent triageAgent = + this.CreateChatCompletionAgent( + instructions: "A customer support agent that triages issues.", + name: "TriageAgent", + description: "Handle customer requests."); + Agent statusAgent = + this.CreateChatCompletionAgent( + name: "OrderStatusAgent", + instructions: "Handle order status requests.", + description: "A customer support agent that checks order status."); + statusAgent.Kernel.Plugins.Add(KernelPluginFactory.CreateFromObject(new OrderStatusPlugin())); + Agent returnAgent = + this.CreateChatCompletionAgent( + name: "OrderReturnAgent", + instructions: "Handle order return requests.", + description: "A customer support agent that handles order returns."); + returnAgent.Kernel.Plugins.Add(KernelPluginFactory.CreateFromObject(new OrderReturnPlugin())); + Agent refundAgent = + this.CreateChatCompletionAgent( + name: "OrderRefundAgent", + instructions: "Handle order refund requests.", + description: "A customer support agent that handles order refund."); + refundAgent.Kernel.Plugins.Add(KernelPluginFactory.CreateFromObject(new OrderRefundPlugin())); + + // Create a monitor to capturing agent responses (via ResponseCallback) + // to display at the end of this sample. (optional) + // NOTE: Create your own callback to capture responses in your application or service. + OrchestrationMonitor monitor = new(); + // Define user responses for InteractiveCallback (since sample is not interactive) + Queue responses = new(); + string task = "I am a customer that needs help with my orders"; + responses.Enqueue("I'd like to track the status of my order"); + responses.Enqueue("My order ID is 123"); + responses.Enqueue("I want to return another order of mine"); + responses.Enqueue("Order ID 321"); + responses.Enqueue("Broken item"); + responses.Enqueue("No, bye"); + // Define the orchestration + HandoffOrchestration orchestration = + new(OrchestrationHandoffs + .StartWith(triageAgent) + .Add(triageAgent, statusAgent, returnAgent, refundAgent) + .Add(statusAgent, triageAgent, "Transfer to this agent if the issue is not status related") + .Add(returnAgent, triageAgent, "Transfer to this agent if the issue is not return related") + .Add(refundAgent, triageAgent, "Transfer to this agent if the issue is not refund related"), + triageAgent, + statusAgent, + returnAgent, + refundAgent) + { + InteractiveCallback = () => + { + string input = responses.Dequeue(); + Console.WriteLine($"\n# INPUT: {input}\n"); + return ValueTask.FromResult(new ChatMessageContent(AuthorRole.User, input)); + }, + LoggerFactory = this.LoggerFactory, + ResponseCallback = monitor.ResponseCallback, + }; + + // Start the runtime + InProcessRuntime runtime = new(); + await runtime.StartAsync(); + + // Run the orchestration + Console.WriteLine($"\n# INPUT:\n{task}\n"); + OrchestrationResult result = await orchestration.InvokeAsync(task, runtime); + + string text = await result.GetValueAsync(TimeSpan.FromSeconds(ResultTimeoutInSeconds * 10)); + Console.WriteLine($"\n# RESULT: {text}"); + + await runtime.RunUntilIdleAsync(); + + Console.WriteLine("\n\nORCHESTRATION HISTORY"); + foreach (ChatMessageContent message in monitor.History) + { + this.WriteAgentChatMessage(message); + } + } + + #region private + private sealed class OrderStatusPlugin + { + [KernelFunction] + public string CheckOrderStatus(string orderId) => $"Order {orderId} is shipped and will arrive in 2-3 days."; + } + + private sealed class OrderReturnPlugin + { + [KernelFunction] + public string ProcessReturn(string orderId, string reason) => $"Return for order {orderId} has been processed successfully."; + } + + private sealed class OrderRefundPlugin + { + [KernelFunction] + public string ProcessReturn(string orderId, string reason) => $"Refund for order {orderId} has been processed successfully."; + } + #endregion +} diff --git a/dotnet/samples/GettingStartedWithAgents/Step01_Agent.cs b/dotnet/samples/GettingStartedWithAgents/Step01_Agent.cs index e62b869c0af1..80f794e28a3d 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step01_Agent.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step01_Agent.cs @@ -27,7 +27,7 @@ public class Step01_Agent(ITestOutputHelper output) : BaseAgentsTest(output) [InlineData(false)] public async Task UseSingleChatCompletionAgent(bool useChatClient) { - Kernel kernel = this.CreateKernelWithChatCompletion(); + Kernel kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient); // Define the agent ChatCompletionAgent agent = @@ -35,7 +35,7 @@ public async Task UseSingleChatCompletionAgent(bool useChatClient) { Name = ParrotName, Instructions = ParrotInstructions, - Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), + Kernel = kernel }; // Respond to user input diff --git a/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj b/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj index 5e92c4906685..7244d2cb967b 100644 --- a/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj +++ b/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj @@ -10,7 +10,7 @@ - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0081,SKEXP0101,SKEXP0110,OPENAI001 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0080,SKEXP0081,SKEXP0101,SKEXP0110,OPENAI001 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/GettingStartedWithProcesses/Step06/Step06_FoundryAgentProcess.cs b/dotnet/samples/GettingStartedWithProcesses/Step06/Step06_FoundryAgentProcess.cs deleted file mode 100644 index 7465fe5610c9..000000000000 --- a/dotnet/samples/GettingStartedWithProcesses/Step06/Step06_FoundryAgentProcess.cs +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.ClientModel; -using System.Text; -using Azure.AI.Agents.Persistent; -using Azure.Identity; -using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Agents.AzureAI; -using Microsoft.SemanticKernel.Agents.OpenAI; -using OpenAI; - -namespace Step06; -public class Step06_FoundryAgentProcess : BaseTest -{ - public Step06_FoundryAgentProcess(ITestOutputHelper output) : base(output, redirectSystemConsoleOutput: true) - { - this.Client = - this.UseOpenAIConfig ? - OpenAIAssistantAgent.CreateOpenAIClient(new ApiKeyCredential(this.ApiKey ?? throw new ConfigurationNotFoundException("OpenAI:ApiKey"))) : - !string.IsNullOrWhiteSpace(this.ApiKey) ? - OpenAIAssistantAgent.CreateAzureOpenAIClient(new ApiKeyCredential(this.ApiKey), new Uri(this.Endpoint!)) : - OpenAIAssistantAgent.CreateAzureOpenAIClient(new AzureCliCredential(), new Uri(this.Endpoint!)); - } - - protected OpenAIClient Client { get; init; } - - // Target Open AI Services - protected override bool ForceOpenAI => true; - - /// - /// This example demonstrates how to create a process with two agents that can chat with each other. A student agent and a teacher agent are created. - /// The process will keep track of the interaction count between the two agents. - /// - [Fact] - public async Task ProcessWithTwoAgentMathChat() - { - var endpoint = TestConfiguration.AzureAI.Endpoint; - PersistentAgentsClient client = new(endpoint.TrimEnd('/'), new DefaultAzureCredential(), new PersistentAgentsAdministrationClientOptions().WithPolicy(endpoint, "2025-05-15-preview")); - - Azure.Response? studentAgent = null; - Azure.Response? teacherAgent = null; - - try - { - // Create the single agents - studentAgent = await client.Administration.CreateAgentAsync( - model: "gpt-4o", - name: "Student", - instructions: "You are a student that answer question from teacher, when teacher gives you question you answer them." - ); - - teacherAgent = await client.Administration.CreateAgentAsync( - model: "gpt-4o", - name: "Teacher", - instructions: "You are a teacher that create pre-school math question for student and check answer.\nIf the answer is correct, you stop the conversation by saying [COMPLETE].\nIf the answer is wrong, you ask student to fix it." - ); - - // Define the process with a state type - var processBuilder = new FoundryProcessBuilder("two_agent_math_chat"); - - // Create a thread for the student - processBuilder.AddThread("Student", KernelProcessThreadLifetime.Scoped); - processBuilder.AddThread("Teacher", KernelProcessThreadLifetime.Scoped); - - // Add the student - var student = processBuilder.AddStepFromAgent(studentAgent); - - // Add the teacher - var teacher = processBuilder.AddStepFromAgent(teacherAgent); - - /**************************** Orchestrate ***************************/ - - // When the process starts, activate the student agent - processBuilder.OnProcessEnter().SendEventTo( - student, - thread: "_variables_.Student", - messagesIn: ["_variables_.TeacherMessages"], - inputs: new Dictionary { }); - - // When the student agent exits, update the process state to save the student's messages and update interaction counts - processBuilder.OnStepExit(student) - .UpdateProcessState(path: "StudentMessages", operation: StateUpdateOperations.Set, value: "_agent_.messages_out"); - - // When the student agent is finished, send the messages to the teacher agent - processBuilder.OnEvent(student, "_default_") - .SendEventTo(teacher, messagesIn: ["_variables_.StudentMessages"], thread: "Teacher"); - - // When the teacher agent exits with a message containing '[COMPLETE]', update the process state to save the teacher's messages and update interaction counts and emit the `correct_answer` event - processBuilder.OnStepExit(teacher, condition: "jmespath(contains(to_string(_agent_.messages_out), '[COMPLETE]'))") - .EmitEvent( - eventName: "correct_answer", - payload: new Dictionary - { - { "Question", "_variables_.TeacherMessages" }, - { "Answer", "_variables_.StudentMessages" } - }) - .UpdateProcessState(path: "_variables_.TeacherMessages", operation: StateUpdateOperations.Set, value: "_agent_.messages_out"); - - // When the teacher agent exits with a message not containing '[COMPLETE]', update the process state to save the teacher's messages and update interaction counts - processBuilder.OnStepExit(teacher, condition: "_default_") - .UpdateProcessState(path: "_variables_.TeacherMessages", operation: StateUpdateOperations.Set, value: "_agent_.messages_out"); - - // When the teacher agent is finished, send the messages to the student agent - processBuilder.OnEvent(teacher, "_default_", condition: "_default_") - .SendEventTo(student, messagesIn: ["_variables_.TeacherMessages"], thread: "Student"); - - // When the teacher agent emits the `correct_answer` event, stop the process - processBuilder.OnEvent(teacher, "correct_answer") - .StopProcess(); - - // Verify that the process can be built and serialized to json - var processJson = await processBuilder.ToJsonAsync(); - Assert.NotEmpty(processJson); - - var content = await RunWorkflowAsync(client, processBuilder, [new(MessageRole.User, "Go")]); - Assert.NotEmpty(content); - } - finally - { - // Clean up the agents - await client.Administration.DeleteAgentAsync(studentAgent?.Value.Id); - await client.Administration.DeleteAgentAsync(teacherAgent?.Value.Id); - } - } - - private async Task RunWorkflowAsync(PersistentAgentsClient client, FoundryProcessBuilder processBuilder, List? initialMessages = null) where T : class, new() - { - Workflow? workflow = null; - StringBuilder output = new(); - - try - { - // publish the workflow - workflow = await client.Administration.Pipeline.PublishWorkflowAsync(processBuilder); - - // threadId is used to store the thread ID - PersistentAgentThread thread = await client.Threads.CreateThreadAsync(messages: initialMessages ?? []); - - // create run - await foreach (var run in client.Runs.CreateRunStreamingAsync(thread.Id, workflow.Id)) - { - if (run is Azure.AI.Agents.Persistent.MessageContentUpdate contentUpdate) - { - output.Append(contentUpdate.Text); - Console.Write(contentUpdate.Text); - } - else if (run is Azure.AI.Agents.Persistent.RunUpdate runUpdate) - { - if (runUpdate.UpdateKind == Azure.AI.Agents.Persistent.StreamingUpdateReason.RunInProgress && !runUpdate.Value.Id.StartsWith("wf_run", StringComparison.OrdinalIgnoreCase)) - { - Console.WriteLine(); - Console.Write($"{runUpdate.Value.Metadata["x-agent-name"]}> "); - } - } - } - - // delete thread, so we can start over - Console.WriteLine($"\nDeleting thread {thread?.Id}..."); - await client.Threads.DeleteThreadAsync(thread?.Id); - return output.ToString(); - } - finally - { - // // delete workflow - Console.WriteLine($"Deleting workflow {workflow?.Id}..."); - await client.Administration.Pipeline.DeleteWorkflowAsync(workflow!); - } - } - - /// - /// Represents the state of the two-agent math chat process. - /// - public class TwoAgentMathState - { - public List StudentMessages { get; set; } - - public List TeacherMessages { get; set; } - - public StudentState StudentState { get; set; } = new(); - - public int InteractionCount { get; set; } - } - - /// - /// Represents the state of the student agent. - /// - public class StudentState - { - public int InteractionCount { get; set; } - - public string Name { get; set; } - } -} diff --git a/dotnet/samples/GettingStartedWithTextSearch/GettingStartedWithTextSearch.csproj b/dotnet/samples/GettingStartedWithTextSearch/GettingStartedWithTextSearch.csproj index baa7fc9bee3e..17f545356906 100644 --- a/dotnet/samples/GettingStartedWithTextSearch/GettingStartedWithTextSearch.csproj +++ b/dotnet/samples/GettingStartedWithTextSearch/GettingStartedWithTextSearch.csproj @@ -7,7 +7,7 @@ true false - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0101 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 @@ -33,9 +33,7 @@ - -
    diff --git a/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj b/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj index b133579aed1a..19c4e8d63cf9 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj +++ b/dotnet/samples/GettingStartedWithVectorStores/GettingStartedWithVectorStores.csproj @@ -7,7 +7,7 @@ true false - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0101 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/src/Agents/A2A/A2AAgent.cs b/dotnet/src/Agents/A2A/A2AAgent.cs new file mode 100644 index 000000000000..c831a9a89f79 --- /dev/null +++ b/dotnet/src/Agents/A2A/A2AAgent.cs @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; +using SharpA2A.Core; + +namespace Microsoft.SemanticKernel.Agents.A2A; + +/// +/// Provides a specialized based on the A2A Protocol. +/// +public sealed class A2AAgent : Agent +{ + /// + /// Initializes a new instance of the class. + /// + /// instance to associate with the agent. + /// instance associated ith the agent. + public A2AAgent(A2AClient client, AgentCard agentCard) + { + Verify.NotNull(client); + Verify.NotNull(agentCard); + + this.Client = client; + this.AgentCard = agentCard; + this.Name = agentCard.Name; + this.Description = agentCard.Description; + } + + /// + /// The associated client. + /// + public A2AClient Client { get; } + + /// + /// The associated agent card. + /// + public AgentCard AgentCard { get; } + + /// + public override async IAsyncEnumerable> InvokeAsync(ICollection messages, AgentThread? thread = null, AgentInvokeOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(messages); + + var agentThread = await this.EnsureThreadExistsWithMessagesAsync( + messages, + thread, + () => new A2AAgentThread(this.Client), + cancellationToken).ConfigureAwait(false); + + // Invoke the agent. + var invokeResults = this.InternalInvokeAsync( + this.AgentCard.Name, + messages, + agentThread, + options ?? new AgentInvokeOptions(), + cancellationToken); + + // Notify the thread of new messages and return them to the caller. + await foreach (var result in invokeResults.ConfigureAwait(false)) + { + await this.NotifyThreadOfNewMessage(agentThread, result, cancellationToken).ConfigureAwait(false); + yield return new(result, agentThread); + } + } + + /// + public override async IAsyncEnumerable> InvokeStreamingAsync(ICollection messages, AgentThread? thread = null, AgentInvokeOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(messages); + + var agentThread = await this.EnsureThreadExistsWithMessagesAsync( + messages, + thread, + () => new A2AAgentThread(this.Client), + cancellationToken).ConfigureAwait(false); + + // Invoke the agent. + var chatMessages = new ChatHistory(); + var invokeResults = this.InternalInvokeStreamingAsync( + messages, + agentThread, + options ?? new AgentInvokeOptions(), + chatMessages, + cancellationToken); + + // Return the chunks to the caller. + await foreach (var result in invokeResults.ConfigureAwait(false)) + { + yield return new(result, agentThread); + } + + // Notify the thread of any new messages that were assembled from the streaming response. + foreach (var chatMessage in chatMessages) + { + await this.NotifyThreadOfNewMessage(agentThread, chatMessage, cancellationToken).ConfigureAwait(false); + + if (options?.OnIntermediateMessage is not null) + { + await options.OnIntermediateMessage(chatMessage).ConfigureAwait(false); + } + } + } + + /// + protected override Task CreateChannelAsync(CancellationToken cancellationToken) + { + throw new NotSupportedException($"{nameof(A2AAgent)} is not for use with {nameof(AgentChat)}."); + } + + /// + protected override IEnumerable GetChannelKeys() + { + throw new NotSupportedException($"{nameof(A2AAgent)} is not for use with {nameof(AgentChat)}."); + } + + /// + protected override Task RestoreChannelAsync(string channelState, CancellationToken cancellationToken) + { + throw new NotSupportedException($"{nameof(A2AAgent)} is not for use with {nameof(AgentChat)}."); + } + + #region private + private async IAsyncEnumerable> InternalInvokeAsync(string name, ICollection messages, A2AAgentThread thread, AgentInvokeOptions options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + Verify.NotNull(messages); + + // Ensure all messages have the correct role. + if (!messages.All(m => m.Role == AuthorRole.User)) + { + throw new ArgumentException($"All messages must have the role {AuthorRole.User}.", nameof(messages)); + } + + // Send all messages to the remote agent in a single request. + await foreach (var result in this.InvokeAgentAsync(messages, thread, options, cancellationToken).ConfigureAwait(false)) + { + await this.NotifyThreadOfNewMessage(thread, result, cancellationToken).ConfigureAwait(false); + yield return new(result, thread); + } + } + + private async IAsyncEnumerable> InvokeAgentAsync(ICollection messages, A2AAgentThread thread, AgentInvokeOptions options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + List parts = []; + foreach (var message in messages) + { + foreach (var item in message.Items) + { + if (item is TextContent textContent) + { + parts.Add(new TextPart + { + Text = textContent.Text ?? string.Empty, + }); + } + else + { + throw new NotSupportedException($"Unsupported content type: {item.GetType().Name}. Only TextContent are supported."); + } + } + } + + var messageSendParams = new MessageSendParams + { + Message = new Message + { + MessageId = Guid.NewGuid().ToString(), + Role = MessageRole.User, + Parts = parts, + } + }; + + A2AResponse response = await this.Client.SendMessageAsync(messageSendParams).ConfigureAwait(false); + if (response is AgentTask agentTask) + { + if (agentTask.Artifacts != null && agentTask.Artifacts.Count > 0) + { + foreach (var artifact in agentTask.Artifacts) + { + foreach (var part in artifact.Parts) + { + if (part is TextPart textPart) + { + yield return new AgentResponseItem(new ChatMessageContent(AuthorRole.Assistant, textPart.Text), thread); + } + } + } + Console.WriteLine(); + } + } + else if (response is Message messageResponse) + { + foreach (var part in messageResponse.Parts) + { + if (part is TextPart textPart) + { + yield return new AgentResponseItem( + new ChatMessageContent( + AuthorRole.Assistant, + textPart.Text), + thread); + } + } + } + else + { + throw new InvalidOperationException("Unexpected response type from A2A client."); + } + } + + private async IAsyncEnumerable> InternalInvokeStreamingAsync(ICollection messages, A2AAgentThread thread, AgentInvokeOptions options, ChatHistory chatMessages, [EnumeratorCancellation] CancellationToken cancellationToken) + { + Verify.NotNull(messages); + + // Ensure all messages have the correct role. + if (messages.Any(m => m.Role != AuthorRole.User)) + { + throw new ArgumentException($"All messages must have the role {AuthorRole.User}.", nameof(messages)); + } + + // Send all messages to the remote agent in a single request. + await foreach (var result in this.InvokeAgentAsync(messages, thread, options, cancellationToken).ConfigureAwait(false)) + { + await this.NotifyThreadOfNewMessage(thread, result, cancellationToken).ConfigureAwait(false); + yield return new(this.ToStreamingAgentResponseItem(result), thread); + } + } + + private AgentResponseItem ToStreamingAgentResponseItem(AgentResponseItem responseItem) + { + var messageContent = new StreamingChatMessageContent( + responseItem.Message.Role, + responseItem.Message.Content, + innerContent: responseItem.Message.InnerContent, + modelId: responseItem.Message.ModelId, + encoding: responseItem.Message.Encoding, + metadata: responseItem.Message.Metadata); + + return new AgentResponseItem(messageContent, responseItem.Thread); + } + #endregion +} diff --git a/dotnet/src/Agents/A2A/A2AAgentThread.cs b/dotnet/src/Agents/A2A/A2AAgentThread.cs new file mode 100644 index 000000000000..3c10a468fcdb --- /dev/null +++ b/dotnet/src/Agents/A2A/A2AAgentThread.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading; +using System.Threading.Tasks; +using SharpA2A.Core; + +namespace Microsoft.SemanticKernel.Agents.A2A; + +/// +/// Represents a conversation thread for an A2A agent. +/// +public sealed class A2AAgentThread : AgentThread +{ + /// + /// Initializes a new instance of the class that resumes an existing thread. + /// + /// The agents client to use for interacting with threads. + /// The ID of an existing thread to resume. + public A2AAgentThread(A2AClient client, string? id = null) + { + Verify.NotNull(client); + + this._client = client; + this.Id = id ?? Guid.NewGuid().ToString("N"); + } + + /// + protected override Task CreateInternalAsync(CancellationToken cancellationToken) + { + return Task.FromResult(Guid.NewGuid().ToString("N")); + } + + /// + protected override Task DeleteInternalAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + /// + protected override Task OnNewMessageInternalAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + #region private + private readonly A2AClient _client; + #endregion +} diff --git a/dotnet/src/Agents/A2A/A2AHostAgent.cs b/dotnet/src/Agents/A2A/A2AHostAgent.cs new file mode 100644 index 000000000000..da36fad6e725 --- /dev/null +++ b/dotnet/src/Agents/A2A/A2AHostAgent.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using SharpA2A.Core; + +namespace Microsoft.SemanticKernel.Agents.A2A; + +/// +/// Host which will attach a to a +/// +public sealed class A2AHostAgent +{ + /// + /// Initializes a new instance of the SemanticKernelTravelAgent + /// + public A2AHostAgent(Agent agent, AgentCard agentCard, TaskManager? taskManager = null) + { + Verify.NotNull(agent); + Verify.NotNull(agentCard); + + this.Agent = agent; + this._agentCard = agentCard; + + this.Attach(taskManager ?? new TaskManager()); + } + + /// + /// The associated + /// + public Agent? Agent { get; private set; } + + /// + /// The associated + /// + public TaskManager? TaskManager => this._taskManager; + + /// + /// Attach the to the provided + /// + /// + public void Attach(TaskManager taskManager) + { + Verify.NotNull(taskManager); + + this._taskManager = taskManager; + taskManager.OnTaskCreated = this.ExecuteAgentTaskAsync; + taskManager.OnTaskUpdated = this.ExecuteAgentTaskAsync; + taskManager.OnAgentCardQuery = this.GetAgentCard; + } + /// + /// Execute the specific + /// + /// + /// + /// + public async Task ExecuteAgentTaskAsync(AgentTask task) + { + Verify.NotNull(task); + Verify.NotNull(this.Agent); + + if (this._taskManager is null) + { + throw new InvalidOperationException("TaskManager must be attached before executing an agent task."); + } + + await this._taskManager.UpdateStatusAsync(task.Id, TaskState.Working).ConfigureAwait(false); + + // Get message from the user + var userMessage = task.History!.Last().Parts.First().AsTextPart().Text; + + // Get the response from the agent + var artifact = new Artifact(); + await foreach (AgentResponseItem response in this.Agent.InvokeAsync(userMessage).ConfigureAwait(false)) + { + var content = response.Message.Content; + artifact.Parts.Add(new TextPart() { Text = content! }); + } + + // Return as artifacts + await this._taskManager.ReturnArtifactAsync(task.Id, artifact).ConfigureAwait(false); + await this._taskManager.UpdateStatusAsync(task.Id, TaskState.Completed).ConfigureAwait(false); + } + + /// + /// Return the associated with this hosted agent. + /// + /// Current URL for the agent +#pragma warning disable CA1054 // URI-like parameters should not be strings + public AgentCard GetAgentCard(string agentUrl) + { + // Ensure the URL is in the correct format + Uri uri = new(agentUrl); + agentUrl = $"{uri.Scheme}://{uri.Host}:{uri.Port}/"; + + this._agentCard.Url = agentUrl; + return this._agentCard; + } +#pragma warning restore CA1054 // URI-like parameters should not be strings + + #region private + private readonly AgentCard _agentCard; + private TaskManager? _taskManager; + #endregion +} diff --git a/dotnet/src/Agents/A2A/Agents.A2A.csproj b/dotnet/src/Agents/A2A/Agents.A2A.csproj new file mode 100644 index 000000000000..9defa8d88691 --- /dev/null +++ b/dotnet/src/Agents/A2A/Agents.A2A.csproj @@ -0,0 +1,44 @@ + + + + + Microsoft.SemanticKernel.Agents.A2A + Microsoft.SemanticKernel.Agents.A2A + net8.0;netstandard2.0 + $(NoWarn);SKEXP0110 + false + alpha + + + + + + + Semantic Kernel Agents - A2A + Defines a concrete Agent based on the A2A Protocol. + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/dotnet/src/Agents/A2A/Extensions/AuthorRoleExtensions.cs b/dotnet/src/Agents/A2A/Extensions/AuthorRoleExtensions.cs new file mode 100644 index 000000000000..211a3cc9494d --- /dev/null +++ b/dotnet/src/Agents/A2A/Extensions/AuthorRoleExtensions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.SemanticKernel.ChatCompletion; +using SharpA2A.Core; + +namespace Microsoft.SemanticKernel.Agents.A2A; + +/// +/// Extensions for converting between amd . +/// +internal static class AuthorRoleExtensions +{ + public static AuthorRole ToAuthorRole(this MessageRole role) + { + return role switch + { + MessageRole.User => AuthorRole.User, + MessageRole.Agent => AuthorRole.Assistant, + _ => throw new ArgumentOutOfRangeException(nameof(role), role, "Invalid message role") + }; + } + + public static MessageRole ToMessageRole(this AuthorRole role) + { + return role.Label switch + { + "user" => MessageRole.User, + "assistant" => MessageRole.Agent, + _ => throw new ArgumentOutOfRangeException(nameof(role), role, "Invalid author role") + }; + } +} diff --git a/dotnet/src/Agents/Abstractions/Agent.cs b/dotnet/src/Agents/Abstractions/Agent.cs index 5f21f83bb6c0..e35fbc5738a6 100644 --- a/dotnet/src/Agents/Abstractions/Agent.cs +++ b/dotnet/src/Agents/Abstractions/Agent.cs @@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.Arguments.Extensions; @@ -66,6 +67,17 @@ public abstract class Agent /// public Kernel Kernel { get; init; } = new(); + /// + /// This option forces the agent to clone the original kernel instance during invocation if true. Default is false. + /// + /// + /// implementations that provide instances require the + /// kernel to be cloned during agent invocation, but cloning has the side affect of causing modifications to Kernel + /// Data by plugins to be lost. Cloning is therefore opt-in. + /// + [Experimental("SKEXP0130")] + public bool UseImmutableKernel { get; set; } = false; + /// /// Gets or sets a prompt template based on the agent instructions. /// diff --git a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs index a942eae8ba73..1b252b84eddc 100644 --- a/dotnet/src/Agents/AzureAI/AzureAIAgent.cs +++ b/dotnet/src/Agents/AzureAI/AzureAIAgent.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; @@ -136,11 +137,22 @@ public async IAsyncEnumerable> InvokeAsync () => new AzureAIAgentThread(this.Client), cancellationToken).ConfigureAwait(false); - Kernel kernel = (options?.Kernel ?? this.Kernel).Clone(); + Kernel kernel = this.GetKernel(options); +#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (this.UseImmutableKernel) + { + kernel = kernel.Clone(); + } // Get the context contributions from the AIContextProviders. -#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. AIContext providersContext = await azureAIAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false); + + // Check for compatibility AIContextProviders and the UseImmutableKernel setting. + if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel) + { + throw new InvalidOperationException("AIContextProviders with AIFunctions are not supported when Agent UseImmutableKernel setting is false."); + } + kernel.Plugins.AddFromAIContext(providersContext, "Tools"); #pragma warning restore SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. @@ -228,11 +240,22 @@ public async IAsyncEnumerable> In () => new AzureAIAgentThread(this.Client), cancellationToken).ConfigureAwait(false); - var kernel = (options?.Kernel ?? this.Kernel).Clone(); + Kernel kernel = this.GetKernel(options); +#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (this.UseImmutableKernel) + { + kernel = kernel.Clone(); + } // Get the context contributions from the AIContextProviders. -#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. AIContext providersContext = await azureAIAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false); + + // Check for compatibility AIContextProviders and the UseImmutableKernel setting. + if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel) + { + throw new InvalidOperationException("AIContextProviders with AIFunctions are not supported when Agent UseImmutableKernel setting is false."); + } + kernel.Plugins.AddFromAIContext(providersContext, "Tools"); #pragma warning restore SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. diff --git a/dotnet/src/Agents/AzureAI/AzureAIChannel.cs b/dotnet/src/Agents/AzureAI/AzureAIChannel.cs index e1b57d4ad32b..4620cf05f756 100644 --- a/dotnet/src/Agents/AzureAI/AzureAIChannel.cs +++ b/dotnet/src/Agents/AzureAI/AzureAIChannel.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Azure; using Azure.AI.Agents.Persistent; using Microsoft.SemanticKernel.Agents.AzureAI.Internal; using Microsoft.SemanticKernel.Agents.Extensions; @@ -18,9 +20,22 @@ internal sealed class AzureAIChannel(PersistentAgentsClient client, string threa /// protected override async Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken) { + const string ErrorMessage = "The message could not be added to the thread due to an error response from the service."; + foreach (ChatMessageContent message in history) { - await AgentThreadActions.CreateMessageAsync(client, threadId, message, cancellationToken).ConfigureAwait(false); + try + { + await AgentThreadActions.CreateMessageAsync(client, threadId, message, cancellationToken).ConfigureAwait(false); + } + catch (RequestFailedException ex) + { + throw new AgentThreadOperationException(ErrorMessage, ex); + } + catch (AggregateException ex) + { + throw new AgentThreadOperationException(ErrorMessage, ex); + } } } diff --git a/dotnet/src/Agents/AzureAI/Extensions/FoundryWorkflowExtensions.cs b/dotnet/src/Agents/AzureAI/Extensions/FoundryWorkflowExtensions.cs deleted file mode 100644 index 4c02567c5a32..000000000000 --- a/dotnet/src/Agents/AzureAI/Extensions/FoundryWorkflowExtensions.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using Azure.AI.Agents.Persistent; -using Azure.Core; - -namespace Microsoft.SemanticKernel.Agents.AzureAI; - -/// -/// Extensions for configuring the PersistentAgentsAdministrationClientOptions with a routing policy for Foundry Workflows. -/// -public static class FoundryWorkflowExtensions -{ - /// - /// Adds a routing policy to the PersistentAgentsAdministrationClientOptions for Foundry Workflows. - /// - /// - /// - /// - /// - /// - public static PersistentAgentsAdministrationClientOptions WithPolicy(this PersistentAgentsAdministrationClientOptions options, string endpoint, string apiVersion) - { - if (!Uri.TryCreate(endpoint, UriKind.Absolute, out var _endpoint)) - { - throw new ArgumentException("The endpoint must be an absolute URI.", nameof(endpoint)); - } - - options.AddPolicy(new HttpPipelineRoutingPolicy(_endpoint, apiVersion), HttpPipelinePosition.PerCall); - - return options; - } -} diff --git a/dotnet/src/Agents/AzureAI/Internal/AgentMessageFactory.cs b/dotnet/src/Agents/AzureAI/Internal/AgentMessageFactory.cs index fd63784909c2..bb916126cc32 100644 --- a/dotnet/src/Agents/AzureAI/Internal/AgentMessageFactory.cs +++ b/dotnet/src/Agents/AzureAI/Internal/AgentMessageFactory.cs @@ -62,7 +62,13 @@ public static IEnumerable GetMessageContent(ChatMessag { if (content is TextContent textContent) { - yield return new MessageInputTextBlock(content.ToString()); + var text = content.ToString(); + if (string.IsNullOrWhiteSpace(text)) + { + // Message content must be non-empty. + continue; + } + yield return new MessageInputTextBlock(text); } else if (content is ImageContent imageContent) { diff --git a/dotnet/src/Agents/AzureAI/Internal/AgentThreadActions.cs b/dotnet/src/Agents/AzureAI/Internal/AgentThreadActions.cs index 7173ca219923..412642d88204 100644 --- a/dotnet/src/Agents/AzureAI/Internal/AgentThreadActions.cs +++ b/dotnet/src/Agents/AzureAI/Internal/AgentThreadActions.cs @@ -65,10 +65,16 @@ public static async Task CreateMessageAsync(PersistentAgentsClient client, strin return; } + var contentBlocks = AgentMessageFactory.GetMessageContent(message); + if (!contentBlocks.Any()) + { + return; + } + await client.Messages.CreateMessageAsync( threadId, role: message.Role == AuthorRole.User ? MessageRole.User : MessageRole.Agent, - contentBlocks: AgentMessageFactory.GetMessageContent(message), + contentBlocks: contentBlocks, attachments: AgentMessageFactory.GetAttachments(message), metadata: AgentMessageFactory.GetMetadata(message), cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/Agents/Copilot/CopilotStudioAgent.cs b/dotnet/src/Agents/Copilot/CopilotStudioAgent.cs index 57996b6e95a3..f9424e72a7b9 100644 --- a/dotnet/src/Agents/Copilot/CopilotStudioAgent.cs +++ b/dotnet/src/Agents/Copilot/CopilotStudioAgent.cs @@ -66,6 +66,12 @@ public override async IAsyncEnumerable> In await foreach (ChatMessageContent result in invokeResults.ConfigureAwait(false)) { await this.NotifyThreadOfNewMessage(agentThread, result, cancellationToken).ConfigureAwait(false); + + if (options?.OnIntermediateMessage is not null) + { + await options.OnIntermediateMessage(result).ConfigureAwait(false); + } + yield return new(result, agentThread); } } diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 56c8712ab50f..307009fe5099 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -73,11 +73,22 @@ public override async IAsyncEnumerable> In () => new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); - Kernel kernel = (options?.Kernel ?? this.Kernel).Clone(); + Kernel kernel = this.GetKernel(options); +#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (this.UseImmutableKernel) + { + kernel = kernel.Clone(); + } // Get the context contributions from the AIContextProviders. -#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. AIContext providersContext = await chatHistoryAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false); + + // Check for compatibility AIContextProviders and the UseImmutableKernel setting. + if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel) + { + throw new InvalidOperationException("AIContextProviders with AIFunctions are not supported when Agent UseImmutableKernel setting is false."); + } + kernel.Plugins.AddFromAIContext(providersContext, "Tools"); #pragma warning restore SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. @@ -168,11 +179,22 @@ public override async IAsyncEnumerable new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); - Kernel kernel = (options?.Kernel ?? this.Kernel).Clone(); + Kernel kernel = this.GetKernel(options); +#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (this.UseImmutableKernel) + { + kernel = kernel.Clone(); + } // Get the context contributions from the AIContextProviders. -#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. AIContext providersContext = await chatHistoryAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false); + + // Check for compatibility AIContextProviders and the UseImmutableKernel setting. + if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel) + { + throw new InvalidOperationException("AIContextProviders with AIFunctions are not supported when Agent UseImmutableKernel setting is false."); + } + kernel.Plugins.AddFromAIContext(providersContext, "Tools"); #pragma warning restore SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. @@ -391,6 +413,7 @@ private async IAsyncEnumerable InternalInvokeStream this.Logger.LogAgentChatServiceInvokedStreamingAgent(nameof(InvokeAsync), this.Id, agentName, serviceType); + int messageIndex = messageCount; AuthorRole? role = null; StringBuilder builder = new(); await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false)) @@ -401,18 +424,18 @@ private async IAsyncEnumerable InternalInvokeStream builder.Append(message.ToString()); - yield return message; - } + // Capture mutated messages related function calling / tools + for (; messageIndex < chat.Count; messageIndex++) + { + ChatMessageContent chatMessage = chat[messageIndex]; - // Capture mutated messages related function calling / tools - for (int messageIndex = messageCount; messageIndex < chat.Count; messageIndex++) - { - ChatMessageContent message = chat[messageIndex]; + chatMessage.AuthorName = this.Name; - message.AuthorName = this.Name; + await onNewMessage(chatMessage).ConfigureAwait(false); + history.Add(chatMessage); + } - await onNewMessage(message).ConfigureAwait(false); - history.Add(message); + yield return message; } // Do not duplicate terminated function result to history diff --git a/dotnet/src/Agents/Magentic/MagenticManagerActor.cs b/dotnet/src/Agents/Magentic/MagenticManagerActor.cs index bb6cf93b1b20..881dfc365110 100644 --- a/dotnet/src/Agents/Magentic/MagenticManagerActor.cs +++ b/dotnet/src/Agents/Magentic/MagenticManagerActor.cs @@ -8,7 +8,6 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.Agents.Orchestration; -using Microsoft.SemanticKernel.Agents.Orchestration.GroupChat; using Microsoft.SemanticKernel.Agents.Runtime; using Microsoft.SemanticKernel.Agents.Runtime.Core; using Microsoft.SemanticKernel.ChatCompletion; diff --git a/dotnet/src/Agents/Magentic/MagenticMessages.cs b/dotnet/src/Agents/Magentic/MagenticMessages.cs index 388b9305c69e..8ca25cb6ce5e 100644 --- a/dotnet/src/Agents/Magentic/MagenticMessages.cs +++ b/dotnet/src/Agents/Magentic/MagenticMessages.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Agents.Magentic; @@ -76,4 +77,9 @@ public sealed class InputTask /// Extension method to convert a to a message. ///
    public static Result AsResultMessage(this ChatMessageContent message) => new() { Message = message }; + + /// + /// Extension method to convert a to a . + /// + public static Result AsResultMessage(this string text) => new() { Message = new(AuthorRole.Assistant, text) }; } diff --git a/dotnet/src/Agents/OpenAI/Extensions/ChatContentMessageExtensions.cs b/dotnet/src/Agents/OpenAI/Extensions/ChatContentMessageExtensions.cs index 3c6db7f5b894..e8f9b8cea96b 100644 --- a/dotnet/src/Agents/OpenAI/Extensions/ChatContentMessageExtensions.cs +++ b/dotnet/src/Agents/OpenAI/Extensions/ChatContentMessageExtensions.cs @@ -44,12 +44,12 @@ public static IEnumerable ToThreadInitializationMes public static ResponseItem ToResponseItem(this ChatMessageContent message) { string content = message.Content ?? string.Empty; - return message.Role.Label switch + return message.Role.Label.ToUpperInvariant() switch { - "system" => ResponseItem.CreateSystemMessageItem(content), - "user" => ResponseItem.CreateUserMessageItem(content), - "developer" => ResponseItem.CreateDeveloperMessageItem(content), - "assistant" => ResponseItem.CreateAssistantMessageItem(content), + "SYSTEM" => ResponseItem.CreateSystemMessageItem(content), + "USER" => ResponseItem.CreateUserMessageItem(content), + "DEVELOPER" => ResponseItem.CreateDeveloperMessageItem(content), + "ASSISTANT" => ResponseItem.CreateAssistantMessageItem(content), _ => throw new NotSupportedException($"Unsupported role {message.Role.Label}. Only system, user, developer or assistant roles are allowed."), }; } diff --git a/dotnet/src/Agents/OpenAI/Extensions/OpenAIResponseExtensions.cs b/dotnet/src/Agents/OpenAI/Extensions/OpenAIResponseExtensions.cs index d923f6bfb023..d65e0f940fff 100644 --- a/dotnet/src/Agents/OpenAI/Extensions/OpenAIResponseExtensions.cs +++ b/dotnet/src/Agents/OpenAI/Extensions/OpenAIResponseExtensions.cs @@ -43,18 +43,25 @@ public static ChatMessageContent ToChatMessageContent(this OpenAIResponse respon ///
    /// The response item to convert. /// A instance. - public static ChatMessageContent ToChatMessageContent(this ResponseItem item) + public static ChatMessageContent? ToChatMessageContent(this ResponseItem item) { if (item is MessageResponseItem messageResponseItem) { var role = messageResponseItem.Role.ToAuthorRole(); return new ChatMessageContent(role, item.ToChatMessageContentItemCollection(), innerContent: messageResponseItem); } + else if (item is ReasoningResponseItem reasoningResponseItem) + { + if (reasoningResponseItem.SummaryTextParts is not null && reasoningResponseItem.SummaryTextParts.Count > 0) + { + return new ChatMessageContent(AuthorRole.Assistant, item.ToChatMessageContentItemCollection(), innerContent: reasoningResponseItem); + } + } else if (item is FunctionCallResponseItem functionCallResponseItem) { return new ChatMessageContent(AuthorRole.Assistant, item.ToChatMessageContentItemCollection(), innerContent: functionCallResponseItem); } - throw new NotSupportedException($"Unsupported response item: {item.GetType()}"); + return null; } /// @@ -68,6 +75,10 @@ public static ChatMessageContentItemCollection ToChatMessageContentItemCollectio { return messageResponseItem.Content.ToChatMessageContentItemCollection(); } + else if (item is ReasoningResponseItem reasoningResponseItem) + { + return reasoningResponseItem.SummaryTextParts.ToChatMessageContentItemCollection(); + } else if (item is FunctionCallResponseItem functionCallResponseItem) { Exception? exception = null; @@ -92,7 +103,7 @@ public static ChatMessageContentItemCollection ToChatMessageContentItemCollectio }; return [functionCallContent]; } - throw new NotImplementedException($"Unsupported response item: {item.GetType()}"); + return []; } /// @@ -183,6 +194,16 @@ private static ChatMessageContentItemCollection ToChatMessageContentItemCollecti } return collection; } + + private static ChatMessageContentItemCollection ToChatMessageContentItemCollection(this IReadOnlyList texts) + { + var collection = new ChatMessageContentItemCollection(); + foreach (var text in texts) + { + collection.Add(new TextContent(text, innerContent: null)); + } + return collection; + } #endregion } diff --git a/dotnet/src/Agents/OpenAI/Internal/AssistantMessageFactory.cs b/dotnet/src/Agents/OpenAI/Internal/AssistantMessageFactory.cs index fadccba9cd93..008e781fafa8 100644 --- a/dotnet/src/Agents/OpenAI/Internal/AssistantMessageFactory.cs +++ b/dotnet/src/Agents/OpenAI/Internal/AssistantMessageFactory.cs @@ -45,7 +45,13 @@ public static IEnumerable GetMessageContents(ChatMessageContent { if (content is TextContent textContent) { - yield return MessageContent.FromText(content.ToString()); + var text = content.ToString(); + if (string.IsNullOrWhiteSpace(text)) + { + // Message content must be non-empty. + continue; + } + yield return MessageContent.FromText(text); } else if (content is ImageContent imageContent) { diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs index 1971b8bc9058..11f813fd3267 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; @@ -138,11 +139,22 @@ public async IAsyncEnumerable> InvokeAsync AdditionalInstructions = options?.AdditionalInstructions, }); - Kernel kernel = (options?.Kernel ?? this.Kernel).Clone(); + Kernel kernel = this.GetKernel(options); +#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (this.UseImmutableKernel) + { + kernel = kernel.Clone(); + } // Get the context contributions from the AIContextProviders. -#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. AIContext providersContext = await openAIAssistantAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false); + + // Check for compatibility AIContextProviders and the UseImmutableKernel setting. + if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel) + { + throw new InvalidOperationException("AIContextProviders with AIFunctions are not supported when Agent UseImmutableKernel setting is false."); + } + kernel.Plugins.AddFromAIContext(providersContext, "Tools"); #pragma warning restore SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. @@ -226,11 +238,22 @@ public async IAsyncEnumerable> In () => new OpenAIAssistantAgentThread(this.Client), cancellationToken).ConfigureAwait(false); - Kernel kernel = (options?.Kernel ?? this.Kernel).Clone(); + Kernel kernel = this.GetKernel(options); +#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (this.UseImmutableKernel) + { + kernel = kernel.Clone(); + } // Get the context contributions from the AIContextProviders. -#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. AIContext providersContext = await openAIAssistantAgentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false); + + // Check for compatibility AIContextProviders and the UseImmutableKernel setting. + if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel) + { + throw new InvalidOperationException("AIContextProviders with AIFunctions are not supported when Agent UseImmutableKernel setting is false."); + } + kernel.Plugins.AddFromAIContext(providersContext, "Tools"); #pragma warning restore SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs index bdb5de57a121..7f5a5194092d 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs @@ -1,4 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.ClientModel; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; @@ -23,9 +25,22 @@ internal sealed class OpenAIAssistantChannel(AssistantClient client, string thre /// protected override async Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken) { + const string ErrorMessage = "The message could not be added to the thread due to an error response from the service."; + foreach (ChatMessageContent message in history) { - await AssistantThreadActions.CreateMessageAsync(this._client, this._threadId, message, cancellationToken).ConfigureAwait(false); + try + { + await AssistantThreadActions.CreateMessageAsync(this._client, this._threadId, message, cancellationToken).ConfigureAwait(false); + } + catch (ClientResultException ex) + { + throw new AgentThreadOperationException(ErrorMessage, ex); + } + catch (AggregateException ex) + { + throw new AgentThreadOperationException(ErrorMessage, ex); + } } } diff --git a/dotnet/src/Agents/OpenAI/OpenAIResponseAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIResponseAgent.cs index 87b0912d01ef..0e1f84d610df 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIResponseAgent.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIResponseAgent.cs @@ -37,7 +37,7 @@ public OpenAIResponseAgent(OpenAIResponseClient client) /// /// Storing of messages is enabled. /// - public bool StoreEnabled { get; init; } = true; + public bool StoreEnabled { get; init; } = false; /// public override async IAsyncEnumerable> InvokeAsync(ICollection messages, AgentThread? thread = null, AgentInvokeOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) @@ -47,7 +47,7 @@ public override async IAsyncEnumerable> In AgentThread agentThread = await this.EnsureThreadExistsWithMessagesAsync(messages, thread, cancellationToken).ConfigureAwait(false); // Get the context contributions from the AIContextProviders. - OpenAIAssistantAgentInvokeOptions extensionsContextOptions = await this.FinalizeInvokeOptionsAsync(messages, options, agentThread, cancellationToken).ConfigureAwait(false); + OpenAIResponseAgentInvokeOptions extensionsContextOptions = await this.FinalizeInvokeOptionsAsync(messages, options, agentThread, cancellationToken).ConfigureAwait(false); // Invoke responses with the updated chat history. ChatHistory chatHistory = [.. messages]; @@ -74,7 +74,7 @@ public override async IAsyncEnumerable EnsureThreadExistsWithMessagesAsync(ICollection< return await this.EnsureThreadExistsWithMessagesAsync(messages, thread, () => new ChatHistoryAgentThread(), cancellationToken).ConfigureAwait(false); } - private async Task FinalizeInvokeOptionsAsync(ICollection messages, AgentInvokeOptions? options, AgentThread agentThread, CancellationToken cancellationToken) + private async Task FinalizeInvokeOptionsAsync(ICollection messages, AgentInvokeOptions? options, AgentThread agentThread, CancellationToken cancellationToken) { - Kernel kernel = this.GetKernel(options).Clone(); + Kernel kernel = this.GetKernel(options); +#pragma warning disable SKEXP0110, SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (this.UseImmutableKernel) + { + kernel = kernel.Clone(); + } -#pragma warning disable SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + // Get the AIContextProviders contributions to the kernel. AIContext providersContext = await agentThread.AIContextProviders.ModelInvokingAsync(messages, cancellationToken).ConfigureAwait(false); + + // Check for compatibility AIContextProviders and the UseImmutableKernel setting. + if (providersContext.AIFunctions is { Count: > 0 } && !this.UseImmutableKernel) + { + throw new InvalidOperationException("AIContextProviders with AIFunctions are not supported when Agent UseImmutableKernel setting is false."); + } + kernel.Plugins.AddFromAIContext(providersContext, "Tools"); #pragma warning restore SKEXP0130 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. string mergedAdditionalInstructions = FormatAdditionalInstructions(providersContext, options); - OpenAIAssistantAgentInvokeOptions extensionsContextOptions = + OpenAIResponseAgentInvokeOptions extensionsContextOptions = options is null ? new() { diff --git a/dotnet/src/Agents/OpenAI/OpenAIResponseAgentInvokeOptions.cs b/dotnet/src/Agents/OpenAI/OpenAIResponseAgentInvokeOptions.cs index e11f55c41463..ec84f87cc042 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIResponseAgentInvokeOptions.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIResponseAgentInvokeOptions.cs @@ -24,6 +24,11 @@ public OpenAIResponseAgentInvokeOptions(AgentInvokeOptions options) : base(options) { Verify.NotNull(options); + + if (options is OpenAIResponseAgentInvokeOptions responseAgentInvokeOptions) + { + this.ResponseCreationOptions = responseAgentInvokeOptions.ResponseCreationOptions; + } } /// @@ -34,6 +39,8 @@ public OpenAIResponseAgentInvokeOptions(OpenAIResponseAgentInvokeOptions options : base(options) { Verify.NotNull(options); + + this.ResponseCreationOptions = options.ResponseCreationOptions; } /// diff --git a/dotnet/src/Agents/OpenAI/OpenAIResponseAgentThread.cs b/dotnet/src/Agents/OpenAI/OpenAIResponseAgentThread.cs index b4f424c1001b..3bbaa97d13dd 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIResponseAgentThread.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIResponseAgentThread.cs @@ -111,7 +111,11 @@ public async IAsyncEnumerable GetMessagesAsync([EnumeratorCa var collectionResult = this._client.GetResponseInputItemsAsync(this.ResponseId, default, cancellationToken).ConfigureAwait(false); await foreach (var responseItem in collectionResult) { - yield return responseItem.ToChatMessageContent(); + var messageContent = responseItem.ToChatMessageContent(); + if (messageContent is not null) + { + yield return messageContent; + } } } } diff --git a/dotnet/src/Agents/Orchestration/GroupChat/GroupChatMessages.cs b/dotnet/src/Agents/Orchestration/GroupChat/GroupChatMessages.cs index aaf084b700c9..cd6a6f73f415 100644 --- a/dotnet/src/Agents/Orchestration/GroupChat/GroupChatMessages.cs +++ b/dotnet/src/Agents/Orchestration/GroupChat/GroupChatMessages.cs @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Agents.Orchestration.GroupChat; /// /// Common messages used for agent chat patterns. /// -public static class GroupChatMessages +internal static class GroupChatMessages { /// /// An empty message instance as a default. @@ -79,7 +79,7 @@ public sealed class InputTask public static InputTask AsInputTaskMessage(this IEnumerable messages) => new() { Messages = messages }; /// - /// Extension method to convert a to a . + /// Extension method to convert a to a . /// public static Result AsResultMessage(this string text) => new() { Message = new(AuthorRole.Assistant, text) }; } diff --git a/dotnet/src/Agents/Runtime/InProcess.Tests/PublishMessageTests.cs b/dotnet/src/Agents/Runtime/InProcess.Tests/PublishMessageTests.cs index c81a80ba1d86..c73070ff6763 100644 --- a/dotnet/src/Agents/Runtime/InProcess.Tests/PublishMessageTests.cs +++ b/dotnet/src/Agents/Runtime/InProcess.Tests/PublishMessageTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Linq; using System.Threading.Tasks; using FluentAssertions; using Xunit; @@ -56,10 +55,7 @@ public async Task Test_PublishMessage_MultipleFailures() Func publishTask = async () => await fixture.RunPublishTestAsync(new TopicId("TestTopic"), new BasicMessage { Content = "1" }); // What we are really testing here is that a single exception does not prevent sending to the remaining agents - (await publishTask.Should().ThrowAsync()) - .Which.Should().Match( - exception => exception.InnerExceptions.Count == 2 && - exception.InnerExceptions.All(exception => exception is TestException)); + await publishTask.Should().ThrowAsync(); fixture.GetAgentInstances().Values .Should().HaveCount(2) @@ -81,11 +77,7 @@ public async Task Test_PublishMessage_MixedSuccessFailure() Func publicTask = async () => await fixture.RunPublishTestAsync(new TopicId("TestTopic"), new BasicMessage { Content = "1" }); // What we are really testing here is that raising exceptions does not prevent sending to the remaining agents - (await publicTask.Should().ThrowAsync()) - .Which.Should().Match( - exception => exception.InnerExceptions.Count == 2 && - exception.InnerExceptions.All( - exception => exception is TestException)); + await publicTask.Should().ThrowAsync(); fixture.GetAgentInstances().Values .Should().HaveCount(2, "Two ReceiverAgents should have been created") diff --git a/dotnet/src/Agents/Runtime/InProcess/InProcessRuntime.cs b/dotnet/src/Agents/Runtime/InProcess/InProcessRuntime.cs index 93e3ec89144b..2ab31ab3f9ab 100644 --- a/dotnet/src/Agents/Runtime/InProcess/InProcessRuntime.cs +++ b/dotnet/src/Agents/Runtime/InProcess/InProcessRuntime.cs @@ -352,45 +352,41 @@ private async ValueTask PublishMessageServicerAsync(MessageEnvelope envelope, Ca throw new InvalidOperationException("Message must have a topic to be published."); } - List exceptions = []; + List? tasks = null; TopicId topic = envelope.Topic.Value; foreach (ISubscriptionDefinition subscription in this._subscriptions.Values.Where(subscription => subscription.Matches(topic))) { - try - { - deliveryToken.ThrowIfCancellationRequested(); + (tasks ??= []).Add(ProcessSubscriptionAsync(envelope, topic, subscription, deliveryToken)); + } - AgentId? sender = envelope.Sender; + if (tasks is not null) + { + await Task.WhenAll(tasks).ConfigureAwait(false); + } - using CancellationTokenSource combinedSource = CancellationTokenSource.CreateLinkedTokenSource(envelope.Cancellation, deliveryToken); - MessageContext messageContext = new(envelope.MessageId, combinedSource.Token) - { - Sender = sender, - Topic = topic, - IsRpc = false - }; + async Task ProcessSubscriptionAsync(MessageEnvelope envelope, TopicId topic, ISubscriptionDefinition subscription, CancellationToken deliveryToken) + { + deliveryToken.ThrowIfCancellationRequested(); - AgentId agentId = subscription.MapToAgent(topic); - if (!this.DeliverToSelf && sender.HasValue && sender == agentId) - { - continue; - } + AgentId? sender = envelope.Sender; - IHostableAgent agent = await this.EnsureAgentAsync(agentId).ConfigureAwait(false); + using CancellationTokenSource combinedSource = CancellationTokenSource.CreateLinkedTokenSource(envelope.Cancellation, deliveryToken); + MessageContext messageContext = new(envelope.MessageId, combinedSource.Token) + { + Sender = sender, + Topic = topic, + IsRpc = false + }; - // TODO: Cancellation propagation! - await agent.OnMessageAsync(envelope.Message, messageContext).ConfigureAwait(false); - } - catch (Exception ex) when (!ex.IsCriticalException()) + AgentId agentId = subscription.MapToAgent(topic); + if (!this.DeliverToSelf && sender.HasValue && sender == agentId) { - exceptions.Add(ex); + return; } - } - if (exceptions.Count > 0) - { - // TODO: Unwrap TargetInvocationException? - throw new AggregateException("One or more exceptions occurred while processing the message.", exceptions); + IHostableAgent agent = await this.EnsureAgentAsync(agentId).ConfigureAwait(false); + + await agent.OnMessageAsync(envelope.Message, messageContext).ConfigureAwait(false); } } diff --git a/dotnet/src/Agents/UnitTests/A2A/A2AAgentTests.cs b/dotnet/src/Agents/UnitTests/A2A/A2AAgentTests.cs new file mode 100644 index 000000000000..3c18d0b594dc --- /dev/null +++ b/dotnet/src/Agents/UnitTests/A2A/A2AAgentTests.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.A2A; +using Xunit; + +namespace SemanticKernel.Agents.UnitTests.A2A; + +/// +/// Tests for the class. +/// +public sealed class A2AAgentTests : BaseA2AClientTest +{ + /// + /// Tests that the constructor verifies parameters and throws when necessary. + /// + [Fact] + public void ConstructorShouldVerifyParams() + { + using var httpClient = new HttpClient(); + + // Arrange & Act & Assert + Assert.Throws(() => new A2AAgent(null!, new())); + Assert.Throws(() => new A2AAgent(new(httpClient), null!)); + } + + [Fact] + public void VerifyConstructor() + { + // Arrange & Act + var agent = new A2AAgent(this.Client, this.CreateAgentCard()); + + // Assert + Assert.NotNull(agent); + Assert.Equal("InvoiceAgent", agent.Name); + Assert.Equal("Handles requests relating to invoices.", agent.Description); + } + + [Fact] + public async Task VerifyInvokeAsync() + { + // Arrange + this.MessageHandlerStub.ResponsesToReturn.Add( + new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StringContent(InvokeResponse, Encoding.UTF8, "application/json") } + ); + var agent = new A2AAgent(this.Client, this.CreateAgentCard()); + + // Act + var responseItems = agent.InvokeAsync("List the latest invoices for Contoso?"); + + // Assert + Assert.NotNull(responseItems); + var items = await responseItems!.ToListAsync>(); + Assert.Single(items); + Assert.StartsWith("Here are the latest invoices for Contoso:", items[0].Message.Content); + } + + [Fact] + public async Task VerifyInvokeStreamingAsync() + { + // Arrange + this.MessageHandlerStub.ResponsesToReturn.Add( + new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StringContent(InvokeResponse, Encoding.UTF8, "application/json") } + ); + var agent = new A2AAgent(this.Client, this.CreateAgentCard()); + + // Act + var responseItems = agent.InvokeStreamingAsync("List the latest invoices for Contoso?"); + + // Assert + Assert.NotNull(responseItems); + var items = await responseItems!.ToListAsync>(); + Assert.Single(items); + Assert.StartsWith("Here are the latest invoices for Contoso:", items[0].Message.Content); + } + + #region private + private const string InvokeResponse = + """ + {"jsonrpc":"2.0","id":"ce7a5ef6-1078-4b6e-ad35-a8bfa6743c5d","result":{"kind":"task","id":"8d328159-ca63-4ce8-b416-4bcf69f9e119","contextId":"496a4a95-392b-4c04-a517-9a043b3f7565","status":{"state":"completed","timestamp":"2025-06-20T09:42:49.4013958Z"},"artifacts":[{"artifactId":"","parts":[{"kind":"text","text":"Here are the latest invoices for Contoso:\n\n1. Invoice ID: INV789, Date: 2025-06-18\n Products: T-Shirts (150 units at $10.00), Hats (200 units at $15.00), Glasses (300 units at $5.00)\n\n2. Invoice ID: INV666, Date: 2025-06-15\n Products: T-Shirts (2500 units at $8.00), Hats (1200 units at $10.00), Glasses (1000 units at $6.00)\n\n3. Invoice ID: INV999, Date: 2025-05-17\n Products: T-Shirts (1400 units at $10.50), Hats (1100 units at $9.00), Glasses (950 units at $12.00)\n\n4. Invoice ID: INV333, Date: 2025-05-13\n Products: T-Shirts (400 units at $11.00), Hats (600 units at $15.00), Glasses (700 units at $5.00)\n\nIf you need more details on any specific invoice, please let me know!"}]}],"history":[{"role":"user","parts":[{"kind":"text","text":"List the latest invoices for Contoso?"}],"messageId":"80a26c0f-2262-4d0f-8e7d-51ac4046173b"}]}} + """; + #endregion +} diff --git a/dotnet/src/Agents/UnitTests/A2A/A2AHostAgentTests.cs b/dotnet/src/Agents/UnitTests/A2A/A2AHostAgentTests.cs new file mode 100644 index 000000000000..b7f838d2764a --- /dev/null +++ b/dotnet/src/Agents/UnitTests/A2A/A2AHostAgentTests.cs @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.A2A; +using Microsoft.SemanticKernel.ChatCompletion; +using SharpA2A.Core; +using Xunit; + +namespace SemanticKernel.Agents.UnitTests.A2A; + +/// +/// Tests for the class. +/// +public sealed class A2AHostAgentTests : BaseA2AClientTest +{ + /// + /// Tests that the constructor verifies parameters and throws when necessary. + /// + [Fact] + public void ConstructorShouldVerifyParams() + { + // Arrange & Act & Assert + Assert.Throws(() => new A2AHostAgent(null!, this.CreateAgentCard())); + Assert.Throws(() => new A2AHostAgent(new MockAgent(), null!)); + } + + [Fact] + public async Task VerifyExecuteAgentTaskAsync() + { + // Arrange + var agent = new MockAgent(); + var taskManager = new TaskManager(); + var hostAgent = new A2AHostAgent(agent, this.CreateAgentCard(), taskManager); + + // Act + var agentTask = await taskManager.CreateTaskAsync(); + agentTask.History = this.CreateUserMessages(["Hello"]); + await hostAgent.ExecuteAgentTaskAsync(agentTask); + + // Assert + Assert.NotNull(agentTask); + Assert.NotNull(agentTask.Artifacts); + Assert.Single(agentTask.Artifacts); + Assert.NotNull(agentTask.Artifacts[0].Parts); + Assert.Single(agentTask.Artifacts[0].Parts); + Assert.Equal("Mock Response", agentTask.Artifacts[0].Parts[0].AsTextPart().Text); + } + + #region private + private List CreateUserMessages(string[] userMessages) + { + var messages = new List(); + + foreach (var userMessage in userMessages) + { + messages.Add(new Message() + { + Role = MessageRole.User, + Parts = [new TextPart() { Text = userMessage }], + }); + } + + return messages; + } + #endregion +} + +internal sealed class MockAgent : Agent +{ + public override async IAsyncEnumerable> InvokeAsync(ICollection messages, AgentThread? thread = null, AgentInvokeOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Delay(100, cancellationToken); + + yield return new AgentResponseItem(new ChatMessageContent(AuthorRole.Assistant, "Mock Response"), thread ?? new MockAgentThread()); + } + + public override async IAsyncEnumerable> InvokeStreamingAsync(ICollection messages, AgentThread? thread = null, AgentInvokeOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Delay(100, cancellationToken); + + yield return new AgentResponseItem(new StreamingChatMessageContent(AuthorRole.Assistant, "Mock Streaming Response"), thread ?? new MockAgentThread()); + } + + protected internal override Task CreateChannelAsync(CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + protected internal override IEnumerable GetChannelKeys() + { + throw new NotImplementedException(); + } + + protected internal override Task RestoreChannelAsync(string channelState, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } +} + +internal sealed class MockAgentThread : AgentThread +{ + protected override Task CreateInternalAsync(CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + protected override Task DeleteInternalAsync(CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + protected override Task OnNewMessageInternalAsync(ChatMessageContent newMessage, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } +} diff --git a/dotnet/src/Agents/UnitTests/A2A/BaseA2AClientTest.cs b/dotnet/src/Agents/UnitTests/A2A/BaseA2AClientTest.cs new file mode 100644 index 000000000000..52fb0620c475 --- /dev/null +++ b/dotnet/src/Agents/UnitTests/A2A/BaseA2AClientTest.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using SharpA2A.Core; + +namespace SemanticKernel.Agents.UnitTests.A2A; +public class BaseA2AClientTest : IDisposable +{ + internal MultipleHttpMessageHandlerStub MessageHandlerStub { get; } + internal HttpClient HttpClient { get; } + internal A2AClient Client { get; } + + internal BaseA2AClientTest() + { + this.MessageHandlerStub = new MultipleHttpMessageHandlerStub(); + this.HttpClient = new HttpClient(this.MessageHandlerStub, disposeHandler: false) + { + BaseAddress = new Uri("http://127.0.0.1/") + }; + this.Client = new A2AClient(this.HttpClient); + } + + /// + public void Dispose() + { + this.MessageHandlerStub.Dispose(); + this.HttpClient.Dispose(); + + GC.SuppressFinalize(this); + } + + protected AgentCard CreateAgentCard() + { + var capabilities = new AgentCapabilities() + { + Streaming = false, + PushNotifications = false, + }; + + var invoiceQuery = new AgentSkill() + { + Id = "id_invoice_agent", + Name = "InvoiceQuery", + Description = "Handles requests relating to invoices.", + Tags = ["invoice", "semantic-kernel"], + Examples = + [ + "List the latest invoices for Contoso.", + ], + }; + + return new AgentCard() + { + Name = "InvoiceAgent", + Description = "Handles requests relating to invoices.", + Url = "http://127.0.0.1/5000", + Version = "1.0.0", + DefaultInputModes = ["text"], + DefaultOutputModes = ["text"], + Capabilities = capabilities, + Skills = [invoiceQuery], + }; + } +} diff --git a/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj b/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj index 8752623526da..e0a5d2938e1e 100644 --- a/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj +++ b/dotnet/src/Agents/UnitTests/Agents.UnitTests.csproj @@ -19,11 +19,11 @@ - + - - + + @@ -39,6 +39,7 @@ + diff --git a/dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs b/dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs index 1f7d2d1e0fb2..def132c60102 100644 --- a/dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs @@ -176,6 +176,46 @@ public async Task VerifyChatCompletionAgentInvocationAsync() Times.Once); } + /// + /// Verify the invocation and response of . + /// + [Fact] + public async Task VerifyChatCompletionAgentInvocationsCanMutateProvidedKernelAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())).ReturnsAsync([new(AuthorRole.Assistant, "what?")]); + + var kernel = CreateKernel(mockService.Object); + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = kernel, + Arguments = [], + }; + + // Act + AgentResponseItem[] result = await agent.InvokeAsync(Array.Empty() as ICollection).ToArrayAsync(); + + // Assert + Assert.Single(result); + + mockService.Verify( + x => + x.GetChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + kernel, // Use the same kernel instance + It.IsAny()), + Times.Once); + } + /// /// Verify the invocation and response of using . /// @@ -195,7 +235,7 @@ public async Task VerifyChatClientAgentInvocationAsync() { Instructions = "test instructions", Kernel = CreateKernel(mockService.Object), - Arguments = [], + Arguments = new(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), }; // Act @@ -208,7 +248,7 @@ public async Task VerifyChatClientAgentInvocationAsync() x => x.GetResponseAsync( It.IsAny>(), - It.IsAny(), + It.Is(o => GetKernelFromChatOptions(o) == agent.Kernel), It.IsAny()), Times.Once); } @@ -258,6 +298,52 @@ public async Task VerifyChatCompletionAgentStreamingAsync() Times.Once); } + /// + /// Verify the streaming invocation and response of . + /// + [Fact] + public async Task VerifyChatCompletionAgentStreamingCanMutateProvidedKernelAsync() + { + // Arrange + StreamingChatMessageContent[] returnContent = + [ + new(AuthorRole.Assistant, "wh"), + new(AuthorRole.Assistant, "at?"), + ]; + + Mock mockService = new(); + mockService.Setup( + s => s.GetStreamingChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())).Returns(returnContent.ToAsyncEnumerable()); + + var kernel = CreateKernel(mockService.Object); + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = kernel, + Arguments = [], + }; + + // Act + AgentResponseItem[] result = await agent.InvokeStreamingAsync(Array.Empty() as ICollection).ToArrayAsync(); + + // Assert + Assert.Equal(2, result.Length); + + mockService.Verify( + x => + x.GetStreamingChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + kernel, // Use the same kernel instance + It.IsAny()), + Times.Once); + } + /// /// Verify the streaming invocation and response of using . /// @@ -283,7 +369,7 @@ public async Task VerifyChatClientAgentStreamingAsync() { Instructions = "test instructions", Kernel = CreateKernel(mockService.Object), - Arguments = [], + Arguments = new(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), }; // Act @@ -296,7 +382,7 @@ public async Task VerifyChatClientAgentStreamingAsync() x => x.GetStreamingResponseAsync( It.IsAny>(), - It.IsAny(), + It.Is(o => GetKernelFromChatOptions(o) == agent.Kernel), It.IsAny()), Times.Once); } @@ -373,6 +459,414 @@ public void VerifyChatCompletionChannelKeys() Assert.NotEqual(agent3.GetChannelKeys(), agent5.GetChannelKeys()); } + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is false and AIFunctions exist. + /// + [Fact] + public async Task VerifyChatCompletionAgentThrowsWhenUseImmutableKernelFalseWithAIFunctionsAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())).ReturnsAsync([new(AuthorRole.Assistant, "what?")]); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = CreateKernel(mockService.Object), + Arguments = [], + UseImmutableKernel = false // Explicitly set to false + }; + + var thread = new ChatHistoryAgentThread(); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeAsync(Array.Empty() as ICollection, thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is default (false) and AIFunctions exist. + /// + [Fact] + public async Task VerifyChatCompletionAgentThrowsWhenUseImmutableKernelDefaultWithAIFunctionsAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())).ReturnsAsync([new(AuthorRole.Assistant, "what?")]); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = CreateKernel(mockService.Object), + Arguments = [] + // UseImmutableKernel not set, should default to false + }; + + var thread = new ChatHistoryAgentThread(); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeAsync(Array.Empty() as ICollection, thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that kernel remains immutable when UseImmutableKernel is true. + /// + [Fact] + public async Task VerifyChatCompletionAgentKernelImmutabilityWhenUseImmutableKernelTrueAsync() + { + // Arrange + Mock mockService = new(); + Kernel capturedKernel = null!; + mockService.Setup( + s => s.GetChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((_, _, kernel, _) => capturedKernel = kernel) + .ReturnsAsync([new(AuthorRole.Assistant, "what?")]); + + var originalKernel = CreateKernel(mockService.Object); + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = originalKernel, + Arguments = [], + UseImmutableKernel = true + }; + + var thread = new ChatHistoryAgentThread(); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act + AgentResponseItem[] result = await agent.InvokeAsync(Array.Empty() as ICollection, thread: thread).ToArrayAsync(); + + // Assert + Assert.Single(result); + + // Verify original kernel was not modified + Assert.Equal(originalPluginCount, originalKernel.Plugins.Count); + + // Verify a different kernel instance was used for the service call + Assert.NotSame(originalKernel, capturedKernel); + + // Verify the captured kernel has the additional plugin from AIContext + Assert.True(capturedKernel.Plugins.Count > originalPluginCount); + Assert.Contains(capturedKernel.Plugins, p => p.Name == "Tools"); + } + + /// + /// Verify that mutable kernel behavior works when UseImmutableKernel is false and no AIFunctions exist. + /// + [Fact] + public async Task VerifyChatCompletionAgentMutableKernelWhenUseImmutableKernelFalseNoAIFunctionsAsync() + { + // Arrange + Mock mockService = new(); + Kernel capturedKernel = null!; + mockService.Setup( + s => s.GetChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((_, _, kernel, _) => capturedKernel = kernel) + .ReturnsAsync([new(AuthorRole.Assistant, "what?")]); + + var originalKernel = CreateKernel(mockService.Object); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [] // Empty AIFunctions list + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = originalKernel, + Arguments = [], + UseImmutableKernel = false + }; + + var thread = new ChatHistoryAgentThread(); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act + AgentResponseItem[] result = await agent.InvokeAsync(Array.Empty() as ICollection, thread: thread).ToArrayAsync(); + + // Assert + Assert.Single(result); + + // Verify the same kernel instance was used (mutable behavior) + Assert.Same(originalKernel, capturedKernel); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is false and AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyChatCompletionAgentStreamingThrowsWhenUseImmutableKernelFalseWithAIFunctionsAsync() + { + // Arrange + StreamingChatMessageContent[] returnContent = + [ + new(AuthorRole.Assistant, "wh"), + new(AuthorRole.Assistant, "at?"), + ]; + + Mock mockService = new(); + mockService.Setup( + s => s.GetStreamingChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())).Returns(returnContent.ToAsyncEnumerable()); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = CreateKernel(mockService.Object), + Arguments = [], + UseImmutableKernel = false // Explicitly set to false + }; + + var thread = new ChatHistoryAgentThread(); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeStreamingAsync(Array.Empty() as ICollection, thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is default (false) and AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyChatCompletionAgentStreamingThrowsWhenUseImmutableKernelDefaultWithAIFunctionsAsync() + { + // Arrange + StreamingChatMessageContent[] returnContent = + [ + new(AuthorRole.Assistant, "wh"), + new(AuthorRole.Assistant, "at?"), + ]; + + Mock mockService = new(); + mockService.Setup( + s => s.GetStreamingChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())).Returns(returnContent.ToAsyncEnumerable()); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = CreateKernel(mockService.Object), + Arguments = [] + // UseImmutableKernel not set, should default to false + }; + + var thread = new ChatHistoryAgentThread(); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeStreamingAsync(Array.Empty() as ICollection, thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that kernel remains immutable when UseImmutableKernel is true (streaming). + /// + [Fact] + public async Task VerifyChatCompletionAgentStreamingKernelImmutabilityWhenUseImmutableKernelTrueAsync() + { + // Arrange + StreamingChatMessageContent[] returnContent = + [ + new(AuthorRole.Assistant, "wh"), + new(AuthorRole.Assistant, "at?"), + ]; + + Mock mockService = new(); + Kernel capturedKernel = null!; + mockService.Setup( + s => s.GetStreamingChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((_, _, kernel, _) => capturedKernel = kernel) + .Returns(returnContent.ToAsyncEnumerable()); + + var originalKernel = CreateKernel(mockService.Object); + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = originalKernel, + Arguments = [], + UseImmutableKernel = true + }; + + var thread = new ChatHistoryAgentThread(); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act + AgentResponseItem[] result = await agent.InvokeStreamingAsync(Array.Empty() as ICollection, thread: thread).ToArrayAsync(); + + // Assert + Assert.Equal(2, result.Length); + + // Verify original kernel was not modified + Assert.Equal(originalPluginCount, originalKernel.Plugins.Count); + + // Verify a different kernel instance was used for the service call + Assert.NotSame(originalKernel, capturedKernel); + + // Verify the captured kernel has the additional plugin from AIContext + Assert.True(capturedKernel.Plugins.Count > originalPluginCount); + Assert.Contains(capturedKernel.Plugins, p => p.Name == "Tools"); + } + + /// + /// Verify that mutable kernel behavior works when UseImmutableKernel is false and no AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyChatCompletionAgentStreamingMutableKernelWhenUseImmutableKernelFalseNoAIFunctionsAsync() + { + // Arrange + StreamingChatMessageContent[] returnContent = + [ + new(AuthorRole.Assistant, "wh"), + new(AuthorRole.Assistant, "at?"), + ]; + + Mock mockService = new(); + Kernel capturedKernel = null!; + mockService.Setup( + s => s.GetStreamingChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((_, _, kernel, _) => capturedKernel = kernel) + .Returns(returnContent.ToAsyncEnumerable()); + + var originalKernel = CreateKernel(mockService.Object); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [] // Empty AIFunctions list + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = originalKernel, + Arguments = [], + UseImmutableKernel = false + }; + + var thread = new ChatHistoryAgentThread(); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act + AgentResponseItem[] result = await agent.InvokeStreamingAsync(Array.Empty() as ICollection, thread: thread).ToArrayAsync(); + + // Assert + Assert.Equal(2, result.Length); + + // Verify the same kernel instance was used (mutable behavior) + Assert.Same(originalKernel, capturedKernel); + } + private static Kernel CreateKernel(IChatCompletionService chatCompletionService) { var builder = Kernel.CreateBuilder(); @@ -386,4 +880,46 @@ private static Kernel CreateKernel(IChatClient chatClient) builder.Services.AddSingleton(chatClient); return builder.Build(); } + + /// + /// Gets the Kernel property from ChatOptions using reflection. + /// + /// The ChatOptions instance to extract Kernel from. + /// The Kernel instance if found; otherwise, null. + private static Kernel? GetKernelFromChatOptions(ChatOptions options) + { + // Use reflection to try to get the Kernel property + var kernelProperty = options.GetType().GetProperty("Kernel", + System.Reflection.BindingFlags.Public | + System.Reflection.BindingFlags.NonPublic | + System.Reflection.BindingFlags.Instance); + + if (kernelProperty != null) + { + return kernelProperty.GetValue(options) as Kernel; + } + + return null; + } + + /// + /// Helper class for testing AIFunction behavior. + /// + private sealed class TestAIFunction : AIFunction + { + public TestAIFunction(string name, string description = "") + { + this.Name = name; + this.Description = description; + } + + public override string Name { get; } + + public override string Description { get; } + + protected override ValueTask InvokeCoreAsync(AIFunctionArguments? arguments = null, CancellationToken cancellationToken = default) + { + return ValueTask.FromResult("Test result"); + } + } } diff --git a/dotnet/src/Agents/UnitTests/Extensions/ResponseItemExtensionsTests.cs b/dotnet/src/Agents/UnitTests/Extensions/ResponseItemExtensionsTests.cs index 5ca2bc65aae0..8683a19f379f 100644 --- a/dotnet/src/Agents/UnitTests/Extensions/ResponseItemExtensionsTests.cs +++ b/dotnet/src/Agents/UnitTests/Extensions/ResponseItemExtensionsTests.cs @@ -35,7 +35,8 @@ public void VerifyToChatMessageContentFromInputText(string creationMethod, strin // Act var messageContent = responseItem.ToChatMessageContent(); - // Assert + // Assert + Assert.NotNull(messageContent); Assert.Equal(new AuthorRole(roleLabel), messageContent.Role); Assert.Single(messageContent.Items); Assert.IsType(messageContent.Items[0]); @@ -53,6 +54,7 @@ public void VerifyToChatMessageContentFromInputImage() var messageContent = responseItem.ToChatMessageContent(); // Assert + Assert.NotNull(messageContent); Assert.Equal(AuthorRole.User, messageContent.Role); Assert.Single(messageContent.Items); Assert.IsType(messageContent.Items[0]); @@ -71,6 +73,7 @@ public void VerifyToChatMessageContentFromInputFile() var messageContent = responseItem.ToChatMessageContent(); // Assert + Assert.NotNull(messageContent); Assert.Equal(AuthorRole.User, messageContent.Role); Assert.Single(messageContent.Items); Assert.IsType(messageContent.Items[0]); @@ -88,9 +91,25 @@ public void VerifyToChatMessageContentFromRefusal() var messageContent = responseItem.ToChatMessageContent(); // Assert + Assert.NotNull(messageContent); Assert.Equal(AuthorRole.User, messageContent.Role); Assert.Single(messageContent.Items); Assert.IsType(messageContent.Items[0]); Assert.Equal("refusal", ((TextContent)messageContent.Items[0]).Text); } + + [Fact] + public void VerifyToChatMessageContentFromReasoning() + { + // Arrange + IEnumerable summaryParts = ["Foo"]; + ReasoningResponseItem responseItem = ResponseItem.CreateReasoningItem(summaryParts); + + // Act + var messageContent = responseItem.ToChatMessageContent(); + + // Assert + Assert.NotNull(messageContent); + Assert.Equal("Foo", messageContent.Content); + } } diff --git a/dotnet/src/Agents/UnitTests/OpenAI/Extensions/OpenAIResponseExtensionsTests.cs b/dotnet/src/Agents/UnitTests/OpenAI/Extensions/OpenAIResponseExtensionsTests.cs index bdcb2837349d..6fa230534427 100644 --- a/dotnet/src/Agents/UnitTests/OpenAI/Extensions/OpenAIResponseExtensionsTests.cs +++ b/dotnet/src/Agents/UnitTests/OpenAI/Extensions/OpenAIResponseExtensionsTests.cs @@ -50,8 +50,8 @@ public void VerifyToChatMessageContentWithResponseItem() ResponseItem functionCall = ResponseItem.CreateFunctionCallItem("callId", "functionName", new BinaryData("{}")); // Act - ChatMessageContent userMessageContent = userMessage.ToChatMessageContent(); - ChatMessageContent functionCallContent = functionCall.ToChatMessageContent(); + ChatMessageContent? userMessageContent = userMessage.ToChatMessageContent(); + ChatMessageContent? functionCallContent = functionCall.ToChatMessageContent(); // Assert Assert.NotNull(userMessageContent); diff --git a/dotnet/src/Agents/UnitTests/OpenAI/OpenAIAssistantAgentTests.cs b/dotnet/src/Agents/UnitTests/OpenAI/OpenAIAssistantAgentTests.cs index b2671b62b9ef..309b4d599dfc 100644 --- a/dotnet/src/Agents/UnitTests/OpenAI/OpenAIAssistantAgentTests.cs +++ b/dotnet/src/Agents/UnitTests/OpenAI/OpenAIAssistantAgentTests.cs @@ -5,11 +5,14 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.Agents.OpenAI; using Microsoft.SemanticKernel.ChatCompletion; +using Moq; using OpenAI.Assistants; using Xunit; @@ -50,7 +53,7 @@ public async Task VerifyOpenAIAssistantAgentGroupChatAsync() // Assert Assert.Single(messages); Assert.Single(messages[0].Items); - Assert.IsType(messages[0].Items[0]); + Assert.IsType(messages[0].Items[0]); // Arrange this.SetupResponse(HttpStatusCode.OK, OpenAIAssistantResponseContent.DeleteThread); @@ -87,7 +90,7 @@ public async Task VerifyOpenAIAssistantAgentInvokeWithThreadAsync() // Assert Assert.Single(messages); Assert.Single(messages[0].Message.Items); - Assert.IsType(messages[0].Message.Items[0]); + Assert.IsType(messages[0].Message.Items[0]); Assert.Equal("Hello, how can I help you?", messages[0].Message.Content); } @@ -122,7 +125,7 @@ public async Task VerifyOpenAIAssistantAgentInvokeMultipleMessagesWithThreadAsyn // Assert Assert.Single(messages); Assert.Single(messages[0].Message.Items); - Assert.IsType(messages[0].Message.Items[0]); + Assert.IsType(messages[0].Message.Items[0]); Assert.Equal("How can I help you?", messages[0].Message.Content); } @@ -194,7 +197,7 @@ public async Task VerifyOpenAIAssistantAgentChatTextMessageWithAnnotationAsync() // Assert Assert.Single(messages); Assert.Equal(2, messages[0].Items.Count); - Assert.NotNull(messages[0].Items.SingleOrDefault(c => c is TextContent)); + Assert.NotNull(messages[0].Items.SingleOrDefault(c => c is Microsoft.SemanticKernel.TextContent)); Assert.NotNull(messages[0].Items.SingleOrDefault(c => c is AnnotationContent)); } @@ -328,7 +331,387 @@ public async Task VerifyOpenAIAssistantAgentWithFunctionCallAsync() // Assert Assert.Single(messages); Assert.Single(messages[0].Items); - Assert.IsType(messages[0].Items[0]); + Assert.IsType(messages[0].Items[0]); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is false and AIFunctions exist. + /// + [Fact] + public async Task VerifyOpenAIAssistantAgentThrowsWhenUseImmutableKernelFalseWithAIFunctionsAsync() + { + // Arrange + OpenAIAssistantAgent agent = await this.CreateAgentAsync(); + agent.UseImmutableKernel = false; // Explicitly set to false + + // Initialize agent channel + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread, + OpenAIAssistantResponseContent.Run.CreateRun, + OpenAIAssistantResponseContent.Run.CompletedRun, + OpenAIAssistantResponseContent.Run.MessageSteps, + OpenAIAssistantResponseContent.GetTextMessage()); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIAssistantAgentThread(agent.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hi"), thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is default (false) and AIFunctions exist. + /// + [Fact] + public async Task VerifyOpenAIAssistantAgentThrowsWhenUseImmutableKernelDefaultWithAIFunctionsAsync() + { + // Arrange + OpenAIAssistantAgent agent = await this.CreateAgentAsync(); + // UseImmutableKernel not set, should default to false + + // Initialize agent channel + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread, + OpenAIAssistantResponseContent.Run.CreateRun, + OpenAIAssistantResponseContent.Run.CompletedRun, + OpenAIAssistantResponseContent.Run.MessageSteps, + OpenAIAssistantResponseContent.GetTextMessage()); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIAssistantAgentThread(agent.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hi"), thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that kernel remains immutable when UseImmutableKernel is true. + /// + [Fact] + public async Task VerifyOpenAIAssistantAgentKernelImmutabilityWhenUseImmutableKernelTrueAsync() + { + // Arrange + OpenAIAssistantAgent agent = await this.CreateAgentAsync(); + agent.UseImmutableKernel = true; + + var originalKernel = agent.Kernel; + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIAssistantAgentThread(agent.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread, + // Create message response + OpenAIAssistantResponseContent.GetTextMessage("Hi"), + OpenAIAssistantResponseContent.Run.CreateRun, + OpenAIAssistantResponseContent.Run.CompletedRun, + OpenAIAssistantResponseContent.Run.MessageSteps, + OpenAIAssistantResponseContent.GetTextMessage("Hello, how can I help you?")); + + // Act + AgentResponseItem[] result = await agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hi"), thread: thread).ToArrayAsync(); + + // Assert + Assert.Single(result); + + // Verify original kernel was not modified + Assert.Equal(originalPluginCount, originalKernel.Plugins.Count); + + // The kernel should remain unchanged since UseImmutableKernel=true creates a clone + Assert.Same(originalKernel, agent.Kernel); + } + + /// + /// Verify that mutable kernel behavior works when UseImmutableKernel is false and no AIFunctions exist. + /// + [Fact] + public async Task VerifyOpenAIAssistantAgentMutableKernelWhenUseImmutableKernelFalseNoAIFunctionsAsync() + { + // Arrange + OpenAIAssistantAgent agent = await this.CreateAgentAsync(); + agent.UseImmutableKernel = false; + + var originalKernel = agent.Kernel; + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [] // Empty AIFunctions list + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIAssistantAgentThread(agent.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread, + // Create message response + OpenAIAssistantResponseContent.GetTextMessage("Hi"), + OpenAIAssistantResponseContent.Run.CreateRun, + OpenAIAssistantResponseContent.Run.CompletedRun, + OpenAIAssistantResponseContent.Run.MessageSteps, + OpenAIAssistantResponseContent.GetTextMessage("Hello, how can I help you?")); + + // Act + AgentResponseItem[] result = await agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, "Hi"), thread: thread).ToArrayAsync(); + + // Assert + Assert.Single(result); + + // Verify the same kernel instance is still being used (mutable behavior) + Assert.Same(originalKernel, agent.Kernel); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is false and AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyOpenAIAssistantAgentStreamingThrowsWhenUseImmutableKernelFalseWithAIFunctionsAsync() + { + // Arrange + OpenAIAssistantAgent agent = await this.CreateAgentAsync(); + agent.UseImmutableKernel = false; // Explicitly set to false + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread, + // Create message response + OpenAIAssistantResponseContent.GetTextMessage("Hi"), + OpenAIAssistantResponseContent.Streaming.Response( + [ + OpenAIAssistantResponseContent.Streaming.CreateRun("created"), + OpenAIAssistantResponseContent.Streaming.CreateRun("queued"), + OpenAIAssistantResponseContent.Streaming.CreateRun("in_progress"), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("Hello, "), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("how can I "), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("help you?"), + OpenAIAssistantResponseContent.Streaming.CreateRun("completed"), + OpenAIAssistantResponseContent.Streaming.Done + ]), + OpenAIAssistantResponseContent.GetTextMessage("Hello, how can I help you?")); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIAssistantAgentThread(agent.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeStreamingAsync(new ChatMessageContent(AuthorRole.User, "Hi"), thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is default (false) and AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyOpenAIAssistantAgentStreamingThrowsWhenUseImmutableKernelDefaultWithAIFunctionsAsync() + { + // Arrange + OpenAIAssistantAgent agent = await this.CreateAgentAsync(); + // UseImmutableKernel not set, should default to false + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread, + // Create message response + OpenAIAssistantResponseContent.GetTextMessage("Hi"), + OpenAIAssistantResponseContent.Streaming.Response( + [ + OpenAIAssistantResponseContent.Streaming.CreateRun("created"), + OpenAIAssistantResponseContent.Streaming.CreateRun("queued"), + OpenAIAssistantResponseContent.Streaming.CreateRun("in_progress"), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("Hello, "), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("how can I "), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("help you?"), + OpenAIAssistantResponseContent.Streaming.CreateRun("completed"), + OpenAIAssistantResponseContent.Streaming.Done + ]), + OpenAIAssistantResponseContent.GetTextMessage("Hello, how can I help you?")); + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIAssistantAgentThread(agent.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeStreamingAsync(new ChatMessageContent(AuthorRole.User, "Hi"), thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that kernel remains immutable when UseImmutableKernel is true (streaming). + /// + [Fact] + public async Task VerifyOpenAIAssistantAgentStreamingKernelImmutabilityWhenUseImmutableKernelTrueAsync() + { + // Arrange + OpenAIAssistantAgent agent = await this.CreateAgentAsync(); + agent.UseImmutableKernel = true; + + var originalKernel = agent.Kernel; + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIAssistantAgentThread(agent.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread, + // Create message response + OpenAIAssistantResponseContent.GetTextMessage("Hi"), + OpenAIAssistantResponseContent.Streaming.Response( + [ + OpenAIAssistantResponseContent.Streaming.CreateRun("created"), + OpenAIAssistantResponseContent.Streaming.CreateRun("queued"), + OpenAIAssistantResponseContent.Streaming.CreateRun("in_progress"), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("Hello, "), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("how can I "), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("help you?"), + OpenAIAssistantResponseContent.Streaming.CreateRun("completed"), + OpenAIAssistantResponseContent.Streaming.Done + ]), + OpenAIAssistantResponseContent.GetTextMessage("Hello, how can I help you?")); + + // Act + AgentResponseItem[] result = await agent.InvokeStreamingAsync(new ChatMessageContent(AuthorRole.User, "Hi"), thread: thread).ToArrayAsync(); + + // Assert + Assert.True(result.Length > 0); + + // Verify original kernel was not modified + Assert.Equal(originalPluginCount, originalKernel.Plugins.Count); + + // The kernel should remain unchanged since UseImmutableKernel=true creates a clone + Assert.Same(originalKernel, agent.Kernel); + } + + /// + /// Verify that mutable kernel behavior works when UseImmutableKernel is false and no AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyOpenAIAssistantAgentStreamingMutableKernelWhenUseImmutableKernelFalseNoAIFunctionsAsync() + { + // Arrange + OpenAIAssistantAgent agent = await this.CreateAgentAsync(); + agent.UseImmutableKernel = false; + + var originalKernel = agent.Kernel; + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [] // Empty AIFunctions list + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIAssistantAgentThread(agent.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + this.SetupResponses( + HttpStatusCode.OK, + OpenAIAssistantResponseContent.CreateThread, + // Create message response + OpenAIAssistantResponseContent.GetTextMessage("Hi"), + OpenAIAssistantResponseContent.Streaming.Response( + [ + OpenAIAssistantResponseContent.Streaming.CreateRun("created"), + OpenAIAssistantResponseContent.Streaming.CreateRun("queued"), + OpenAIAssistantResponseContent.Streaming.CreateRun("in_progress"), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("Hello, "), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("how can I "), + OpenAIAssistantResponseContent.Streaming.DeltaMessage("help you?"), + OpenAIAssistantResponseContent.Streaming.CreateRun("completed"), + OpenAIAssistantResponseContent.Streaming.Done + ]), + OpenAIAssistantResponseContent.GetTextMessage("Hello, how can I help you?")); + + // Act + AgentResponseItem[] result = await agent.InvokeStreamingAsync(new ChatMessageContent(AuthorRole.User, "Hi"), thread: thread).ToArrayAsync(); + + // Assert + Assert.True(result.Length > 0); + + // Verify the same kernel instance is still being used (mutable behavior) + Assert.Same(originalKernel, agent.Kernel); } /// @@ -469,6 +852,27 @@ private sealed class MyPlugin public void MyFunction(int index) { } } + + /// + /// Helper class for testing AIFunction behavior. + /// + private sealed class TestAIFunction : AIFunction + { + public TestAIFunction(string name, string description = "") + { + this.Name = name; + this.Description = description; + } + + public override string Name { get; } + + public override string Description { get; } + + protected override ValueTask InvokeCoreAsync(AIFunctionArguments? arguments = null, CancellationToken cancellationToken = default) + { + return ValueTask.FromResult("Test result"); + } + } } #pragma warning restore CS0419 // Ambiguous reference in cref attribute diff --git a/dotnet/src/Agents/UnitTests/OpenAI/OpenAIResponseAgentTests.cs b/dotnet/src/Agents/UnitTests/OpenAI/OpenAIResponseAgentTests.cs index cb94960ebb14..2eade8a8789c 100644 --- a/dotnet/src/Agents/UnitTests/OpenAI/OpenAIResponseAgentTests.cs +++ b/dotnet/src/Agents/UnitTests/OpenAI/OpenAIResponseAgentTests.cs @@ -1,14 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.ComponentModel; using System.Linq; using System.Net.Http; +using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.Agents.OpenAI; using Microsoft.SemanticKernel.ChatCompletion; +using Moq; using Xunit; namespace SemanticKernel.Agents.UnitTests.OpenAI; @@ -141,6 +145,300 @@ public async Task VerifyInvokeWithFunctionCallingAsync(bool storeEnabled) Assert.Equal("The special soup is Clam Chowder, and it costs $9.99.", items[2].Message.Content); } + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is false and AIFunctions exist. + /// + [Fact] + public async Task VerifyOpenAIResponseAgentThrowsWhenUseImmutableKernelFalseWithAIFunctionsAsync() + { + // Arrange + var agent = new OpenAIResponseAgent(this.Client) + { + UseImmutableKernel = false, // Explicitly set to false + StoreEnabled = true, + }; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIResponseAgentThread(this.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeAsync("Hi", thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is default (false) and AIFunctions exist. + /// + [Fact] + public async Task VerifyOpenAIResponseAgentThrowsWhenUseImmutableKernelDefaultWithAIFunctionsAsync() + { + // Arrange + var agent = new OpenAIResponseAgent(this.Client) + { + StoreEnabled = true, + }; + // UseImmutableKernel not set, should default to false + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIResponseAgentThread(this.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeAsync("Hi", thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that kernel remains immutable when UseImmutableKernel is true. + /// + [Fact] + public async Task VerifyOpenAIResponseAgentKernelImmutabilityWhenUseImmutableKernelTrueAsync() + { + // Arrange + this.MessageHandlerStub.ResponsesToReturn.Add( + new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StringContent(InvokeResponse) } + ); + + var agent = new OpenAIResponseAgent(this.Client) + { + UseImmutableKernel = true, + StoreEnabled = true, + }; + + var originalKernel = agent.Kernel; + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIResponseAgentThread(this.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act + var result = await agent.InvokeAsync("Hi", thread: thread).ToArrayAsync(); + + // Assert + Assert.Single(result); + + // Verify original kernel was not modified + Assert.Equal(originalPluginCount, originalKernel.Plugins.Count); + + // The kernel should remain unchanged since UseImmutableKernel=true creates a clone + Assert.Same(originalKernel, agent.Kernel); + } + + /// + /// Verify that mutable kernel behavior works when UseImmutableKernel is false and no AIFunctions exist. + /// + [Fact] + public async Task VerifyOpenAIResponseAgentMutableKernelWhenUseImmutableKernelFalseNoAIFunctionsAsync() + { + // Arrange + this.MessageHandlerStub.ResponsesToReturn.Add( + new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StringContent(InvokeResponse) } + ); + + var agent = new OpenAIResponseAgent(this.Client) + { + UseImmutableKernel = false, + StoreEnabled = true, + }; + + var originalKernel = agent.Kernel; + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [] // Empty AIFunctions list + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIResponseAgentThread(this.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act + var result = await agent.InvokeAsync("Hi", thread: thread).ToArrayAsync(); + + // Assert + Assert.Single(result); + + // Verify the same kernel instance is still being used (mutable behavior) + Assert.Same(originalKernel, agent.Kernel); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is false and AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyOpenAIResponseAgentStreamingThrowsWhenUseImmutableKernelFalseWithAIFunctionsAsync() + { + // Arrange + var agent = new OpenAIResponseAgent(this.Client) + { + UseImmutableKernel = false, // Explicitly set to false + StoreEnabled = true, + }; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIResponseAgentThread(this.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeStreamingAsync("Hi", thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that InvalidOperationException is thrown when UseImmutableKernel is default (false) and AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyOpenAIResponseAgentStreamingThrowsWhenUseImmutableKernelDefaultWithAIFunctionsAsync() + { + // Arrange + var agent = new OpenAIResponseAgent(this.Client) + { + StoreEnabled = true, + }; + // UseImmutableKernel not set, should default to false + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIResponseAgentThread(this.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await agent.InvokeStreamingAsync("Hi", thread: thread).ToArrayAsync()); + + Assert.NotNull(exception); + } + + /// + /// Verify that kernel remains immutable when UseImmutableKernel is true (streaming). + /// + [Fact] + public async Task VerifyOpenAIResponseAgentStreamingKernelImmutabilityWhenUseImmutableKernelTrueAsync() + { + // Arrange + this.MessageHandlerStub.ResponsesToReturn.Add( + new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StringContent(InvokeStreamingResponse) } + ); + + var agent = new OpenAIResponseAgent(this.Client) + { + UseImmutableKernel = true, + StoreEnabled = true, + }; + + var originalKernel = agent.Kernel; + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [new TestAIFunction("TestFunction", "Test function description")] + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIResponseAgentThread(this.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act + var result = await agent.InvokeStreamingAsync("Hi", thread: thread).ToArrayAsync(); + + // Assert + Assert.True(result.Length > 0); + + // Verify original kernel was not modified + Assert.Equal(originalPluginCount, originalKernel.Plugins.Count); + + // The kernel should remain unchanged since UseImmutableKernel=true creates a clone + Assert.Same(originalKernel, agent.Kernel); + } + + /// + /// Verify that mutable kernel behavior works when UseImmutableKernel is false and no AIFunctions exist (streaming). + /// + [Fact] + public async Task VerifyOpenAIResponseAgentStreamingMutableKernelWhenUseImmutableKernelFalseNoAIFunctionsAsync() + { + // Arrange + this.MessageHandlerStub.ResponsesToReturn.Add( + new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StringContent(InvokeStreamingResponse) } + ); + + var agent = new OpenAIResponseAgent(this.Client) + { + UseImmutableKernel = false, + StoreEnabled = true, + }; + + var originalKernel = agent.Kernel; + var originalPluginCount = originalKernel.Plugins.Count; + + var mockAIContextProvider = new Mock(); + var aiContext = new AIContext + { + AIFunctions = [] // Empty AIFunctions list + }; + mockAIContextProvider.Setup(p => p.ModelInvokingAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(aiContext); + + var thread = new OpenAIResponseAgentThread(this.Client); + thread.AIContextProviders.Add(mockAIContextProvider.Object); + + // Act + var result = await agent.InvokeStreamingAsync("Hi", thread: thread).ToArrayAsync(); + + // Assert + Assert.True(result.Length > 0); + + // Verify the same kernel instance is still being used (mutable behavior) + Assert.Same(originalKernel, agent.Kernel); + } + #region private private const string InvokeResponse = """ @@ -433,6 +731,27 @@ public void MyFunction3(string value, int[] indices) { } } + /// + /// Helper class for testing AIFunction behavior. + /// + private sealed class TestAIFunction : AIFunction + { + public TestAIFunction(string name, string description = "") + { + this.Name = name; + this.Description = description; + } + + public override string Name { get; } + + public override string Description { get; } + + protected override ValueTask InvokeCoreAsync(AIFunctionArguments? arguments = null, CancellationToken cancellationToken = default) + { + return ValueTask.FromResult("Test result"); + } + } + private sealed class MenuPlugin { [KernelFunction, Description("Provides a list of specials from the menu.")] diff --git a/dotnet/src/Agents/Yaml/AgentMetadataTypeConverter.cs b/dotnet/src/Agents/Yaml/AgentMetadataTypeConverter.cs index 2ba910a1c6fe..948632331ff2 100644 --- a/dotnet/src/Agents/Yaml/AgentMetadataTypeConverter.cs +++ b/dotnet/src/Agents/Yaml/AgentMetadataTypeConverter.cs @@ -24,7 +24,7 @@ public bool Accepts(Type type) } /// - public object? ReadYaml(IParser parser, Type type) + public object? ReadYaml(IParser parser, Type type, ObjectDeserializer rootDeserializer) { s_deserializer ??= new DeserializerBuilder() .WithNamingConvention(UnderscoredNamingConvention.Instance) @@ -55,7 +55,7 @@ public bool Accepts(Type type) } /// - public void WriteYaml(IEmitter emitter, object? value, Type type) + public void WriteYaml(IEmitter emitter, object? value, Type type, ObjectSerializer serializer) { throw new NotImplementedException(); } diff --git a/dotnet/src/Agents/Yaml/ModelConfigurationTypeConverter.cs b/dotnet/src/Agents/Yaml/ModelConfigurationTypeConverter.cs index fe3907e194ba..c812e860b96c 100644 --- a/dotnet/src/Agents/Yaml/ModelConfigurationTypeConverter.cs +++ b/dotnet/src/Agents/Yaml/ModelConfigurationTypeConverter.cs @@ -24,7 +24,7 @@ public bool Accepts(Type type) } /// - public object? ReadYaml(IParser parser, Type type) + public object? ReadYaml(IParser parser, Type type, ObjectDeserializer rootDeserializer) { s_deserializer ??= new DeserializerBuilder() .WithNamingConvention(UnderscoredNamingConvention.Instance) @@ -55,7 +55,7 @@ public bool Accepts(Type type) } /// - public void WriteYaml(IEmitter emitter, object? value, Type type) + public void WriteYaml(IEmitter emitter, object? value, Type type, ObjectSerializer serializer) { throw new NotImplementedException(); } diff --git a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj index 250dd6b7b94f..1149eda36a71 100644 --- a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj @@ -9,7 +9,7 @@ enable - $(NoWarn);CS1591;CA2007;VSTHRD111;SKEXP0001;SKEXP0070 + $(NoWarn);CS1591;CA2007;VSTHRD111;SKEXP0001 diff --git a/dotnet/src/Connectors/Connectors.Amazon/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.Amazon/AssemblyInfo.cs deleted file mode 100644 index fe66371dbc58..000000000000 --- a/dotnet/src/Connectors/Connectors.Amazon/AssemblyInfo.cs +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; - -// This assembly is currently experimental. -[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.Amazon/Connectors.Amazon.csproj b/dotnet/src/Connectors/Connectors.Amazon/Connectors.Amazon.csproj index 1f758c47b790..6dda882145b0 100644 --- a/dotnet/src/Connectors/Connectors.Amazon/Connectors.Amazon.csproj +++ b/dotnet/src/Connectors/Connectors.Amazon/Connectors.Amazon.csproj @@ -6,7 +6,7 @@ $(AssemblyName) net8.0;netstandard2.0 alpha - $(NoWarn);SKEXP0001;SKEXP0070 + $(NoWarn);SKEXP0001 diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj index d7e1f65ec24f..765c56caf618 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj @@ -8,7 +8,7 @@ enable disable false - $(NoWarn);CA2007,CA1806,CS1591,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0070 + $(NoWarn);CA2007,CA1806,CS1591,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001 diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/AssemblyInfo.cs deleted file mode 100644 index fe66371dbc58..000000000000 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/AssemblyInfo.cs +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; - -// This assembly is currently experimental. -[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj b/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj index 817994449fc3..02aa5c6e070b 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj @@ -5,7 +5,7 @@ Microsoft.SemanticKernel.Connectors.AzureAIInference $(AssemblyName) net8.0;netstandard2.0 - $(NoWarn);NU5104;SKEXP0001,SKEXP0070 + $(NoWarn);NU5104;SKEXP0001 false beta diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs index ed453295da5b..ad1077504939 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs @@ -1362,12 +1362,12 @@ public async Task FunctionResultsCanBeProvidedToLLMAsManyResultsInOneChatMessage public async Task GetChatMessageContentShouldSendMutatedChatHistoryToLLM() { // Arrange - static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + static Task MutateChatHistory(AutoFunctionInvocationContext context, Func next) { // Remove the function call messages from the chat history to reduce token count. context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. - next(context); + return next(context); } var kernel = new Kernel(); @@ -1433,12 +1433,12 @@ static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + static Task MutateChatHistory(AutoFunctionInvocationContext context, Func next) { // Remove the function call messages from the chat history to reduce token count. context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. - next(context); + return next(context); } var kernel = new Kernel(); diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj index 4468b0001333..294069980ae7 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj @@ -8,7 +8,7 @@ enable disable false - $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0050,SKEXP0070 + $(NoWarn);CA2007,CA1806,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0050 diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index 913878b8c59d..877b80debf67 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -528,14 +528,14 @@ public void LabelsFromPromptReturnsAsExpected() var prompt = "prompt-example"; var executionSettings = new GeminiPromptExecutionSettings { - Labels = "Key1:Value1" + Labels = new Dictionary { { "key1", "value1" }, { "key2", "value2" } } }; // Act var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings); // Assert - Assert.NotNull(request.Configuration); + Assert.NotNull(request.Labels); Assert.Equal(executionSettings.Labels, request.Labels); } @@ -569,7 +569,7 @@ public void LabelsFromChatHistoryReturnsAsExpected() chatHistory.AddUserMessage("user-message2"); var executionSettings = new GeminiPromptExecutionSettings { - Labels = "Key1:Value1" + Labels = new Dictionary { { "key1", "value1" }, { "key2", "value2" } } }; // Act diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs index 5ba6895da18b..f368ac054f88 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/GeminiPromptExecutionSettingsTests.cs @@ -29,7 +29,7 @@ public void ItCreatesGeminiExecutionSettingsWithCorrectDefaults() Assert.Null(executionSettings.AudioTimestamp); Assert.Null(executionSettings.ResponseMimeType); Assert.Null(executionSettings.ResponseSchema); - Assert.Equal(GeminiPromptExecutionSettings.DefaultTextMaxTokens, executionSettings.MaxTokens); + Assert.Null(executionSettings.MaxTokens); } [Fact] diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleAIGeminiChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleAIGeminiChatCompletionServiceTests.cs index 6291fa898b99..06e5361de75e 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleAIGeminiChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleAIGeminiChatCompletionServiceTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.IO; using System.Net.Http; using System.Text; @@ -69,15 +70,13 @@ public async Task RequestCachedContentWorksCorrectlyAsync(string? cachedContent) } } - [Theory] - [InlineData(null)] - [InlineData("key:value")] - [InlineData("")] - public async Task RequestLabelsWorksCorrectlyAsync(string? labels) + [Fact] + public async Task RequestLabelsWorksCorrectlyAsync() { // Arrange string model = "fake-model"; var sut = new GoogleAIGeminiChatCompletionService(model, "key", httpClient: this._httpClient); + var labels = new Dictionary { { "key1", "value1" }, { "key2", "value2" } }; // Act var result = await sut.GetChatMessageContentAsync("my prompt", new GeminiPromptExecutionSettings { Labels = labels }); @@ -87,19 +86,28 @@ public async Task RequestLabelsWorksCorrectlyAsync(string? labels) Assert.NotNull(this._messageHandlerStub.RequestContent); var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent); - if (labels is not null) - { - Assert.Contains($"\"labels\":\"{labels}\"", requestBody); - } - else - { - // Then no quality is provided, it should not be included in the request body - Assert.DoesNotContain("labels", requestBody); - } + Assert.Contains("\"labels\":{\"key1\":\"value1\",\"key2\":\"value2\"}", requestBody); + } + + [Fact] + public async Task RequestLabelsNullWorksCorrectlyAsync() + { + // Arrange + string model = "fake-model"; + var sut = new GoogleAIGeminiChatCompletionService(model, "key", httpClient: this._httpClient); + + // Act + var result = await sut.GetChatMessageContentAsync("my prompt", new GeminiPromptExecutionSettings { Labels = null }); + + // Assert + Assert.NotNull(result); + Assert.NotNull(this._messageHandlerStub.RequestContent); + + var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent); + Assert.DoesNotContain("labels", requestBody); } [Theory] - [InlineData(null, false)] [InlineData(0, true)] [InlineData(500, true)] [InlineData(2048, true)] diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/VertexAIGeminiChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/VertexAIGeminiChatCompletionServiceTests.cs index 9140924bc011..179705981186 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/VertexAIGeminiChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/VertexAIGeminiChatCompletionServiceTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.IO; using System.Net.Http; using System.Text; @@ -80,15 +81,13 @@ public async Task RequestCachedContentWorksCorrectlyAsync(string? cachedContent) } } - [Theory] - [InlineData(null)] - [InlineData("key:value")] - [InlineData("")] - public async Task RequestLabelsWorksCorrectlyAsync(string? labels) + [Fact] + public async Task RequestLabelsWorksCorrectlyAsync() { // Arrange string model = "fake-model"; - var sut = new GoogleAIGeminiChatCompletionService(model, "key", httpClient: this._httpClient); + var sut = new VertexAIGeminiChatCompletionService(model, () => new ValueTask("key"), "location", "project", httpClient: this._httpClient); + var labels = new Dictionary { { "key1", "value1" }, { "key2", "value2" } }; // Act var result = await sut.GetChatMessageContentAsync("my prompt", new GeminiPromptExecutionSettings { Labels = labels }); @@ -98,15 +97,25 @@ public async Task RequestLabelsWorksCorrectlyAsync(string? labels) Assert.NotNull(this._messageHandlerStub.RequestContent); var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent); - if (labels is not null) - { - Assert.Contains($"\"labels\":\"{labels}\"", requestBody); - } - else - { - // Then no quality is provided, it should not be included in the request body - Assert.DoesNotContain("labels", requestBody); - } + Assert.Contains("\"labels\":{\"key1\":\"value1\",\"key2\":\"value2\"}", requestBody); + } + + [Fact] + public async Task RequestLabelsNullWorksCorrectlyAsync() + { + // Arrange + string model = "fake-model"; + var sut = new VertexAIGeminiChatCompletionService(model, () => new ValueTask("key"), "location", "project", httpClient: this._httpClient); + + // Act + var result = await sut.GetChatMessageContentAsync("my prompt", new GeminiPromptExecutionSettings { Labels = null }); + + // Assert + Assert.NotNull(result); + Assert.NotNull(this._messageHandlerStub.RequestContent); + + var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent); + Assert.DoesNotContain("labels", requestBody); } [Theory] diff --git a/dotnet/src/Connectors/Connectors.Google/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.Google/AssemblyInfo.cs deleted file mode 100644 index fe66371dbc58..000000000000 --- a/dotnet/src/Connectors/Connectors.Google/AssemblyInfo.cs +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; - -// This assembly is currently experimental. -[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj index 4d5a3deb9906..3fa32d4d83bf 100644 --- a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj +++ b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj @@ -6,7 +6,7 @@ $(AssemblyName) net8.0;netstandard2.0 alpha - $(NoWarn);SKEXP0001,SKEXP0070 + $(NoWarn);SKEXP0001 diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs index 1d69922b94fb..5d4b917ee1e7 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs @@ -48,7 +48,7 @@ internal sealed class GeminiRequest [JsonPropertyName("labels")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] - public string? Labels { get; set; } + public IDictionary? Labels { get; set; } public void AddFunction(GeminiFunction function) { @@ -450,7 +450,12 @@ private static void AddSafetySettings(GeminiPromptExecutionSettings executionSet private static void AddAdditionalBodyFields(GeminiPromptExecutionSettings executionSettings, GeminiRequest request) { request.CachedContent = executionSettings.CachedContent; - request.Labels = executionSettings.Labels; + + if (executionSettings.Labels is not null) + { + request.Labels = executionSettings.Labels; + } + if (executionSettings.ThinkingConfig is not null) { request.Configuration ??= new ConfigurationElement(); diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Utils/GeminiKernelFunctionMetadataExtensions.cs b/dotnet/src/Connectors/Connectors.Google/Extensions/GeminiKernelFunctionMetadataExtensions.cs similarity index 95% rename from dotnet/src/Connectors/Connectors.Google.UnitTests/Utils/GeminiKernelFunctionMetadataExtensions.cs rename to dotnet/src/Connectors/Connectors.Google/Extensions/GeminiKernelFunctionMetadataExtensions.cs index a716c48a2074..58cbdf0d28c0 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Utils/GeminiKernelFunctionMetadataExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Google/Extensions/GeminiKernelFunctionMetadataExtensions.cs @@ -1,10 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; -using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Google; -namespace SemanticKernel.Connectors.Google.UnitTests; +namespace Microsoft.SemanticKernel; /// /// Extensions for specific to the Gemini connector. diff --git a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs index c4d4514feb5f..5f8bc0874cc2 100644 --- a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs @@ -28,16 +28,11 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings private string? _responseMimeType; private object? _responseSchema; private string? _cachedContent; - private string? _labels; + private IDictionary? _labels; private IList? _safetySettings; private GeminiToolCallBehavior? _toolCallBehavior; private GeminiThinkingConfig? _thinkingConfig; - /// - /// Default max tokens for a text generation. - /// - public static int DefaultTextMaxTokens { get; } = 256; - /// /// Temperature controls the randomness of the completion. /// The higher the temperature, the more random the completion. @@ -152,9 +147,12 @@ public IList? SafetySettings /// Gets or sets the labels. /// /// - /// Metadata that can be added to the API call in the format of key-value pairs. + /// The labels with user-defined metadata for the request. It is used for billing and reporting only. + /// label keys and values can be no longer than 63 characters (Unicode codepoints) and can only contain lowercase letters, numeric characters, underscores, and dashes. International characters are allowed. label values are optional. label keys must start with a letter. /// - public string? Labels + [JsonPropertyName("labels")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public IDictionary? Labels { get => this._labels; set @@ -358,7 +356,7 @@ public static GeminiPromptExecutionSettings FromExecutionSettings(PromptExecutio switch (executionSettings) { case null: - return new GeminiPromptExecutionSettings() { MaxTokens = DefaultTextMaxTokens }; + return new GeminiPromptExecutionSettings(); case GeminiPromptExecutionSettings settings: return settings; } diff --git a/dotnet/src/Connectors/Connectors.Google/GeminiToolCallBehavior.cs b/dotnet/src/Connectors/Connectors.Google/GeminiToolCallBehavior.cs index da25a11f7969..4597a18e1bd7 100644 --- a/dotnet/src/Connectors/Connectors.Google/GeminiToolCallBehavior.cs +++ b/dotnet/src/Connectors/Connectors.Google/GeminiToolCallBehavior.cs @@ -127,50 +127,11 @@ internal override void ConfigureGeminiRequest(Kernel? kernel, GeminiRequest requ // Provide all functions from the kernel. foreach (var functionMetadata in kernel.Plugins.GetFunctionsMetadata()) { - request.AddFunction(FunctionMetadataAsGeminiFunction(functionMetadata)); + request.AddFunction(functionMetadata.ToGeminiFunction()); } } internal override bool AllowAnyRequestedKernelFunction => true; - - /// - /// Convert a to an . - /// - /// The object to convert. - /// An object. - private static GeminiFunction FunctionMetadataAsGeminiFunction(KernelFunctionMetadata metadata) - { - IReadOnlyList metadataParams = metadata.Parameters; - - var openAIParams = new GeminiFunctionParameter[metadataParams.Count]; - for (int i = 0; i < openAIParams.Length; i++) - { - var param = metadataParams[i]; - - openAIParams[i] = new GeminiFunctionParameter( - param.Name, - GetDescription(param), - param.IsRequired, - param.ParameterType, - param.Schema); - } - - return new GeminiFunction( - metadata.PluginName, - metadata.Name, - metadata.Description, - openAIParams, - new GeminiFunctionReturnParameter( - metadata.ReturnParameter.Description, - metadata.ReturnParameter.ParameterType, - metadata.ReturnParameter.Schema)); - - static string GetDescription(KernelParameterMetadata param) - { - string? stringValue = InternalTypeConverter.ConvertToString(param.DefaultValue); - return !string.IsNullOrEmpty(stringValue) ? $"{param.Description} (default value: {stringValue})" : param.Description; - } - } } /// diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.HuggingFace/AssemblyInfo.cs deleted file mode 100644 index fe66371dbc58..000000000000 --- a/dotnet/src/Connectors/Connectors.HuggingFace/AssemblyInfo.cs +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; - -// This assembly is currently experimental. -[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.HuggingFace/Connectors.HuggingFace.csproj b/dotnet/src/Connectors/Connectors.HuggingFace/Connectors.HuggingFace.csproj index 6cc98cd71c16..08f033474349 100644 --- a/dotnet/src/Connectors/Connectors.HuggingFace/Connectors.HuggingFace.csproj +++ b/dotnet/src/Connectors/Connectors.HuggingFace/Connectors.HuggingFace.csproj @@ -5,6 +5,7 @@ Microsoft.SemanticKernel.Connectors.HuggingFace $(AssemblyName) net8.0;netstandard2.0 + $(NoWarn);SKEXP0001 preview diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Connectors.MistralAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Connectors.MistralAI.UnitTests.csproj index cf453ddf2c45..b1d36c7d035c 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Connectors.MistralAI.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Connectors.MistralAI.UnitTests.csproj @@ -10,15 +10,9 @@ enable disable false - SKEXP0001,SKEXP0070 + SKEXP0001 - - - - - - diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Services/MistralAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Services/MistralAIChatCompletionServiceTests.cs index f5f8439a1291..76dc8ac1dcd3 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Services/MistralAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Services/MistralAIChatCompletionServiceTests.cs @@ -77,17 +77,17 @@ public async Task ValidateGetStreamingChatMessageContentsAsync() public async Task GetChatMessageContentShouldSendMutatedChatHistoryToLLMAsync() { // Arrange - static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + static Task MutateChatHistoryAsync(AutoFunctionInvocationContext context, Func next) { // Remove the function call messages from the chat history to reduce token count. context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. - next(context); + return next(context); } var kernel = new Kernel(); kernel.ImportPluginFromFunctions("WeatherPlugin", [KernelFunctionFactory.CreateFromMethod((string location) => "rainy", "GetWeather")]); - kernel.AutoFunctionInvocationFilters.Add(new AutoFunctionInvocationFilter(MutateChatHistory)); + kernel.AutoFunctionInvocationFilters.Add(new AutoFunctionInvocationFilter(MutateChatHistoryAsync)); var firstResponse = this.GetTestResponseAsBytes("chat_completions_function_call_response.json"); var secondResponse = this.GetTestResponseAsBytes("chat_completions_function_called_response.json"); @@ -149,12 +149,12 @@ static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + static Task MutateChatHistory(AutoFunctionInvocationContext context, Func next) { // Remove the function call messages from the chat history to reduce token count. context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. - next(context); + return next(context); } var kernel = new Kernel(); diff --git a/dotnet/src/Connectors/Connectors.MistralAI/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.MistralAI/AssemblyInfo.cs deleted file mode 100644 index fe66371dbc58..000000000000 --- a/dotnet/src/Connectors/Connectors.MistralAI/AssemblyInfo.cs +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; - -// This assembly is currently experimental. -[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs index c5e53e67f85d..e0c3c729bf3f 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs @@ -30,7 +30,8 @@ internal sealed class ChatCompletionRequest public bool Stream { get; set; } = false; [JsonPropertyName("safe_prompt")] - public bool SafePrompt { get; set; } = false; + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public bool? SafePrompt { get; set; } = false; [JsonPropertyName("tools")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Connectors.MistralAI.csproj b/dotnet/src/Connectors/Connectors.MistralAI/Connectors.MistralAI.csproj index 8edcf0ed416e..617bada68806 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Connectors.MistralAI.csproj +++ b/dotnet/src/Connectors/Connectors.MistralAI/Connectors.MistralAI.csproj @@ -6,7 +6,7 @@ $(AssemblyName) net8.0;netstandard2.0 alpha - SKEXP0001,SKEXP0070 + SKEXP0001 diff --git a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs index eb2b6760a11a..424257b88bc9 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs @@ -78,7 +78,7 @@ public int? MaxTokens /// [JsonPropertyName("safe_prompt")] [JsonConverter(typeof(BoolJsonConverter))] - public bool SafePrompt + public bool? SafePrompt { get => this._safePrompt; @@ -338,7 +338,7 @@ public static MistralAIPromptExecutionSettings FromExecutionSettings(PromptExecu private double _temperature = 0.7; private double _topP = 1; private int? _maxTokens; - private bool _safePrompt = false; + private bool? _safePrompt = false; private int? _randomSeed; private string _apiVersion = "v1"; private MistralAIToolCallBehavior? _toolCallBehavior; diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj index e3384e56aebb..7bff3b4e7ccf 100644 --- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj @@ -10,7 +10,7 @@ enable disable false - CA2007,CA1861,VSTHRD111,CS1591,SKEXP0001,SKEXP0070 + CA2007,CA1861,VSTHRD111,CS1591,SKEXP0001 diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs index a1b2e4f64d1d..c571868fd59f 100644 --- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs @@ -700,12 +700,12 @@ public async Task GetChatMessageContentsWithFunctionCallMaximumAutoInvokeAttempt public async Task GetChatMessageContentShouldSendMutatedChatHistoryToLLMAsync() { // Arrange - static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + static Task MutateChatHistory(AutoFunctionInvocationContext context, Func next) { // Remove the function call messages from the chat history to reduce token count. context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. - next(context); + return next(context); } var kernel = new Kernel(); diff --git a/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs deleted file mode 100644 index fe66371dbc58..000000000000 --- a/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; - -// This assembly is currently experimental. -[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj b/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj index 606ee24a0f0d..9d9f4d030c0a 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj +++ b/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj @@ -5,6 +5,7 @@ Microsoft.SemanticKernel.Connectors.Ollama $(AssemblyName) net8;netstandard2.0 + $(NoWarn);SKEXP0001 alpha diff --git a/dotnet/src/Connectors/Connectors.Onnx.UnitTests/Connectors.Onnx.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Onnx.UnitTests/Connectors.Onnx.UnitTests.csproj index 8762a79479f0..fd60589cc4e5 100644 --- a/dotnet/src/Connectors/Connectors.Onnx.UnitTests/Connectors.Onnx.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Onnx.UnitTests/Connectors.Onnx.UnitTests.csproj @@ -7,7 +7,7 @@ true enable false - $(NoWarn);SKEXP0001;SKEXP0070;CS1591;IDE1006;RCS1261;CA1031;CA1308;CA1861;CA2007;CA2234;VSTHRD111;SYSLIB1222 + $(NoWarn);SKEXP0001;CS1591;IDE1006;RCS1261;CA1031;CA1308;CA1861;CA2007;CA2234;VSTHRD111;SYSLIB1222 diff --git a/dotnet/src/Connectors/Connectors.Onnx.UnitTests/OnnxChatClientExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Onnx.UnitTests/OnnxChatClientExtensionsTests.cs new file mode 100644 index 000000000000..238ee839c324 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Onnx.UnitTests/OnnxChatClientExtensionsTests.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Xunit; + +namespace SemanticKernel.Connectors.Onnx.UnitTests; + +/// +/// Unit tests for and Onnx IChatClient service collection extensions. +/// +public class OnnxChatClientExtensionsTests +{ + [Fact] + public void AddOnnxRuntimeGenAIChatClientToServiceCollection() + { + // Arrange + var collection = new ServiceCollection(); + + // Act + collection.AddOnnxRuntimeGenAIChatClient("modelId"); + + // Assert + var serviceDescriptor = collection.FirstOrDefault(x => x.ServiceType == typeof(IChatClient)); + Assert.NotNull(serviceDescriptor); + Assert.Equal(ServiceLifetime.Singleton, serviceDescriptor.Lifetime); + } + + [Fact] + public void AddOnnxRuntimeGenAIChatClientToKernelBuilder() + { + // Arrange + var collection = new ServiceCollection(); + var kernelBuilder = collection.AddKernel(); + + // Act + kernelBuilder.AddOnnxRuntimeGenAIChatClient("modelPath"); + + // Assert + var serviceDescriptor = collection.FirstOrDefault(x => x.ServiceType == typeof(IChatClient)); + Assert.NotNull(serviceDescriptor); + Assert.Equal(ServiceLifetime.Singleton, serviceDescriptor.Lifetime); + } + + [Fact] + public void AddOnnxRuntimeGenAIChatClientWithServiceId() + { + // Arrange + var collection = new ServiceCollection(); + + // Act + collection.AddOnnxRuntimeGenAIChatClient("modelPath", serviceId: "test-service"); + + // Assert + var serviceDescriptor = collection.FirstOrDefault(x => x.ServiceType == typeof(IChatClient) && x.ServiceKey?.ToString() == "test-service"); + Assert.NotNull(serviceDescriptor); + Assert.Equal(ServiceLifetime.Singleton, serviceDescriptor.Lifetime); + } + + [Fact] + public void AddOnnxRuntimeGenAIChatClientToKernelBuilderWithServiceId() + { + // Arrange + var collection = new ServiceCollection(); + var kernelBuilder = collection.AddKernel(); + + // Act + kernelBuilder.AddOnnxRuntimeGenAIChatClient("modelPath", serviceId: "test-service"); + + // Assert + var serviceDescriptor = collection.FirstOrDefault(x => x.ServiceType == typeof(IChatClient) && x.ServiceKey?.ToString() == "test-service"); + Assert.NotNull(serviceDescriptor); + Assert.Equal(ServiceLifetime.Singleton, serviceDescriptor.Lifetime); + } +} diff --git a/dotnet/src/Connectors/Connectors.Onnx/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.Onnx/AssemblyInfo.cs deleted file mode 100644 index fe66371dbc58..000000000000 --- a/dotnet/src/Connectors/Connectors.Onnx/AssemblyInfo.cs +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; - -// This assembly is currently experimental. -[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.Onnx/Connectors.Onnx.csproj b/dotnet/src/Connectors/Connectors.Onnx/Connectors.Onnx.csproj index 7abb899f0e66..deb2c228fbe9 100644 --- a/dotnet/src/Connectors/Connectors.Onnx/Connectors.Onnx.csproj +++ b/dotnet/src/Connectors/Connectors.Onnx/Connectors.Onnx.csproj @@ -7,7 +7,7 @@ net8.0;netstandard2.0 true alpha - SYSLIB1222 + $(NoWarn);SKEXP0001;SYSLIB1222 diff --git a/dotnet/src/Connectors/Connectors.Onnx/OnnxKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.Onnx/OnnxKernelBuilderExtensions.ChatClient.cs new file mode 100644 index 000000000000..7e9329d94903 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Onnx/OnnxKernelBuilderExtensions.ChatClient.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.ML.OnnxRuntimeGenAI; + +namespace Microsoft.SemanticKernel; + +/// Extension methods for . +public static class OnnxChatClientKernelBuilderExtensions +{ + #region Chat Client + + /// + /// Adds an OnnxRuntimeGenAI to the . + /// + /// The instance to augment. + /// The generative AI ONNX model path. + /// The optional options for the chat client. + /// A local identifier for the given AI service + /// The same instance as . + public static IKernelBuilder AddOnnxRuntimeGenAIChatClient( + this IKernelBuilder builder, + string modelPath, + OnnxRuntimeGenAIChatClientOptions? chatClientOptions = null, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOnnxRuntimeGenAIChatClient( + modelPath, + chatClientOptions, + serviceId); + + return builder; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Onnx/OnnxServiceCollectionExtensions.DependencyInjection.cs b/dotnet/src/Connectors/Connectors.Onnx/OnnxServiceCollectionExtensions.DependencyInjection.cs index 0ea95328d89c..a8dda516b338 100644 --- a/dotnet/src/Connectors/Connectors.Onnx/OnnxServiceCollectionExtensions.DependencyInjection.cs +++ b/dotnet/src/Connectors/Connectors.Onnx/OnnxServiceCollectionExtensions.DependencyInjection.cs @@ -1,7 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.IO; +using System.Text; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.ML.OnnxRuntimeGenAI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Onnx; using Microsoft.SemanticKernel.Embeddings; @@ -57,4 +61,56 @@ public static IServiceCollection AddBertOnnxEmbeddingGenerator( serviceId, BertOnnxTextEmbeddingGenerationService.Create(onnxModelStream, vocabStream, options).AsEmbeddingGenerator()); } + + /// + /// Add OnnxRuntimeGenAI Chat Client to the service collection. + /// + /// The service collection. + /// The generative AI ONNX model path. + /// The options for the chat client. + /// The optional service ID. + /// The updated service collection. + public static IServiceCollection AddOnnxRuntimeGenAIChatClient( + this IServiceCollection services, + string modelPath, + OnnxRuntimeGenAIChatClientOptions? chatClientOptions = null, + string? serviceId = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelPath); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var chatClient = new OnnxRuntimeGenAIChatClient(modelPath, chatClientOptions ?? new OnnxRuntimeGenAIChatClientOptions() + { + PromptFormatter = static (messages, _) => + { + StringBuilder promptBuilder = new(); + foreach (var message in messages) + { + promptBuilder.Append($"<|{message.Role}|>\n{message.Text}"); + } + promptBuilder.Append("<|end|>\n<|assistant|>"); + + return promptBuilder.ToString(); + } + }); + + var builder = chatClient.AsBuilder() + .UseKernelFunctionInvocation(loggerFactory); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } } diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj index cfe78dcc804f..0a7171bbcd0d 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj @@ -7,7 +7,7 @@ true enable false - $(NoWarn);SKEXP0001;SKEXP0070;SKEXP0010;CS1591;IDE1006;RCS1261;CA1031;CA1308;CA1861;CA2007;CA2234;VSTHRD111;CA1812 + $(NoWarn);SKEXP0001;SKEXP0010;CS1591;IDE1006;RCS1261;CA1031;CA1308;CA1861;CA2007;CA2234;VSTHRD111;CA1812 diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/ChatHistoryExtensionsTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/ChatHistoryExtensionsTests.cs index c0c3f0abda4b..961d3724afd8 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/ChatHistoryExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/ChatHistoryExtensionsTests.cs @@ -12,6 +12,7 @@ using Xunit; namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions; + public class ChatHistoryExtensionsTests { [Fact] diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs index 695137812a15..b012a0210e6c 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs @@ -1107,12 +1107,12 @@ public async Task GetInvalidResponseThrowsExceptionAndIsCapturedByDiagnosticsAsy public async Task GetChatMessageContentShouldSendMutatedChatHistoryToLLM() { // Arrange - static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + static Task MutateChatHistory(AutoFunctionInvocationContext context, Func next) { // Remove the function call messages from the chat history to reduce token count. context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. - next(context); + return next(context); } var kernel = new Kernel(); @@ -1178,12 +1178,12 @@ static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + static Task MutateChatHistory(AutoFunctionInvocationContext context, Func next) { // Remove the function call messages from the chat history to reduce token count. context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. - next(context); + return next(context); } var kernel = new Kernel(); diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Settings/OpenAIPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Settings/OpenAIPromptExecutionSettingsTests.cs index 9423ad6fee14..96325d3b7ac7 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Settings/OpenAIPromptExecutionSettingsTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Settings/OpenAIPromptExecutionSettingsTests.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Text.Json; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.OpenAI; using Xunit; @@ -470,6 +471,365 @@ public void ItCannotCreateOpenAIPromptExecutionSettingsWithInvalidBoolValues(obj Assert.Throws(() => OpenAIPromptExecutionSettings.FromExecutionSettings(originalSettings)); } + [Fact] + public void PrepareChatHistoryToRequestAsyncAddsSystemPromptWhenNotPresent() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = "You are a helpful assistant." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); // Should return the same instance + Assert.Equal(2, chatHistory.Count); + Assert.Equal(AuthorRole.System, chatHistory[0].Role); + Assert.Equal("You are a helpful assistant.", chatHistory[0].Content); + Assert.Equal(AuthorRole.User, chatHistory[1].Role); + Assert.Equal("Hello", chatHistory[1].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncAddsSystemPromptAtBeginning() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = "You are a helpful assistant." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("First message"); + chatHistory.AddAssistantMessage("First response"); + chatHistory.AddUserMessage("Second message"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(4, chatHistory.Count); + Assert.Equal(AuthorRole.System, chatHistory[0].Role); + Assert.Equal("You are a helpful assistant.", chatHistory[0].Content); + Assert.Equal(AuthorRole.User, chatHistory[1].Role); + Assert.Equal("First message", chatHistory[1].Content); + Assert.Equal(AuthorRole.Assistant, chatHistory[2].Role); + Assert.Equal("First response", chatHistory[2].Content); + Assert.Equal(AuthorRole.User, chatHistory[3].Role); + Assert.Equal("Second message", chatHistory[3].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncDoesNotAddSystemPromptWhenAlreadyPresent() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = "You are a helpful assistant." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("Existing system message"); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(2, chatHistory.Count); + Assert.Equal(AuthorRole.System, chatHistory[0].Role); + Assert.Equal("Existing system message", chatHistory[0].Content); // Original system message preserved + Assert.Equal(AuthorRole.User, chatHistory[1].Role); + Assert.Equal("Hello", chatHistory[1].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncAddsDeveloperPromptWhenNotPresent() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatDeveloperPrompt = "Debug mode enabled." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(2, chatHistory.Count); + Assert.Equal(AuthorRole.Developer, chatHistory[0].Role); + Assert.Equal("Debug mode enabled.", chatHistory[0].Content); + Assert.Equal(AuthorRole.User, chatHistory[1].Role); + Assert.Equal("Hello", chatHistory[1].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncDoesNotAddDeveloperPromptWhenAlreadyPresent() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatDeveloperPrompt = "Debug mode enabled." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddDeveloperMessage("Existing developer message"); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(2, chatHistory.Count); + Assert.Equal(AuthorRole.Developer, chatHistory[0].Role); + Assert.Equal("Existing developer message", chatHistory[0].Content); // Original developer message preserved + Assert.Equal(AuthorRole.User, chatHistory[1].Role); + Assert.Equal("Hello", chatHistory[1].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncAddsBothSystemAndDeveloperPrompts() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = "You are a helpful assistant.", + ChatDeveloperPrompt = "Debug mode enabled." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(3, chatHistory.Count); + Assert.Equal(AuthorRole.System, chatHistory[0].Role); + Assert.Equal("You are a helpful assistant.", chatHistory[0].Content); + Assert.Equal(AuthorRole.Developer, chatHistory[1].Role); + Assert.Equal("Debug mode enabled.", chatHistory[1].Content); + Assert.Equal(AuthorRole.User, chatHistory[2].Role); + Assert.Equal("Hello", chatHistory[2].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncDoesNotAddEmptyOrWhitespacePrompts() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = " ", // Whitespace only + ChatDeveloperPrompt = "" // Empty string + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Single(chatHistory); // Only the original user message should remain + Assert.Equal(AuthorRole.User, chatHistory[0].Role); + Assert.Equal("Hello", chatHistory[0].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncDoesNotAddNullPrompts() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = null, + ChatDeveloperPrompt = null + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Single(chatHistory); // Only the original user message should remain + Assert.Equal(AuthorRole.User, chatHistory[0].Role); + Assert.Equal("Hello", chatHistory[0].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncWorksWithEmptyChatHistory() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = "You are a helpful assistant.", + ChatDeveloperPrompt = "Debug mode enabled." + }; + + var chatHistory = new ChatHistory(); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(2, chatHistory.Count); + Assert.Equal(AuthorRole.System, chatHistory[0].Role); + Assert.Equal("You are a helpful assistant.", chatHistory[0].Content); + Assert.Equal(AuthorRole.Developer, chatHistory[1].Role); + Assert.Equal("Debug mode enabled.", chatHistory[1].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncPreservesExistingMessageOrder() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = "You are a helpful assistant." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddDeveloperMessage("Existing developer message"); + chatHistory.AddUserMessage("First user message"); + chatHistory.AddAssistantMessage("Assistant response"); + chatHistory.AddUserMessage("Second user message"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(5, chatHistory.Count); + + // System message should be added at the beginning, before existing developer message + Assert.Equal(AuthorRole.System, chatHistory[0].Role); + Assert.Equal("You are a helpful assistant.", chatHistory[0].Content); + Assert.Equal(AuthorRole.Developer, chatHistory[1].Role); + Assert.Equal("Existing developer message", chatHistory[1].Content); + Assert.Equal(AuthorRole.User, chatHistory[2].Role); + Assert.Equal("First user message", chatHistory[2].Content); + Assert.Equal(AuthorRole.Assistant, chatHistory[3].Role); + Assert.Equal("Assistant response", chatHistory[3].Content); + Assert.Equal(AuthorRole.User, chatHistory[4].Role); + Assert.Equal("Second user message", chatHistory[4].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncInsertsSystemBeforeDeveloperWhenBothExist() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = "You are a helpful assistant.", + ChatDeveloperPrompt = "Debug mode enabled." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddDeveloperMessage("Existing developer message"); + chatHistory.AddSystemMessage("Existing system message"); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(3, chatHistory.Count); // No new messages should be added since both already exist + Assert.Equal(AuthorRole.Developer, chatHistory[0].Role); + Assert.Equal("Existing developer message", chatHistory[0].Content); + Assert.Equal(AuthorRole.System, chatHistory[1].Role); + Assert.Equal("Existing system message", chatHistory[1].Content); + Assert.Equal(AuthorRole.User, chatHistory[2].Role); + Assert.Equal("Hello", chatHistory[2].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncAddsSystemBeforeExistingDeveloper() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatSystemPrompt = "You are a helpful assistant.", + ChatDeveloperPrompt = "Debug mode enabled." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddDeveloperMessage("Existing developer message"); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(3, chatHistory.Count); + + // System message should be inserted at the beginning, before existing developer message + Assert.Equal(AuthorRole.System, chatHistory[0].Role); + Assert.Equal("You are a helpful assistant.", chatHistory[0].Content); + Assert.Equal(AuthorRole.Developer, chatHistory[1].Role); + Assert.Equal("Existing developer message", chatHistory[1].Content); + Assert.Equal(AuthorRole.User, chatHistory[2].Role); + Assert.Equal("Hello", chatHistory[2].Content); + } + + [Fact] + public void PrepareChatHistoryToRequestAsyncAddsDeveloperWhenSystemExists() + { + // Arrange + var settings = new TestableOpenAIPromptExecutionSettings + { + ChatDeveloperPrompt = "Debug mode enabled." + }; + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("Existing system message"); + chatHistory.AddUserMessage("Hello"); + + // Act + var result = settings.TestPrepareChatHistoryToRequest(chatHistory); + + // Assert + Assert.Same(chatHistory, result); + Assert.Equal(3, chatHistory.Count); + + // Developer message should be inserted at the beginning, before existing system message + Assert.Equal(AuthorRole.Developer, chatHistory[0].Role); + Assert.Equal("Debug mode enabled.", chatHistory[0].Content); + Assert.Equal(AuthorRole.System, chatHistory[1].Role); + Assert.Equal("Existing system message", chatHistory[1].Content); + Assert.Equal(AuthorRole.User, chatHistory[2].Role); + Assert.Equal("Hello", chatHistory[2].Content); + } + + /// + /// Test implementation of OpenAIPromptExecutionSettings that exposes the protected PrepareChatHistoryToRequestAsync method. + /// + private sealed class TestableOpenAIPromptExecutionSettings : OpenAIPromptExecutionSettings + { + public ChatHistory TestPrepareChatHistoryToRequest(ChatHistory chatHistory) + { + return base.PrepareChatHistoryForRequest(chatHistory); + } + } + private static void AssertExecutionSettings(OpenAIPromptExecutionSettings executionSettings) { Assert.NotNull(executionSettings); diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Settings/OpenAIPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.OpenAI/Settings/OpenAIPromptExecutionSettings.cs index c212c3d04767..c4c3e6259823 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Settings/OpenAIPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Settings/OpenAIPromptExecutionSettings.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.SemanticKernel.ChatCompletion; @@ -545,6 +546,23 @@ public static OpenAIPromptExecutionSettings FromExecutionSettings(PromptExecutio }; } + /// + protected override ChatHistory PrepareChatHistoryForRequest(ChatHistory chatHistory) + { + // Inserts system and developer prompts at the beginning of the chat history if they are not already present. + if (!string.IsNullOrWhiteSpace(this.ChatDeveloperPrompt) && !chatHistory.Any(m => m.Role == AuthorRole.Developer)) + { + chatHistory.Insert(0, new ChatMessageContent(AuthorRole.Developer, this.ChatDeveloperPrompt)); + } + + if (!string.IsNullOrWhiteSpace(this.ChatSystemPrompt) && !chatHistory.Any(m => m.Role == AuthorRole.System)) + { + chatHistory.Insert(0, new ChatMessageContent(AuthorRole.System, this.ChatSystemPrompt)); + } + + return chatHistory; + } + #region private ================================================================================ private object? _webSearchOptions; diff --git a/dotnet/src/Experimental/Process.Core/FoundryListenForBuilder.cs b/dotnet/src/Experimental/Process.Core/FoundryListenForBuilder.cs deleted file mode 100644 index 3fe9bd25ed7c..000000000000 --- a/dotnet/src/Experimental/Process.Core/FoundryListenForBuilder.cs +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; -using System.Threading.Tasks; -using Microsoft.SemanticKernel.Process.Internal; - -namespace Microsoft.SemanticKernel; - -/// -/// Builder class for defining Processes that can be exported to Foundry. -/// -[Experimental("SKEXP0081")] -public class FoundryListenForBuilder -{ - private readonly ProcessBuilder _processBuilder; - private readonly ListenForBuilder _listenForBuilder; - - /// - /// Initializes a new instance of the class. - /// - /// The process builder. - public FoundryListenForBuilder(ProcessBuilder processBuilder) - { - this._processBuilder = processBuilder; - this._listenForBuilder = new ListenForBuilder(processBuilder); - } - - /// - /// Listens for an input event. - /// - /// - /// - /// - public FoundryListenForTargetBuilder InputEvent(string eventName, KernelProcessEdgeCondition? condition = null) - { - return new(this._listenForBuilder.InputEvent(eventName, condition)); - } - - /// - /// Defines a message to listen for from a specific process step. - /// - /// - /// - public FoundryListenForTargetBuilder ProcessStart(KernelProcessEdgeCondition? condition = null) - { - return this.InputEvent(ProcessConstants.Declarative.OnEnterEvent, condition); - } - - /// - /// Defines a message to listen for from a specific process step. - /// - /// The type of the message. - /// The process step from which the message originates. - /// Condition that must be met for the message to be processed - /// A builder for defining the target of the message. - public FoundryListenForTargetBuilder Message(string messageType, ProcessStepBuilder from, string? condition = null) - { - KernelProcessEdgeCondition? edgeCondition = null; - if (!string.IsNullOrWhiteSpace(condition)) - { - edgeCondition = new KernelProcessEdgeCondition( - (e, s) => - { - var wrapper = new DeclarativeConditionContentWrapper - { - State = s, - Event = e.Data - }; - - var result = JMESPathConditionEvaluator.EvaluateCondition(wrapper, condition); - return Task.FromResult(result); - }, condition); - } - - return new(this._listenForBuilder.Message(messageType, from, edgeCondition)); - } - - /// - /// Defines a message to listen for from a specific process step. - /// - /// The process step from which the message originates. - /// Condition that must be met for the message to be processed - /// A builder for defining the target of the message. - public FoundryListenForTargetBuilder ResultFrom(ProcessStepBuilder from, string? condition = null) - { - KernelProcessEdgeCondition? edgeCondition = null; - if (!string.IsNullOrWhiteSpace(condition)) - { - edgeCondition = new KernelProcessEdgeCondition( - (e, s) => - { - var wrapper = new DeclarativeConditionContentWrapper - { - State = s, - Event = e.Data - }; - - var result = JMESPathConditionEvaluator.EvaluateCondition(wrapper, condition); - return Task.FromResult(result); - }, condition); - } - - return new(this._listenForBuilder.OnResult(from, edgeCondition)); - } - - /// - /// Listen for the OnEnter event from a specific process step. - /// - /// The process step from which the message originates. - /// Condition that must be met for the message to be processed - /// A builder for defining the target of the message. - public FoundryListenForTargetBuilder OnEnter(ProcessStepBuilder from, string? condition = null) - { - return this.Message(ProcessConstants.Declarative.OnEnterEvent, from, condition); - } - - /// - /// Listen for the OnEnter event from a specific process step. - /// - /// The process step from which the message originates. - /// Condition that must be met for the message to be processed - /// A builder for defining the target of the message. - public FoundryListenForTargetBuilder OnExit(ProcessStepBuilder from, string? condition = null) - { - return this.Message(ProcessConstants.Declarative.OnExitEvent, from, condition); - } -} diff --git a/dotnet/src/Experimental/Process.Core/FoundryListenForTargetBuilder.cs b/dotnet/src/Experimental/Process.Core/FoundryListenForTargetBuilder.cs deleted file mode 100644 index 3f5642e1df6b..000000000000 --- a/dotnet/src/Experimental/Process.Core/FoundryListenForTargetBuilder.cs +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; - -namespace Microsoft.SemanticKernel; - -/// -/// Builder class for defining targets to listen for in a process. -/// -[Experimental("SKEXP0081")] -public class FoundryListenForTargetBuilder -{ - private readonly ListenForTargetBuilder _listenForTargetBuilder; - - internal FoundryListenForTargetBuilder(ListenForTargetBuilder listenForTargetBuilder) - { - this._listenForTargetBuilder = listenForTargetBuilder; - } - - /// - /// Initializes a new instance of the class. - /// - /// The list of message sources. - /// The process builder. - /// The group ID for the message sources. - public FoundryListenForTargetBuilder(List messageSources, ProcessBuilder processBuilder, KernelProcessEdgeGroupBuilder? edgeGroup = null) - { - this._listenForTargetBuilder = new ListenForTargetBuilder(messageSources, processBuilder, edgeGroup); - } - - /// - /// Signals that the output of the source step should be sent to the specified target when the associated event fires. - /// - /// The output target. - /// The thread to send the event to. - /// The inputs to the target. - /// The messages to be sent to the target. - /// A fresh builder instance for fluid definition - public ProcessStepEdgeBuilder SendEventTo(ProcessAgentBuilder target, string? thread = null, Dictionary? inputs = null, List? messagesIn = null) where TProcessState : class, new() - { - return this._listenForTargetBuilder.SendEventTo_Internal(new ProcessAgentInvokeTargetBuilder(target, thread, messagesIn ?? [], inputs ?? [])); - } - - /// - /// Signals that the specified event should be emitted. - /// - /// - /// - /// - public FoundryListenForTargetBuilder EmitEvent(string eventName, Dictionary? payload = null) - { - return new(this._listenForTargetBuilder.EmitEvent(eventName, payload)); - } - - /// - /// Signals that the specified state variable should be updated in the process state. - /// - /// - /// - /// - /// - public FoundryListenForTargetBuilder UpdateProcessState(string path, StateUpdateOperations operation, object? value) - { - return new(this._listenForTargetBuilder.UpdateProcessState(path, operation, value)); - } - - /// - /// Signals that the process should be stopped. - /// - public void StopProcess(string? thread = null, Dictionary? inputs = null, List? messagesIn = null) - { - var target = new ProcessAgentInvokeTargetBuilder(EndStep.Instance, thread, messagesIn ?? [], inputs ?? []); - this._listenForTargetBuilder.SendEventTo_Internal(target); - } -} diff --git a/dotnet/src/Experimental/Process.Core/FoundryMessageSourceBuilder.cs b/dotnet/src/Experimental/Process.Core/FoundryMessageSourceBuilder.cs deleted file mode 100644 index 25cf713711ae..000000000000 --- a/dotnet/src/Experimental/Process.Core/FoundryMessageSourceBuilder.cs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Diagnostics.CodeAnalysis; -using System.Threading.Tasks; -using Microsoft.SemanticKernel.Process.Internal; - -namespace Microsoft.SemanticKernel; - -/// -/// Builder class for defining message sources in a Foundry process. -/// -[Experimental("SKEXP0081")] -public class FoundryMessageSourceBuilder -{ - /// - /// Initializes a new instance of the class. - /// - /// The meassage type - /// The source step builder - /// Condition that must be met for the message to be processed - public FoundryMessageSourceBuilder(string messageType, ProcessStepBuilder source, string? condition) - { - this.MessageType = messageType; - this.Source = source; - this.Condition = condition; - } - - /// - /// The message type - /// - public string MessageType { get; } - - /// - /// The source step builder. - /// - public ProcessStepBuilder Source { get; } - - /// - /// The condition that must be met for the message to be processed. - /// - public string? Condition { get; } - - /// - /// Builds the message source. - /// - /// - internal MessageSourceBuilder Build() - { - KernelProcessEdgeCondition? edgeCondition = null; - if (!string.IsNullOrWhiteSpace(this.Condition)) - { - edgeCondition = new KernelProcessEdgeCondition( - (e, s) => - { - var wrapper = new DeclarativeConditionContentWrapper - { - State = s, - Event = e.Data - }; - - var result = JMESPathConditionEvaluator.EvaluateCondition(wrapper, this.Condition); - return Task.FromResult(result); - }); - } - return new MessageSourceBuilder(this.MessageType, this.Source, edgeCondition); - } -} diff --git a/dotnet/src/Experimental/Process.Core/FoundryProcessBuilder.cs b/dotnet/src/Experimental/Process.Core/FoundryProcessBuilder.cs deleted file mode 100644 index ea80d4d060b6..000000000000 --- a/dotnet/src/Experimental/Process.Core/FoundryProcessBuilder.cs +++ /dev/null @@ -1,268 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.Net.Http; -using System.Text.Json; -using System.Text.Json.Serialization; -using System.Threading; -using System.Threading.Tasks; -using Azure.AI.Agents.Persistent; -using Azure.Core; -using Azure.Identity; -using Microsoft.SemanticKernel.Agents; -using Microsoft.SemanticKernel.Agents.AzureAI; -using Microsoft.SemanticKernel.Process.Models; - -namespace Microsoft.SemanticKernel; - -/// -/// A builder for creating a process that can be deployed to Azure Foundry. -/// -[Experimental("SKEXP0081")] -public class FoundryProcessBuilder where TProcessState : class, new() -{ - private readonly ProcessBuilder _processBuilder; - private static readonly string[] s_scopes = ["https://management.azure.com/"]; - - /// - /// Initializes a new instance of the class. - /// - /// The name of the process. This is required. - /// The description of the Process. - public FoundryProcessBuilder(string id, string? description = null) - { - this._processBuilder = new ProcessBuilder(id, description, processBuilder: null, typeof(TProcessState)); - } - - /// - /// Adds an to the process. - /// - /// The name of the thread. - /// The policy that determines the lifetime of the - /// - public ProcessBuilder AddThread(string threadName, KernelProcessThreadLifetime threadPolicy = KernelProcessThreadLifetime.Scoped) - { - return this._processBuilder.AddThread(threadName, threadPolicy); - } - - /// - /// Adds a step to the process from a declarative agent. - /// - /// The - /// The unique Id of the step. If not provided, the name of the step Type will be used. - /// Aliases that have been used by previous versions of the step, used for supporting backward compatibility when reading old version Process States - /// Specifies the thread reference to be used by the agent. If not provided, the agent will create a new thread for each invocation. - /// Specifies the human-in-the-loop mode for the agent. If not provided, the default is . - public ProcessAgentBuilder AddStepFromAgent(AgentDefinition agentDefinition, string? stepId = null, IReadOnlyList? aliases = null, string? defaultThread = null, HITLMode humanInLoopMode = HITLMode.Never) - { - Verify.NotNull(agentDefinition); - if (agentDefinition.Type != AzureAIAgentFactory.AzureAIAgentType) - { - throw new ArgumentException($"The agent type '{agentDefinition.Type}' is not supported. Only '{AzureAIAgentFactory.AzureAIAgentType}' is supported."); - } - - return this._processBuilder.AddStepFromAgent(agentDefinition, stepId, aliases, defaultThread, humanInLoopMode); - } - - /// - /// Adds a step to the process from a . - /// - /// The - /// The unique Id of the step. If not provided, the name of the step Type will be used. - /// Aliases that have been used by previous versions of the step, used for supporting backward compatibility when reading old version Process States - /// Specifies the thread reference to be used by the agent. If not provided, the agent will create a new thread for each invocation. - /// Specifies the human-in-the-loop mode for the agent. If not provided, the default is . - public ProcessAgentBuilder AddStepFromAgent(PersistentAgent persistentAgent, string? stepId = null, IReadOnlyList? aliases = null, string? defaultThread = null, HITLMode humanInLoopMode = HITLMode.Never) - { - Verify.NotNull(persistentAgent); - - var agentDefinition = new AgentDefinition - { - Id = persistentAgent.Id, - Type = AzureAIAgentFactory.AzureAIAgentType, - Name = persistentAgent.Name, - Description = persistentAgent.Description - }; - - return this._processBuilder.AddStepFromAgent(agentDefinition, stepId, aliases, defaultThread, humanInLoopMode); - } - - /// - /// Adds a step to the process from a declarative agent. - /// - /// Id of the step. If not provided, the Id will come from the agent Id. - /// The - /// Specifies the thread reference to be used by the agent. If not provided, the agent will create a new thread for each invocation. - /// Specifies the human-in-the-loop mode for the agent. If not provided, the default is . - /// - /// - /// - public ProcessAgentBuilder AddStepFromAgentProxy(string stepId, AgentDefinition agentDefinition, string? threadName = null, HITLMode humanInLoopMode = HITLMode.Never, IReadOnlyList? aliases = null) // TODO: Is there a better way to model this? - { - Verify.NotNullOrWhiteSpace(stepId); - Verify.NotNull(agentDefinition); - if (agentDefinition.Type != AzureAIAgentFactory.AzureAIAgentType) - { - throw new ArgumentException($"The agent type '{agentDefinition.Type}' is not supported. Only '{AzureAIAgentFactory.AzureAIAgentType}' is supported."); - } - - return this._processBuilder.AddStepFromAgentProxy(agentDefinition, threadName, stepId, humanInLoopMode, aliases); - } - - /// - /// Provides an instance of for defining an input edge to a process. - /// - /// The Id of the external event. - /// An instance of - internal ProcessEdgeBuilder OnInputEvent(string eventId) - { - return this._processBuilder.OnInputEvent(eventId); - } - - /// - /// Creates a instance to define a listener for incoming messages. - /// - /// The process step from which the message originates. - /// The name of the event to listen for. - /// An optional condition using JMESPath syntax. - /// - public FoundryListenForTargetBuilder OnEvent(ProcessStepBuilder step, string eventName, string? condition = null) - { - Verify.NotNull(step); - Verify.NotNullOrWhiteSpace(eventName); - return new FoundryListenForBuilder(this._processBuilder).Message(eventName, step, condition); - } - - /// - /// Creates a instance to define a listener for when the process step is entered. - /// - /// - /// - /// - public FoundryListenForTargetBuilder OnStepEnter(ProcessStepBuilder step, string? condition = null) - { - Verify.NotNull(step); - return new FoundryListenForBuilder(this._processBuilder).OnEnter(step, condition); - } - - /// - /// Creates a instance to define a listener for when the process step is exited. - /// - /// - /// - /// - public FoundryListenForTargetBuilder OnStepExit(ProcessStepBuilder step, string? condition = null) - { - Verify.NotNull(step); - return new FoundryListenForBuilder(this._processBuilder).OnExit(step, condition); - } - - /// - /// Creates a instance to define a listener for when the process starts. - /// - /// - public FoundryListenForTargetBuilder OnProcessEnter() - { - return new FoundryListenForBuilder(this._processBuilder).ProcessStart(); - } - - /// - /// Builds the process. - /// - /// An instance of - /// - public KernelProcess Build(KernelProcessStateMetadata? stateMetadata = null) - { - return this._processBuilder.Build(stateMetadata); - } - - /// - /// Deploys the process to Azure Foundry. - /// - /// Th workflow endpoint to deploy to. - /// The credential to use. - /// - /// - public async Task DeployToFoundryAsync(string endpoint, TokenCredential? credential = null, CancellationToken cancellationToken = default) - { - // Build the process - var process = this.Build(); - - // Serialize and deploy - using var httpClient = new HttpClient(); - if (credential != null) - { - var token = await credential.GetTokenAsync(new TokenRequestContext(s_scopes), cancellationToken).ConfigureAwait(false); - httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {token.Token}"); - } - else - { - var token = await new DefaultAzureCredential().GetTokenAsync(new TokenRequestContext(s_scopes), cancellationToken).ConfigureAwait(false); - httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {token.Token}"); - } - - var workflow = await WorkflowBuilder.BuildWorkflow(process).ConfigureAwait(false); - string json = WorkflowSerializer.SerializeToJson(workflow); - using var content = new StringContent(json, System.Text.Encoding.UTF8, "application/json"); - var response = await httpClient.PostAsync(new Uri($"{endpoint}/agents?api-version=2025-05-01-preview"), content, cancellationToken).ConfigureAwait(false); - - if (!response.IsSuccessStatusCode) - { - var errorContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - throw new KernelException($"Failed to deploy process. Response: {errorContent}"); - } - - var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - var foundryWorkflow = JsonSerializer.Deserialize(responseContent); - return foundryWorkflow?.Id ?? throw new KernelException("Failed to parse the response from Foundry."); - } - - /// - /// Serializes the process to JSON. - /// - public async Task ToJsonAsync() - { - var process = this.Build(); - var workflow = await WorkflowBuilder.BuildWorkflow(process).ConfigureAwait(false); - return WorkflowSerializer.SerializeToJson(workflow); - } - - /// - /// Serializes the process to YAML. - /// - public async Task ToYamlAsync() - { - var process = this.Build(); - var workflow = await WorkflowBuilder.BuildWorkflow(process).ConfigureAwait(false); - return WorkflowSerializer.SerializeToYaml(workflow); - } - - private class FoundryWorkflow - { - [JsonPropertyName("id")] - public string? Id { get; set; } - } -} - -/// -/// A builder for creating a process that can be deployed to Azure Foundry. -/// -public class FoundryProcessBuilder : FoundryProcessBuilder -{ - /// - /// Initializes a new instance of the class. - /// - /// - public FoundryProcessBuilder(string id) : base(id) - { - } -} - -/// -/// A default process state for the . -/// -public class FoundryProcessDefaultState -{ -} diff --git a/dotnet/src/Experimental/Process.Core/FoundryWorkflowExtensions.cs b/dotnet/src/Experimental/Process.Core/FoundryWorkflowExtensions.cs deleted file mode 100644 index 15ca43c0c83a..000000000000 --- a/dotnet/src/Experimental/Process.Core/FoundryWorkflowExtensions.cs +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.ClientModel.Primitives; -using System.IO; -using System.Text; -using System.Text.Json; -using System.Threading.Tasks; -using Azure.Core; -using Azure.Core.Pipeline; - -namespace Microsoft.SemanticKernel.Agents.AzureAI; - -/// -/// Extensions for managing Foundry Workflows -/// -public static class FoundryWorkflowExtensions -{ - /// - /// Publishes a workflow using a and a . - /// - /// The process state type. - /// The client pipeline. - /// The process builder. - /// The published . - public static async Task PublishWorkflowAsync(this ClientPipeline pipeline, FoundryProcessBuilder process) where T : class, new() - { - // Send the request - using var message = pipeline.CreateMessage(); - var payload = await process.ToJsonAsync().ConfigureAwait(false); - message.Request.Method = "POST"; - message.Request.Uri = new Uri("https://localhost/agents"); - message.Request.Content = System.ClientModel.BinaryContent.Create(new MemoryStream(Encoding.UTF8.GetBytes(payload))); - message.Request.Headers.Add("Content-Type", "application/json"); - - await pipeline.SendAsync(message).ConfigureAwait(false); - - if (message.Response?.Status < 200 || message.Response?.Status >= 300) - { - var errorContent = await message.Response.Content.AsJsonAsync().ConfigureAwait(false); - - throw new KernelException($"Error publishing workflow: {errorContent}"); - } - - var responseJson = await message.Response!.Content.AsJsonAsync().ConfigureAwait(false) ?? string.Empty; - - using var doc = JsonDocument.Parse(responseJson); - var workflowId = doc.RootElement.GetProperty("id").GetString() ?? string.Empty; - - return new Workflow() { Id = workflowId }; - } - - /// - /// Publishes a workflow using an and a . - /// - /// The process state type. - /// The HTTP pipeline. - /// The process builder. - /// The published . - public static async Task PublishWorkflowAsync(this HttpPipeline pipeline, FoundryProcessBuilder process) where T : class, new() - { - // Send the request - using var message = pipeline.CreateMessage(); - message.Request.Method = RequestMethod.Post; - message.Request.Uri.Reset(new Uri("https://localhost/agents")); - message.Request.Content = RequestContent.Create(new MemoryStream(Encoding.UTF8.GetBytes(await process.ToJsonAsync().ConfigureAwait(false)))); - message.Request.Headers.Add("Content-Type", "application/json"); - - await pipeline.SendAsync(message, default).ConfigureAwait(false); - - if (message.Response?.Status < 200 || message.Response?.Status >= 300) - { - var errorContent = await message.Response.Content.AsJsonAsync().ConfigureAwait(false); - - throw new KernelException($"Error publishing workflow: {errorContent}"); - } - - var responseJson = await message.Response!.Content.AsJsonAsync().ConfigureAwait(false) ?? string.Empty; - - using var doc = JsonDocument.Parse(responseJson); - var workflowId = doc.RootElement.GetProperty("id").GetString() ?? string.Empty; - - Console.WriteLine($"Creating workflow {workflowId}..."); - - return new Workflow() { Id = workflowId }; - } - - /// - /// Deletes a workflow using a . - /// - /// The client pipeline. - /// The workflow to delete. - public static async Task DeleteWorkflowAsync(this ClientPipeline pipeline, Workflow workflow) - { - // Send the request - using var message = pipeline.CreateMessage(); - message.Request.Method = "DELETE"; - message.Request.Uri = new Uri($"https://localhost/agents/{workflow.Id}"); - - await pipeline.SendAsync(message).ConfigureAwait(false); - - if (message.Response?.Status < 200 || message.Response?.Status >= 300) - { - throw new KernelException($"Failed to delete workflow: {message.Response?.Status} {message.Response?.ReasonPhrase}"); - } - } - - /// - /// Deletes a workflow using an . - /// - /// The HTTP pipeline. - /// The workflow to delete. - public static async Task DeleteWorkflowAsync(this HttpPipeline pipeline, Workflow workflow) - { - // Send the request - using var message = pipeline.CreateMessage(); - message.Request.Method = RequestMethod.Delete; - message.Request.Uri.Reset(new Uri($"https://localhost/agents/{workflow.Id}")); - - await pipeline.SendAsync(message, default).ConfigureAwait(false); - - if (message.Response?.Status < 200 || message.Response?.Status >= 300) - { - throw new KernelException($"Failed to delete workflow: {message.Response?.Status} {message.Response?.ReasonPhrase}"); - } - } - - /// - /// Reads the as a JSON string asynchronously. - /// - /// The binary data. - /// The JSON string. - public static async Task AsJsonAsync(this BinaryData data) - { - if (data == null || data.Length == 0) - { - return string.Empty; - } - - using var reader = new StreamReader(data.ToStream(), Encoding.UTF8); - - return await reader.ReadToEndAsync().ConfigureAwait(false); - } -} diff --git a/dotnet/src/Experimental/Process.Core/ProcessAgentBuilder.cs b/dotnet/src/Experimental/Process.Core/ProcessAgentBuilder.cs index 91e295cdff08..3a4b39f82751 100644 --- a/dotnet/src/Experimental/Process.Core/ProcessAgentBuilder.cs +++ b/dotnet/src/Experimental/Process.Core/ProcessAgentBuilder.cs @@ -240,7 +240,7 @@ internal ProcessFunctionTargetBuilder GetInvokeAgentFunctionTargetBuilder() /// /// Builder for a process step that represents an agent. /// -public class ProcessAgentBuilder : ProcessAgentBuilder +public class ProcessAgentBuilder : ProcessAgentBuilder { /// /// Creates a new instance of the class. diff --git a/dotnet/src/Experimental/Process.Core/ProcessExporter.cs b/dotnet/src/Experimental/Process.Core/ProcessExporter.cs deleted file mode 100644 index bd95513a5923..000000000000 --- a/dotnet/src/Experimental/Process.Core/ProcessExporter.cs +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Linq; -using System.Text.Json; -using YamlDotNet.Serialization; -using YamlDotNet.Serialization.NamingConventions; - -namespace Microsoft.SemanticKernel.Process; - -/// -/// Export a process to a string representation. -/// -public sealed class ProcessExporter -{ - /// - /// Export a process to a string representation. - /// - /// - /// - public static string ExportProcess(KernelProcess process) - { - Verify.NotNull(process); - - Workflow workflow = new() - { - Name = process.State.Name, - Description = process.Description, - FormatVersion = "1.0", - WorkflowVersion = process.State.Version, - Nodes = [.. process.Steps.Select(step => GetNodeFromStep(step))], - // Orchestration - // Suggested Inputs - // Variables - // Schema - // Error handling - }; - - return ""; - } - - private static Node GetNodeFromStep(KernelProcessStepInfo stepInfo) - { - Verify.NotNull(stepInfo); - - if (stepInfo is KernelProcess) - { - throw new KernelException("Processes that contain a subprocess are not currently exportable."); - } - else if (stepInfo is KernelProcessAgentStep agentStep) - { - var agentNode = new Node() - { - Id = agentStep.State.Id ?? throw new KernelException("All steps must have an Id."), - Description = agentStep.Description, - Type = "agent", - Inputs = agentStep.Inputs.ToDictionary((kvp) => kvp.Key, (kvp) => - { - var value = kvp.Value; - var schema = KernelJsonSchemaBuilder.Build(value); - var schemaJson = JsonSerializer.Serialize(schema.RootElement); - - var deserializer = new DeserializerBuilder() - .WithNamingConvention(UnderscoredNamingConvention.Instance) - .IgnoreUnmatchedProperties() - .Build(); - - var yamlSchema = deserializer.Deserialize(schemaJson); - if (yamlSchema is null) - { - throw new KernelException("Failed to deserialize schema."); - } - - return yamlSchema; - }), - OnComplete = null, // TODO: OnComplete, - OnError = null // TODO: OnError - }; - } - else if (stepInfo is KernelProcessMap mapStep) - { - throw new KernelException("Processes that contain a map step are not currently exportable."); - } - else if (stepInfo is KernelProcessProxy proxyStep) - { - throw new KernelException("Processes that contain a proxy step are not currently exportable."); - } - else - { - throw new KernelException("Processes that contain non Foundry-Agent step are not currently exportable."); - } - - return new Node(); - } -} diff --git a/dotnet/src/Experimental/Process.Core/WorkflowSerializer.cs b/dotnet/src/Experimental/Process.Core/WorkflowSerializer.cs index b43fc4211f42..18634a10cf01 100644 --- a/dotnet/src/Experimental/Process.Core/WorkflowSerializer.cs +++ b/dotnet/src/Experimental/Process.Core/WorkflowSerializer.cs @@ -152,13 +152,13 @@ internal class SnakeCaseEnumConverter : IYamlTypeConverter { public bool Accepts(Type type) => type.IsEnum; - public object ReadYaml(IParser parser, Type type) + public object ReadYaml(IParser parser, Type type, ObjectDeserializer rootDeserializer) { var value = parser.Consume().Value; return Enum.Parse(type, value.Replace("_", ""), true); } - public void WriteYaml(IEmitter emitter, object? value, Type type) + public void WriteYaml(IEmitter emitter, object? value, Type type, ObjectSerializer serializer) { var enumValue = value?.ToString(); if (enumValue == null) diff --git a/dotnet/src/Experimental/Process.IntegrationTestHost.Dapr/Process.IntegrationTestHost.Dapr.csproj b/dotnet/src/Experimental/Process.IntegrationTestHost.Dapr/Process.IntegrationTestHost.Dapr.csproj index 9ce38a616a68..c5d8ba4378cb 100644 --- a/dotnet/src/Experimental/Process.IntegrationTestHost.Dapr/Process.IntegrationTestHost.Dapr.csproj +++ b/dotnet/src/Experimental/Process.IntegrationTestHost.Dapr/Process.IntegrationTestHost.Dapr.csproj @@ -7,7 +7,7 @@ enable enable false - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0080,SKEXP0110 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 true diff --git a/dotnet/src/Experimental/Process.IntegrationTestRunner.Dapr/Process.IntegrationTestRunner.Dapr.csproj b/dotnet/src/Experimental/Process.IntegrationTestRunner.Dapr/Process.IntegrationTestRunner.Dapr.csproj index 7f28b56abfe9..8029ae2ff601 100644 --- a/dotnet/src/Experimental/Process.IntegrationTestRunner.Dapr/Process.IntegrationTestRunner.Dapr.csproj +++ b/dotnet/src/Experimental/Process.IntegrationTestRunner.Dapr/Process.IntegrationTestRunner.Dapr.csproj @@ -7,7 +7,7 @@ enable enable false - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0080,SKEXP0110 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 true diff --git a/dotnet/src/Experimental/Process.UnitTests/Process.UnitTests.csproj b/dotnet/src/Experimental/Process.UnitTests/Process.UnitTests.csproj index 2d53676bcbb6..566e3f5559aa 100644 --- a/dotnet/src/Experimental/Process.UnitTests/Process.UnitTests.csproj +++ b/dotnet/src/Experimental/Process.UnitTests/Process.UnitTests.csproj @@ -11,28 +11,6 @@ $(NoWarn);CA2007,CA1812,CA1861,CA1063,VSTHRD111,SKEXP0001,SKEXP0050,SKEXP0080,SKEXP0110;OPENAI001,CA1024 - - - - - - - - - - Always - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - diff --git a/dotnet/src/Experimental/Process.UnitTests/ProcessSerializationTests.cs b/dotnet/src/Experimental/Process.UnitTests/ProcessSerializationTests.cs deleted file mode 100644 index a073c47b2a13..000000000000 --- a/dotnet/src/Experimental/Process.UnitTests/ProcessSerializationTests.cs +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.IO; -using System.Linq; -using System.Reflection; -using System.Threading.Tasks; -using Microsoft.SemanticKernel.Process.UnitTests.Steps; -using Xunit; - -namespace Microsoft.SemanticKernel.Process.UnitTests; - -/// -/// Unit testing of . -/// -public class ProcessSerializationTests -{ - /// - /// Verify initialization of . - /// - [Fact(Skip = "More work left to do.")] - public async Task KernelProcessFromYamlWorksAsync() - { - // Arrange - var yaml = this.ReadResource("workflow1.yaml"); - - // Act - var process = await ProcessBuilder.LoadFromYamlAsync(yaml); - - // Assert - Assert.NotNull(process); - } - - /// - /// Verify initialization of from a YAML file that contains only a .NET workflow. - /// - /// - [Fact] - public async Task KernelProcessFromDotnetOnlyWorkflow1YamlAsync() - { - // Arrange - var yaml = this.ReadResource("dotnetOnlyWorkflow1.yaml"); - - // Act - var process = await ProcessBuilder.LoadFromYamlAsync(yaml); - - // Assert - Assert.NotNull(process); - - var stepKickoff = process.Steps.FirstOrDefault(s => s.State.Id == "kickoff"); - var stepA = process.Steps.FirstOrDefault(s => s.State.Id == "a_step"); - var stepB = process.Steps.FirstOrDefault(s => s.State.Id == "b_step"); - var stepC = process.Steps.FirstOrDefault(s => s.State.Id == "c_step"); - - Assert.NotNull(stepKickoff); - Assert.NotNull(stepA); - Assert.NotNull(stepB); - Assert.NotNull(stepC); - - // kickoff step has outgoing edge to aStep and bStep on event startAStep - Assert.Single(stepKickoff.Edges); - var kickoffStartEdges = stepKickoff.Edges["kickoff.StartARequested"]; - Assert.Equal(2, kickoffStartEdges.Count); - Assert.Contains(kickoffStartEdges, e => (e.OutputTarget as KernelProcessFunctionTarget)!.StepId == "a_step"); - Assert.Contains(kickoffStartEdges, e => (e.OutputTarget as KernelProcessFunctionTarget)!.StepId == "b_step"); - - // aStep and bStep have grouped outgoing edges to cStep on event aStepDone and bStepDone - Assert.Single(stepA.Edges); - var aStepDoneEdges = stepA.Edges["a_step.AStepDone"]; - Assert.Single(aStepDoneEdges); - var aStepDoneEdge = aStepDoneEdges.First(); - Assert.Equal("c_step", (aStepDoneEdge.OutputTarget as KernelProcessFunctionTarget)!.StepId); - Assert.NotEmpty(aStepDoneEdge.GroupId ?? ""); - - Assert.Single(stepB.Edges); - var bStepDoneEdges = stepB.Edges["b_step.BStepDone"]; - Assert.Single(bStepDoneEdges); - var bStepDoneEdge = bStepDoneEdges.First(); - Assert.Equal("c_step", (bStepDoneEdge.OutputTarget as KernelProcessFunctionTarget)!.StepId); - Assert.NotEmpty(bStepDoneEdge.GroupId ?? ""); - - // cStep has outgoing edge to kickoff step on event cStepDone and one to end the process on event exitRequested - Assert.Equal(2, stepC.Edges.Count); - var cStepDoneEdges = stepC.Edges["c_step.CStepDone"]; - Assert.Single(cStepDoneEdges); - var cStepDoneEdge = cStepDoneEdges.First(); - Assert.Equal("kickoff", (cStepDoneEdge.OutputTarget as KernelProcessFunctionTarget)!.StepId); - Assert.Null(cStepDoneEdge.GroupId); - - var exitRequestedEdges = stepC.Edges["Microsoft.SemanticKernel.Process.EndStep"]; - Assert.Single(exitRequestedEdges); - var exitRequestedEdge = exitRequestedEdges.First(); - Assert.Equal("Microsoft.SemanticKernel.Process.EndStep", (exitRequestedEdge.OutputTarget as KernelProcessFunctionTarget)!.StepId); - - // edges to cStep are in the same group - Assert.Equal(aStepDoneEdge.GroupId, bStepDoneEdge.GroupId); - } - - /// - /// Verify initialization of from a YAML file that contains foundry_agents - /// - /// - [Fact] - public async Task KernelProcessFromScenario1YamlAsync() - { - // Arrange - var yaml = this.ReadResource("scenario1.yaml"); - // Act - var process = await ProcessBuilder.LoadFromYamlAsync(yaml); - // Assert - Assert.NotNull(process); - } - - /// - /// Verify that the process can be serialized to YAML and deserialized back to a workflow. - /// - /// - [Fact] - public async Task ProcessToWorkflowWorksAsync() - { - var process = this.GetProcess(); - var workflow = await WorkflowBuilder.BuildWorkflow(process); - string yaml = WorkflowSerializer.SerializeToYaml(workflow); - - Assert.NotNull(workflow); - } - - /// - /// Verify initialization of from a YAML file that contains references to C# class and chat completion agent. - /// - [Fact] - public async Task KernelProcessFromCombinedWorkflowYamlAsync() - { - // Arrange - var yaml = this.ReadResource("combined-workflow.yaml"); - - // Act - var process = await ProcessBuilder.LoadFromYamlAsync(yaml); - - // Assert - Assert.NotNull(process); - Assert.Contains(process.Steps, step => step.State.Id == "GetProductInfo"); - Assert.Contains(process.Steps, step => step.State.Id == "Summarize"); - } - - private KernelProcess GetProcess() - { - // Create the process builder. - ProcessBuilder processBuilder = new("ProcessWithDapr"); - - // Add some steps to the process. - var kickoffStep = processBuilder.AddStepFromType(); - var myAStep = processBuilder.AddStepFromType(); - var myBStep = processBuilder.AddStepFromType(); - - // ########## Configuring initial state on steps in a process ########### - // For demonstration purposes, we add the CStep and configure its initial state with a CurrentCycle of 1. - // Initializing state in a step can be useful for when you need a step to start out with a predetermines - // configuration that is not easily accomplished with dependency injection. - var myCStep = processBuilder.AddStepFromType(initialState: new() { CurrentCycle = 1 }); - - // Setup the input event that can trigger the process to run and specify which step and function it should be routed to. - processBuilder - .OnInputEvent(CommonEvents.StartProcess) - .SendEventTo(new ProcessFunctionTargetBuilder(kickoffStep)); - - // When the kickoff step is finished, trigger both AStep and BStep. - kickoffStep - .OnEvent(CommonEvents.StartARequested) - .SendEventTo(new ProcessFunctionTargetBuilder(myAStep)) - .SendEventTo(new ProcessFunctionTargetBuilder(myBStep)); - - processBuilder - .ListenFor() - .AllOf(new() - { - new(messageType: CommonEvents.AStepDone, source: myAStep), - new(messageType: CommonEvents.BStepDone, source: myBStep) - }) - .SendEventTo(new ProcessStepTargetBuilder(myCStep, inputMapping: (inputEvents) => - { - // Map the input events to the CStep's input parameters. - // In this case, we are mapping the output of AStep to the first input parameter of CStep - // and the output of BStep to the second input parameter of CStep. - return new() - { - { "astepdata", inputEvents[$"aStep.{CommonEvents.AStepDone}"] }, - { "bstepdata", inputEvents[$"bStep.{CommonEvents.BStepDone}"] } - }; - })); - - // When CStep has finished without requesting an exit, activate the Kickoff step to start again. - myCStep - .OnEvent(CommonEvents.CStepDone) - .SendEventTo(new ProcessFunctionTargetBuilder(kickoffStep)); - - // When the CStep has finished by requesting an exit, stop the process. - myCStep - .OnEvent(CommonEvents.ExitRequested) - .StopProcess(); - - var process = processBuilder.Build(); - return process; - } - - private string ReadResource(string name) - { - // Get the current assembly - Assembly assembly = Assembly.GetExecutingAssembly(); - - // Specify the resource name - string resourceName = $"SemanticKernel.Process.UnitTests.Resources.{name}"; - - // Get the resource stream - using (Stream? resourceStream = assembly.GetManifestResourceStream(resourceName)) - { - if (resourceStream != null) - { - using (StreamReader reader = new(resourceStream)) - { - string content = reader.ReadToEnd(); - return content; - } - } - else - { - throw new InvalidOperationException($"Resource {resourceName} not found in assembly {assembly.FullName}"); - } - } - } -} diff --git a/dotnet/src/Experimental/Process.UnitTests/Resources/combined-workflow.yaml b/dotnet/src/Experimental/Process.UnitTests/Resources/combined-workflow.yaml deleted file mode 100644 index 5934688bab80..000000000000 --- a/dotnet/src/Experimental/Process.UnitTests/Resources/combined-workflow.yaml +++ /dev/null @@ -1,44 +0,0 @@ -workflow: - id: combined_workflow - name: ProductSummarization - inputs: - events: - cloud_events: - - type: input_message_received - data_schema: - type: string - nodes: - - id: GetProductInfo - type: dotnet - description: Gets product information - agent: - type: SemanticKernel.Process.UnitTests.Steps.ProductInfoProvider, SemanticKernel.Process.UnitTests - on_complete: - - on_condition: - type: default - emits: - - event_type: GetProductInfo.OnResult - - id: Summarize - type: declarative - description: Summarizes the information - agent: - type: chat_completion_agent - name: SummarizationAgent - description: Summarizes the information - instructions: Summarize the provided information in 3 sentences - on_complete: - - on_condition: - type: default - emits: - - event_type: ProcessCompleted - orchestration: - - listen_for: - event: input_message_received - from: _workflow_ - then: - - node: GetProductInfo - - listen_for: - from: GetProductInfo - event: GetProductInfo.OnResult - then: - - node: Summarize diff --git a/dotnet/src/Experimental/Process.UnitTests/Resources/dotnetOnlyWorkflow1.yaml b/dotnet/src/Experimental/Process.UnitTests/Resources/dotnetOnlyWorkflow1.yaml deleted file mode 100644 index 4b7466b80038..000000000000 --- a/dotnet/src/Experimental/Process.UnitTests/Resources/dotnetOnlyWorkflow1.yaml +++ /dev/null @@ -1,205 +0,0 @@ -id: dotnetOnlyWorkflow1 -format_version: "1.0" # The version of the declarative spec being used to define this workflow. -workflow_version: "1.5" # The version of the workflow itself. -name: report_generation_pipeline -description: "A workflow that generates and publishes a report on a given topic." -suggested_inputs: - events: - - type: "research_requested" - payload: - topic: "Create a report on AI agents at Microsoft." - -# Input that the workflow supports. -# The way the events get sent to the workflow may differ depending on the platform. Some platforms may support sending events directly to the workflow, -# while others may require using a chat completion interface similar to how local tool calls work. -inputs: # The structured inputs supported by the workflow. - events: - cloud_events: - - type: "StartRequested" - data_schema: - type: string - - type: "StartARequested" - data_schema: - type: string - -# Schemas for the data types used in the workflow. These can be defined inline or referenced from an external schema. -schemas: - research_data: - type: object - properties: - summary: { type: string } - articles: { type: array, items: { type: string } } - required: [summary, articles] - - draft: - type: object - properties: - content: { type: string } - word_count: { type: integer } - required: [content, word_count] - - report_feedback: - type: object - properties: - passed: { type: boolean } - content: { type: string } - feedback: { type: string } - required: [passed, content, feedback] - - report: - type: object - properties: - content: { type: string } - approval_reason: { type: string } - required: [content, approval_reason] - -# The nodes that make up the workflow. A node is a wrapper around something that can be invoked such as code, an agent, a tool, etc. -nodes: - - id: kickoff - type: dotnet # dotnet | python - version: "1.0" - description: "Kickoff the workflow" - agent: - type: "Microsoft.SemanticKernel.Process.UnitTests.Steps.KickoffStep, SemanticKernel.Process.UnitTests" - id: kickoff_agent - inputs: - input: - type: string - agent_input_mapping: - topic: "inputs.input" - on_complete: - - on_condition: - type: Eval - expression: "results.articles.length > '0'" - emits: - - event_type: data_fetched - schema: - $ref: "#/workflow/schemas/research_data" - payload: "$agent.outputs.results" - - on_condition: - type: default - emits: - - event_type: data_fetch_no_results - - - id: a_step - type: dotnet - version: "1.0" - description: "A step" - inputs: - research_data: - schema: - $ref: "#/workflow/schemas/research_data" - last_feedback: - type: string - agent: - type: "Microsoft.SemanticKernel.Process.UnitTests.Steps.AStep, SemanticKernel.Process.UnitTests" - id: a_step_agent - on_complete: - - on_condition: - type: default - emits: - - event_type: draft_created - schema: - $ref: "#/workflow/schemas/draft" - payload: "$agent.outputs.draft" - - - id: b_step - type: dotnet - version: "1.0" - description: "B Step" - agent: - type: "Microsoft.SemanticKernel.Process.UnitTests.Steps.BStep, SemanticKernel.Process.UnitTests" - id: b_step_agent - on_complete: - - on_condition: - type: eval - expression: "report_feedback.passed == 'true'" - emits: - - event_type: report_approved - schema: - $ref: "#/workflow/schemas/report" - payload: - object: - content: "$agent.outputs.report_feedback.content" - approval_reason: "$agent.outputs.report_feedback.feedback" - - on_condition: - type: default - emits: - - event_type: report_rejected - schema: - $ref: "#/workflow/schemas/report_feedback" - payload: "$agent.outputs.report_feedback" - updates: - - variable: revision_count - operation: increment - value: 1 - - variable: last_feedback - operation: set - value: "$agent.outputs.report_feedback.feedback" - - - id: c_step - type: dotnet - version: "1.0" - description: "C Step" - agent: - type: "Microsoft.SemanticKernel.Process.UnitTests.Steps.CStep, SemanticKernel.Process.UnitTests" - id: c_step_agent - # inputs: - # type: string - agent_input_mapping: - event_payload: "$.inputs.report" - event_type: "human_approval_request" - on_complete: - - on_condition: - type: default - emits: - - event_type: human_approved - schema: - $ref: "#/workflow/schemas/report" - updates: - - variable: approved_report - operation: set - value: "$agent.outputs.report" - - variable: $workflow.thread - operation: set - value: "$agent.outputs.report" - -# The orchestration of the workflow. This defines the sequence of events and actions that make up the workflow. -orchestration: - - - listen_for: - event: "StartRequested" - from: _workflow_ - then: - - node: kickoff - - - listen_for: - event: "StartARequested" - from: kickoff - then: - - node: a_step - - node: b_step - - - listen_for: - all_of: - - event: "AStepDone" - from: a_step - - event: "BStepDone" - from: b_step - then: - - node: c_step - inputs: - aStepData: a_step.AStepDone - bStepData: b_step.BStepDone - - - listen_for: - event: "CStepDone" - from: c_step - then: - - node: kickoff - - - listen_for: - event: "ExitRequested" - from: c_step - then: - - node: End diff --git a/dotnet/src/Experimental/Process.UnitTests/Resources/scenario1.yaml b/dotnet/src/Experimental/Process.UnitTests/Resources/scenario1.yaml deleted file mode 100644 index 509ea58c479f..000000000000 --- a/dotnet/src/Experimental/Process.UnitTests/Resources/scenario1.yaml +++ /dev/null @@ -1,88 +0,0 @@ -id: two_agent_math_chat -format_version: "1.0" -name: student_teacher_chat -description: - A workflow that has student and teacher that does question answering - about math -inputs: - messages: - events: - cloud_events: - - type: "input_message_received" - data_schema: - type: string -variables: {} -schemas: {} -nodes: - - id: Student - type: declarative - version: "1.0" - description: Solves problem - agent: - type: foundry_agent - id: "{{student.id}}" - name: "{{student.name}}" - human_in_loop_mode: onNoMessage - stream_output: true - inputs: - Question: - type: messages - on_invoke: - on_error: - on_complete: - - on_condition: - type: default - emits: - - event_type: Answer - schema: - type: messages - - id: Teacher - type: declarative - version: "1.0" - description: Giving the problem - agent: - type: foundry_agent - id: "{{teacher.id}}" - name: "{{teacher.name}}" - human_in_loop_mode: never - stream_output: true - inputs: - Answer: - type: messages - on_invoke: - on_error: - on_complete: - - on_condition: - type: default - emits: - - event_type: Question - schema: - type: messages - - id: End - type: declarative - version: "1.0" - description: Terminal State - -orchestration: - - listen_for: - event: input_message_received - from: _workflow_ - then: - - node: Student - - listen_for: - event: Answer - from: Student - then: - - node: Teacher - - listen_for: - event: Question - from: Teacher - condition: Question.NotContains('[COMPLETE]') - then: - - node: Student - - listen_for: - event: Question - from: Teacher - condition: Question.Contains('[COMPLETE]') - then: - - node: End diff --git a/dotnet/src/Experimental/Process.UnitTests/Resources/workflow1.yaml b/dotnet/src/Experimental/Process.UnitTests/Resources/workflow1.yaml deleted file mode 100644 index 86591dbb46e3..000000000000 --- a/dotnet/src/Experimental/Process.UnitTests/Resources/workflow1.yaml +++ /dev/null @@ -1,396 +0,0 @@ -workflow: - format_version: "1.0" # The version of the declarative spec being used to define this workflow. - workflow_version: "1.5" # The version of the workflow itself. - name: report_generation_pipeline - description: "A workflow that generates and publishes a report on a given topic." - suggested_inputs: - events: - - type: "research_requested" - payload: - topic: "Create a report on AI agents at Microsoft." - - # Input that the workflow supports. - # The way the events get sent to the workflow may differ depending on the platform. Some platforms may support sending events directly to the workflow, - # while others may require using a chat completion interface similar to how local tool calls work. - inputs: # The structured inputs supported by the workflow. - events: - cloud_events: - - type: "research_requested" - data_schema: - type: string - filters: # optional filters on cloud event attributes - - filter: "$.source == 'my_input_source'" - - # Variables used by the agents in the workflow. Variables can be defined as read-only or mutable. - # Read-only variables are initialized with a default value and cannot be modified during the workflow execution. - variables: - max_retries: - type: integer # defaults to mutable: false - default: 3 - scope: "workflow" - report_length_threshold: - type: integer - default: 500 - research_history: - type: "chat_history" # defaults to scope: "run", should it be thread? - is_mutable: true - acls: - - node: "researcher" - access: "read" - drafting_history: - type: "chat_history" - is_mutable: true - research_memory: - type: "memory" - is_mutable: true - drafting_memory: - type: "memory" - is_mutable: true - drafting_whiteboard: - type: "whiteboard" - is_mutable: true - revision_count: - type: integer - default: 0 - is_mutable: true - last_feedback: - type: string - default: "" - is_mutable: true - approved_report: - type: "string" - is_mutable: true - - # Schemas for the data types used in the workflow. These can be defined inline or referenced from an external schema. - schemas: - research_data: - type: object - properties: - summary: { type: string } - articles: { type: array, items: { type: string } } - required: [summary, articles] - - draft: - type: object - properties: - content: { type: string } - word_count: { type: integer } - required: [content, word_count] - - report_feedback: - type: object - properties: - passed: { type: boolean } - content: { type: string } - feedback: { type: string } - required: [passed, content, feedback] - - report: - type: object - properties: - content: { type: string } - approval_reason: { type: string } - required: [content, approval_reason] - - # The nodes that make up the workflow. A node is a wrapper around something that can be invoked such as code, an agent, a tool, etc. - nodes: - - id: fetch_data - type: declarative - version: "1.0" - description: "Fetches relevant research data on the given topic." - agent: - type: foundry_agent - id: research_agent - # name: research_agent - # description: "Find the most relevant articles and summarize key points ${{topic}}." - # inputs: - # topic: - # type: string - # outputs: - # results: - # $ref: "#/workflow/schemas/research_data" - inputs: - input: - type: string - agent_input_mapping: - topic: "inputs.input" - on_invoke: # mvp? - emits: - updates: - on_error: # mvp? - emits: - updates: - on_complete: - - on_condition: - type: state - expression: "$agent.outputs.results.articles.length > 0" # json path or something standard, look at Azure pipelines, GH, etc. - emits: - - event_type: data_fetched - schema: - $ref: "#/workflow/schemas/research_data" - payload: "$agent.outputs.results" - - on_condition: - type: default - emits: - - event_type: data_fetch_no_results - - - id: draft_report - type: declarative - version: "1.0" - description: "Generates a draft report based on the research data." - inputs: - research_data: - schema: - $ref: "#/workflow/schemas/research_data" - last_feedback: - type: string - agent: - type: foundry_agent - id: report_drafter - # name: generate_draft - # prompt: "Create a well-structured draft based on the given research data." - # inputs: - # research_data: - # type: object - # $ref: "#/workflow/schemas/research_data" - # last_feedback: - # type: string - # outputs: - # draft: - # $ref: "#/workflow/schemas/draft" - on_invoke: # mvp? - emits: - updates: - on_error: # mvp? - emits: - updates: - on_complete: - - on_condition: - type: default - emits: - - event_type: draft_created - schema: - $ref: "#/workflow/schemas/draft" - payload: "$agent.outputs.draft" - - - id: proofread_report - type: declarative - version: "1.0" - description: "Proofreads the draft report for grammar, clarity, and factual accuracy." - agent: - type: foundry_agent - id: proofreader - # The agent is already deployed to Foundry and is only referenced here by Id. - # The definition looks like this: - # name: proofreader - # prompt: "Review the draft for grammar, clarity, and factual accuracy." - # inputs: - # draft: - # $ref: "#/workflow/schemas/draft" - # output: - # report_feedback: - # $ref: "#/workflow/schemas/report_feedback" - # tools: - # ... - inputs: - draft: - schema: - $ref: "#/workflow/schemas/draft" - agent_input_mapping: - draft: "$.inputs.draft" # Should only be needed when mapping is not 1:1. - - on_invoke: # mvp? - emits: - updates: - on_error: # mvp? - emits: - updates: - on_complete: - - on_condition: - type: state # need to support structured and unstructured evaluation - expression: "$agent.outputs.report_feedback.passed == true" # json path or something standard - emits: - - event_type: report_approved - schema: - $ref: "#/workflow/schemas/report" - payload: - object: - content: "$agent.outputs.report_feedback.content" - approval_reason: "$agent.outputs.report_feedback.feedback" - - on_condition: - type: default - emits: - - event_type: report_rejected - schema: - $ref: "#/workflow/schemas/report_feedback" - payload: "$agent.outputs.report_feedback" - updates: # discuss more with AK - - variable: $.variables.revision_count - operation: increment - value: 1 - - variable: last_feedback - operation: set - value: "$agent.outputs.report_feedback.feedback" - - - id: human_review - type: declarative - version: "1.0" - description: "Human reviewer for the final report." - # Could be a pre-built agent template from the Foundry Catalog. The behavior is to - # yield an event to the workflow and wait for a response before resuming. - agent: - type: foundry_agent - id: built_in/yield_event - # name: yield_event - # inputs: - # event_type: - # type: string - # event_payload: - # type: object - # output: - # event_response: - # type: object - inputs: - report: - schema: - $ref: "#/workflow/schemas/report" - agent_input_mapping: - event_payload: "$.inputs.report" - event_type: "human_approval_request" - on_invoke: # mvp? - emits: - updates: - on_error: # mvp? - emits: - updates: - on_complete: - - on_condition: - type: default - emits: - - event_type: human_approved - schema: - $ref: "#/workflow/schemas/report" - updates: - - variable: approved_report - operation: set - value: "$agent.outputs.report" - - variable: $workflow.thread - operation: append - value: "$agent.outputs.report" - - - id: error_handler - type: declarative - version: "1.0" - description: "Handles errors that occur during the workflow." - agent: - type: foundry_agent - id: built_in/yield_event - # name: yield_event - # inputs: - # event_type: - # type: string - # event_payload: - # type: object - # output: - # event_response: - # type: object - # schema: - # $ref: "#/workflow/schemas/human_approval_response" - inputs: - error_details: - type: object - error_type: - type: string - agent_input_mapping: - event_payload: "$.inputs.error" - event_type: "$.inputs.error_type" - on_invoke: # mvp? - emits: - updates: - on_error: # mvp? - emits: - updates: - on_complete: - - on_condition: - type: default - emits: - - event_type: human_approved - schema: - $ref: "#/workflow/schemas/report" - updates: # append to the main thread to "output" the answer - - variable: approved_report - operation: set - value: "$agent.outputs.report" - - # The orchestration of the workflow. This defines the sequence of events and actions that make up the workflow. - orchestration: - - - listen_for: - event: "research_requested" - from: $.workflow - then: - - node: fetch_data - inputs: - input: $.event.payload - last_feedback: "" - - - listen_for: - event: "data_fetched" - from: fetch_data - then: - - node: draft_report - inputs: - research_data: $.event.payload - - - listen_for: - event: "draft_created" - from: draft_report - then: - - node: proofread_report - inputs: - draft: $.event.payload - - - listen_for: - event: "report_approved" - from: proofread_report - then: - - node: human_review - inputs: # input mapping for different entry points - report: $.event.payload - - - listen_for: - all_of: # Want to also support any_of - AK needs to figure out implementation - - event: "report_approved" - from: proofread_report - - event: "human_approved" - from: human_review - then: - - node: publish_report - inputs: - report: $.event.payload - - # The compatibility matrix for the workflow. This defines the compatibility of the workflow with different versions of itself. - upgrade: - - from_versions: - min_version: "0.1" - max_version_exclusive: "1.0" - strategy: "not_compatible" - - from_versions: - min_version: "1.0" - max_version_exclusive: "*" - strategy: "backward_compatible" - - # The error handling for the workflow. This defines how errors are handled at different levels of the workflow. - error_handling: - on_error: - - listen_for: - event: "*_failed" - then: - - node: error_handler - inputs: - error_details: $.event.payload - error_type: "unknown_error" - default: - - node: logging_service - inputs: - error_details: $.event.payload diff --git a/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsTypeConverter.cs b/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsTypeConverter.cs index 12431752d6e5..410ad504bd30 100644 --- a/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsTypeConverter.cs +++ b/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsTypeConverter.cs @@ -25,7 +25,7 @@ public bool Accepts(Type type) } /// - public object? ReadYaml(IParser parser, Type type) + public object? ReadYaml(IParser parser, Type type, ObjectDeserializer rootDeserializer) { s_deserializer ??= new DeserializerBuilder() .WithNamingConvention(UnderscoredNamingConvention.Instance) @@ -58,7 +58,7 @@ public bool Accepts(Type type) } /// - public void WriteYaml(IEmitter emitter, object? value, Type type) + public void WriteYaml(IEmitter emitter, object? value, Type type, ObjectSerializer serializer) { throw new NotImplementedException(); } diff --git a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIResponseAgentFixture.cs b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIResponseAgentFixture.cs index 08ca4dc1f60a..5a14cd765374 100644 --- a/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIResponseAgentFixture.cs +++ b/dotnet/src/IntegrationTests/Agents/CommonInterfaceConformance/OpenAIResponseAgentFixture.cs @@ -101,6 +101,7 @@ public override async Task InitializeAsync() { Name = "HelpfulAssistant", Instructions = "You are a helpful assistant.", + StoreEnabled = true, Kernel = kernel }; this._thread = new OpenAIResponseAgentThread(this._responseClient); diff --git a/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockTextGenerationTests.cs b/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockTextGenerationTests.cs index aad9a3eb3c2b..9c2f8beaf7c6 100644 --- a/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockTextGenerationTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockTextGenerationTests.cs @@ -15,8 +15,6 @@ public class BedrockTextGenerationTests [InlineData("cohere.command-r-v1:0")] [InlineData("cohere.command-r-plus-v1:0")] [InlineData("ai21.jamba-instruct-v1:0")] - [InlineData("ai21.j2-ultra-v1")] - [InlineData("ai21.j2-mid-v1")] [InlineData("meta.llama3-70b-instruct-v1:0")] [InlineData("meta.llama3-8b-instruct-v1:0")] [InlineData("mistral.mistral-7b-instruct-v0:2")] diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_AutoFunctionChoiceBehaviorTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_AutoFunctionChoiceBehaviorTests.cs index 8a80403095f3..396d1766c322 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_AutoFunctionChoiceBehaviorTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_AutoFunctionChoiceBehaviorTests.cs @@ -6,6 +6,7 @@ using System.ComponentModel; using System.Globalization; using System.Linq; +using System.Net.Http; using System.Text; using System.Threading.Tasks; using Azure.Identity; @@ -14,13 +15,13 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using SemanticKernel.IntegrationTests.TestSettings; -using xRetry; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI; -public sealed class AzureOpenAIChatClientAutoFunctionChoiceBehaviorTests : BaseIntegrationTest +public sealed class AzureOpenAIChatClientAutoFunctionChoiceBehaviorTests : BaseIntegrationTest, IDisposable { + private HttpClient? _httpClient; private readonly Kernel _kernel; private readonly FakeFunctionFilter _autoFunctionInvocationFilter; private readonly IChatClient _chatClient; @@ -34,7 +35,7 @@ public AzureOpenAIChatClientAutoFunctionChoiceBehaviorTests() this._chatClient = this._kernel.GetRequiredService(); } - [RetryFact] + [Fact] public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionAutomaticallyAsync() { // Arrange @@ -65,7 +66,7 @@ public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionAutomat Assert.Contains("GetCurrentDate", invokedFunctions); } - [RetryFact] + [Fact] public async Task SpecifiedInPromptInstructsConnectorToInvokeKernelFunctionAutomaticallyAsync() { // Arrange @@ -443,6 +444,7 @@ public async Task SpecifiedInCodeInstructsAIModelToCallFunctionInParallelOrSeque private Kernel InitializeKernel() { + this._httpClient ??= new() { Timeout = TimeSpan.FromSeconds(100) }; var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); Assert.NotNull(azureOpenAIConfiguration); Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); @@ -454,11 +456,18 @@ private Kernel InitializeKernel() deploymentName: azureOpenAIConfiguration.ChatDeploymentName, modelId: azureOpenAIConfiguration.ChatModelId, endpoint: azureOpenAIConfiguration.Endpoint, - credentials: new AzureCliCredential()); + credentials: new AzureCliCredential(), + httpClient: this._httpClient); return kernelBuilder.Build(); } + public void Dispose() + { + this._httpClient?.Dispose(); + this._chatClient?.Dispose(); + } + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_NoneFunctionChoiceBehaviorTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_NoneFunctionChoiceBehaviorTests.cs index bc16c0c18119..82b11f31707c 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_NoneFunctionChoiceBehaviorTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_NoneFunctionChoiceBehaviorTests.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Globalization; +using System.Net.Http; using System.Text; using System.Threading.Tasks; using Azure.Identity; @@ -16,8 +17,9 @@ namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI; -public sealed class AzureOpenAIChatClientNoneFunctionChoiceBehaviorTests : BaseIntegrationTest +public sealed class AzureOpenAIChatClientNoneFunctionChoiceBehaviorTests : BaseIntegrationTest, IDisposable { + private HttpClient? _httpClient; private readonly Kernel _kernel; private readonly FakeFunctionFilter _autoFunctionInvocationFilter; private readonly IChatClient _chatClient; @@ -181,8 +183,15 @@ public async Task SpecifiedInPromptInstructsConnectorNotToInvokeKernelFunctionFo Assert.Empty(invokedFunctions); } + public void Dispose() + { + this._httpClient?.Dispose(); + this._chatClient?.Dispose(); + } + private Kernel InitializeKernel() { + this._httpClient ??= new() { Timeout = TimeSpan.FromSeconds(100) }; var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); Assert.NotNull(azureOpenAIConfiguration); Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); @@ -194,7 +203,8 @@ private Kernel InitializeKernel() deploymentName: azureOpenAIConfiguration.ChatDeploymentName, modelId: azureOpenAIConfiguration.ChatModelId, endpoint: azureOpenAIConfiguration.Endpoint, - credentials: new AzureCliCredential()); + credentials: new AzureCliCredential(), + httpClient: this._httpClient); return kernelBuilder.Build(); } diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_RequiredFunctionChoiceBehaviorTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_RequiredFunctionChoiceBehaviorTests.cs index aedc8e5bf540..93da2c0b8f1b 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_RequiredFunctionChoiceBehaviorTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClient_RequiredFunctionChoiceBehaviorTests.cs @@ -5,6 +5,7 @@ using System.ComponentModel; using System.Globalization; using System.Linq; +using System.Net.Http; using System.Threading.Tasks; using Azure.Identity; using Microsoft.Extensions.AI; @@ -12,13 +13,13 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using SemanticKernel.IntegrationTests.TestSettings; -using xRetry; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI; -public sealed class AzureOpenAIChatClientRequiredFunctionChoiceBehaviorTests : BaseIntegrationTest +public sealed class AzureOpenAIChatClientRequiredFunctionChoiceBehaviorTests : BaseIntegrationTest, IDisposable { + private HttpClient? _httpClient; private readonly Kernel _kernel; private readonly FakeFunctionFilter _autoFunctionInvocationFilter; private readonly IChatClient _chatClient; @@ -32,7 +33,7 @@ public AzureOpenAIChatClientRequiredFunctionChoiceBehaviorTests() this._chatClient = this._kernel.GetRequiredService(); } - [RetryFact] + [Fact] public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionAutomaticallyAsync() { // Arrange @@ -63,7 +64,7 @@ public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionAutomat Assert.Contains("GetCurrentDate", invokedFunctions); } - [RetryFact(Skip = "For manual verification only")] + [Fact] public async Task SpecifiedInPromptInstructsConnectorToInvokeKernelFunctionAutomaticallyAsync() { // Arrange @@ -352,8 +353,15 @@ public async Task SpecifiedInCodeInstructsConnectorToInvokeNonKernelFunctionManu Assert.Empty(invokedFunctions); } + public void Dispose() + { + this._httpClient?.Dispose(); + this._chatClient?.Dispose(); + } + private Kernel InitializeKernel() { + this._httpClient ??= new() { Timeout = TimeSpan.FromSeconds(100) }; var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); Assert.NotNull(azureOpenAIConfiguration); Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); @@ -365,7 +373,8 @@ private Kernel InitializeKernel() deploymentName: azureOpenAIConfiguration.ChatDeploymentName, modelId: azureOpenAIConfiguration.ChatModelId, endpoint: azureOpenAIConfiguration.Endpoint, - credentials: new AzureCliCredential()); + credentials: new AzureCliCredential(), + httpClient: this._httpClient); return kernelBuilder.Build(); } diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs index 61d78f743e29..b2c043291c86 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_AutoFunctionChoiceBehaviorTests.cs @@ -6,6 +6,7 @@ using System.ComponentModel; using System.Globalization; using System.Linq; +using System.Net.Http; using System.Text; using System.Threading.Tasks; using Azure.Identity; @@ -15,13 +16,13 @@ using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Connectors.OpenAI; using SemanticKernel.IntegrationTests.TestSettings; -using xRetry; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI; -public sealed class AzureOpenAIAutoFunctionChoiceBehaviorTests : BaseIntegrationTest +public sealed class AzureOpenAIAutoFunctionChoiceBehaviorTests : BaseIntegrationTest, IDisposable { + private HttpClient? _httpClient; private readonly Kernel _kernel; private readonly FakeFunctionFilter _autoFunctionInvocationFilter; private readonly IChatCompletionService _chatCompletionService; @@ -35,7 +36,7 @@ public AzureOpenAIAutoFunctionChoiceBehaviorTests() this._chatCompletionService = this._kernel.GetRequiredService(); } - [RetryFact] + [Fact] public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionAutomaticallyAsync() { // Arrange @@ -63,7 +64,7 @@ public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionAutomat Assert.Contains("GetCurrentDate", invokedFunctions); } - [RetryFact] + [Fact] public async Task SpecifiedInPromptInstructsConnectorToInvokeKernelFunctionAutomaticallyAsync() { // Arrange @@ -396,8 +397,14 @@ public async Task SpecifiedInCodeInstructsAIModelToCallFunctionInParallelOrSeque } } + public void Dispose() + { + this._httpClient?.Dispose(); + } + private Kernel InitializeKernel() { + this._httpClient ??= new() { Timeout = TimeSpan.FromSeconds(100) }; var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); Assert.NotNull(azureOpenAIConfiguration); Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); @@ -409,7 +416,8 @@ private Kernel InitializeKernel() deploymentName: azureOpenAIConfiguration.ChatDeploymentName, modelId: azureOpenAIConfiguration.ChatModelId, endpoint: azureOpenAIConfiguration.Endpoint, - credentials: new AzureCliCredential()); + credentials: new AzureCliCredential(), + httpClient: this._httpClient); return kernelBuilder.Build(); } diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NoneFunctionChoiceBehaviorTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NoneFunctionChoiceBehaviorTests.cs index c6285185b8c2..2560c1045106 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NoneFunctionChoiceBehaviorTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NoneFunctionChoiceBehaviorTests.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Globalization; +using System.Net.Http; using System.Text; using System.Threading.Tasks; using Azure.Identity; @@ -15,8 +16,9 @@ namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI; -public sealed class AzureOpenAINoneFunctionChoiceBehaviorTests : BaseIntegrationTest +public sealed class AzureOpenAINoneFunctionChoiceBehaviorTests : BaseIntegrationTest, IDisposable { + private HttpClient? _httpClient; private readonly Kernel _kernel; private readonly FakeFunctionFilter _autoFunctionInvocationFilter; @@ -160,8 +162,14 @@ public async Task SpecifiedInPromptInstructsConnectorNotToInvokeKernelFunctionFo Assert.Empty(invokedFunctions); } + public void Dispose() + { + this._httpClient?.Dispose(); + } + private Kernel InitializeKernel() { + this._httpClient ??= new() { Timeout = TimeSpan.FromSeconds(100) }; var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); Assert.NotNull(azureOpenAIConfiguration); Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); @@ -173,7 +181,8 @@ private Kernel InitializeKernel() deploymentName: azureOpenAIConfiguration.ChatDeploymentName, modelId: azureOpenAIConfiguration.ChatModelId, endpoint: azureOpenAIConfiguration.Endpoint, - credentials: new AzureCliCredential()); + credentials: new AzureCliCredential(), + httpClient: this._httpClient); return kernelBuilder.Build(); } diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_RequiredFunctionChoiceBehaviorTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_RequiredFunctionChoiceBehaviorTests.cs index 6d6c6373ce9d..49c82b60f62b 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_RequiredFunctionChoiceBehaviorTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_RequiredFunctionChoiceBehaviorTests.cs @@ -5,6 +5,7 @@ using System.ComponentModel; using System.Globalization; using System.Linq; +using System.Net.Http; using System.Threading.Tasks; using Azure.Identity; using Microsoft.Extensions.Configuration; @@ -13,13 +14,13 @@ using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Connectors.OpenAI; using SemanticKernel.IntegrationTests.TestSettings; -using xRetry; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI; -public sealed class AzureOpenAIRequiredFunctionChoiceBehaviorTests : BaseIntegrationTest +public sealed class AzureOpenAIRequiredFunctionChoiceBehaviorTests : BaseIntegrationTest, IDisposable { + private HttpClient? _httpClient; private readonly Kernel _kernel; private readonly FakeFunctionFilter _autoFunctionInvocationFilter; private readonly IChatCompletionService _chatCompletionService; @@ -74,7 +75,7 @@ public AzureOpenAIRequiredFunctionChoiceBehaviorTests() // Assert.Contains("GetCurrentDate", invokedFunctions); //} - [RetryFact] + [Fact] public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionAutomaticallyAsync() { // Arrange @@ -102,7 +103,7 @@ public async Task SpecifiedInCodeInstructsConnectorToInvokeKernelFunctionAutomat Assert.Contains("GetCurrentDate", invokedFunctions); } - [RetryFact] + [Fact] public async Task SpecifiedInPromptInstructsConnectorToInvokeKernelFunctionAutomaticallyAsync() { // Arrange @@ -400,8 +401,14 @@ public async Task SpecifiedInCodeInstructsConnectorToInvokeNonKernelFunctionManu Assert.Empty(invokedFunctions); } + public void Dispose() + { + this._httpClient?.Dispose(); + } + private Kernel InitializeKernel() { + this._httpClient ??= new() { Timeout = TimeSpan.FromSeconds(100) }; var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); Assert.NotNull(azureOpenAIConfiguration); Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); @@ -413,7 +420,8 @@ private Kernel InitializeKernel() deploymentName: azureOpenAIConfiguration.ChatDeploymentName, modelId: azureOpenAIConfiguration.ChatModelId, endpoint: azureOpenAIConfiguration.Endpoint, - credentials: new AzureCliCredential()); + credentials: new AzureCliCredential(), + httpClient: this._httpClient); return kernelBuilder.Build(); } diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs index 90984d8edb07..7645d9cf107e 100644 --- a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Net.Http; @@ -600,4 +601,34 @@ public async Task GoogleAIChatReturnsResponseWorksWithThinkingBudgetAsync() Assert.NotNull(streamResponses[0].Content); Assert.NotNull(responses[0].Content); } + + [RetryTheory(Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI)] // GoogleAI does not support labels yet + public async Task GoogleAIChatReturnsResponseWorksWithLabelsAsync(ServiceType serviceType) + { + // Arrange + ChatHistory chatHistory = []; + chatHistory.AddUserMessage("Hello, I'm Brandon, how are you?"); + chatHistory.AddAssistantMessage("I'm doing well, thanks for asking."); + chatHistory.AddUserMessage("Call me by my name and expand this abbreviation: LLM"); + + var sut = this.GetChatService(serviceType); + + var settings = new GeminiPromptExecutionSettings + { + Labels = new Dictionary() + { + ["label1"] = "value1", + ["label2"] = "value2" + } + }; + + // Act + var streamResponses = await sut.GetStreamingChatMessageContentsAsync(chatHistory, settings).ToListAsync(); + var responses = await sut.GetChatMessageContentsAsync(chatHistory, settings); + + // Assert + Assert.NotNull(streamResponses[0].Content); + Assert.NotNull(responses[0].Content); + } } diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaChatClientIntegrationTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaChatClientIntegrationTests.cs index 2a9d577c397e..9f7d8a192a48 100644 --- a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaChatClientIntegrationTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaChatClientIntegrationTests.cs @@ -32,7 +32,7 @@ public OllamaChatClientIntegrationTests(ITestOutputHelper output) .Build(); } - [Theory(Skip = "This test is for manual verification.")] + [Theory(Skip = "For manual verification only")] [InlineData("phi3")] [InlineData("llama3.2")] public async Task OllamaChatClientBasicUsageAsync(string modelId) @@ -56,7 +56,7 @@ public async Task OllamaChatClientBasicUsageAsync(string modelId) this._output.WriteLine($"Response: {response.Text}"); } - [Theory(Skip = "This test is for manual verification.")] + [Theory(Skip = "For manual verification only")] [InlineData("phi3")] [InlineData("llama3.2")] public async Task OllamaChatClientStreamingUsageAsync(string modelId) @@ -87,7 +87,7 @@ public async Task OllamaChatClientStreamingUsageAsync(string modelId) this._output.WriteLine($"Complete response: {responseText}"); } - [Theory(Skip = "This test is for manual verification.")] + [Theory(Skip = "For manual verification only")] [InlineData("phi3")] public async Task OllamaChatClientWithOptionsAsync(string modelId) { @@ -116,7 +116,7 @@ public async Task OllamaChatClientWithOptionsAsync(string modelId) this._output.WriteLine($"Response: {response.Text}"); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "For manual verification only")] public async Task OllamaChatClientServiceCollectionIntegrationAsync() { // Arrange @@ -146,7 +146,7 @@ public async Task OllamaChatClientServiceCollectionIntegrationAsync() this._output.WriteLine($"Response: {response.Text}"); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "For manual verification only")] public async Task OllamaChatClientKernelBuilderIntegrationAsync() { // Arrange @@ -173,7 +173,7 @@ public async Task OllamaChatClientKernelBuilderIntegrationAsync() this._output.WriteLine($"Response: {response.Text}"); } - [Fact] + [Fact(Skip = "For manual verification only")] public void OllamaChatClientMetadataTest() { // Arrange @@ -190,7 +190,7 @@ public void OllamaChatClientMetadataTest() Assert.Equal(modelId, metadata.DefaultModelId); } - [Fact(Skip = "This test is for manual verification.")] + [Fact(Skip = "For manual verification only")] public async Task OllamaChatClientWithKernelFunctionInvocationAsync() { // Arrange diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaChatCompletion_FunctionCallingTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaChatCompletion_FunctionCallingTests.cs index d358d82ba712..7605454ee708 100644 --- a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaChatCompletion_FunctionCallingTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaChatCompletion_FunctionCallingTests.cs @@ -7,6 +7,7 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using SemanticKernel.IntegrationTests.TestSettings; +using xRetry; using Xunit; using ChatMessageContent = Microsoft.SemanticKernel.ChatMessageContent; @@ -142,7 +143,7 @@ public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForManual Assert.Contains("rain", messageContent.Content, StringComparison.InvariantCultureIgnoreCase); } - [Fact(Skip = "For manual verification only")] + [RetryFact(Skip = "For manual verification only")] public async Task ConnectorAgnosticFunctionCallingModelClassesCanPassFunctionExceptionToConnectorAsync() { // Arrange @@ -158,7 +159,6 @@ public async Task ConnectorAgnosticFunctionCallingModelClassesCanPassFunctionExc // Act var messageContent = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); - var functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); while (functionCalls.Length != 0) @@ -215,7 +215,7 @@ public async Task ConnectorAgnosticFunctionCallingModelClassesSupportSimulatedFu Assert.Contains("tornado", messageContent.Content, StringComparison.InvariantCultureIgnoreCase); } - [Fact(Skip = "For manual verification only")] + [RetryFact(Skip = "For manual verification only")] public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForAutoFunctionCallingAsync() { // Arrange @@ -236,7 +236,7 @@ public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForAutoFu Assert.Equal(AuthorRole.User, userMessage.Role); // LLM requested the functions to call. - var getParallelFunctionCallRequestMessage = chatHistory[1]; + var getParallelFunctionCallRequestMessage = chatHistory.First(m => m.Items.Any(i => i is FunctionCallContent)); Assert.Equal(AuthorRole.Assistant, getParallelFunctionCallRequestMessage.Role); // Parallel Function Calls in the same request @@ -250,16 +250,16 @@ public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForAutoFu getWeatherForCityFunctionCallRequest = functionCalls[0]; // Connector invoked the Get_Weather_For_City function and added result to chat history. - getWeatherForCityFunctionCallResultMessage = chatHistory[2]; + getWeatherForCityFunctionCallResultMessage = chatHistory.First(m => m.Items.Any(i => i is FunctionResultContent)); - Assert.Equal("HelperFunctions-Get_Weather_For_City", getWeatherForCityFunctionCallRequest.FunctionName); + Assert.Equal("HelperFunctions_Get_Weather_For_City", getWeatherForCityFunctionCallRequest.FunctionName); Assert.NotNull(getWeatherForCityFunctionCallRequest.Id); Assert.Equal(AuthorRole.Tool, getWeatherForCityFunctionCallResultMessage.Role); Assert.Single(getWeatherForCityFunctionCallResultMessage.Items.OfType()); // Current function calling model adds TextContent item representing the result of the function call. var getWeatherForCityFunctionCallResult = getWeatherForCityFunctionCallResultMessage.Items.OfType().Single(); - Assert.Equal("HelperFunctions-Get_Weather_For_City", getWeatherForCityFunctionCallResult.FunctionName); + Assert.Equal("HelperFunctions_Get_Weather_For_City", getWeatherForCityFunctionCallResult.FunctionName); Assert.Equal(getWeatherForCityFunctionCallRequest.Id, getWeatherForCityFunctionCallResult.CallId); Assert.NotNull(getWeatherForCityFunctionCallResult.Result); } @@ -318,12 +318,12 @@ private Kernel CreateAndInitializeKernel(bool importHelperPlugin = false) Assert.NotNull(config); Assert.NotNull(config.Endpoint); - Assert.NotNull(config.ModelId); + Assert.NotNull(config.ModelId ?? "llama3.2"); var kernelBuilder = base.CreateKernelBuilder(); kernelBuilder.AddOllamaChatCompletion( - modelId: config.ModelId, + modelId: config.ModelId ?? "llama3.2", endpoint: new Uri(config.Endpoint)); var kernel = kernelBuilder.Build(); @@ -332,7 +332,7 @@ private Kernel CreateAndInitializeKernel(bool importHelperPlugin = false) { kernel.ImportPluginFromFunctions("HelperFunctions", [ - kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcTime", "Retrieves the current time in UTC."), + kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "Get_Current_Utc_Time", "Retrieves the current time in UTC."), kernel.CreateFunctionFromMethod((string cityName) => { return cityName switch diff --git a/dotnet/src/IntegrationTests/Connectors/Onnx/OnnxRuntimeGenAIChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Onnx/OnnxRuntimeGenAIChatClientTests.cs new file mode 100644 index 000000000000..9b5b374fe4bf --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Onnx/OnnxRuntimeGenAIChatClientTests.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft. All rights reserved. + +#pragma warning disable SKEXP0010 + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Onnx; + +public class OnnxRuntimeGenAIChatClientTests : BaseIntegrationTest +{ + [Fact(Skip = "For manual verification only")] + public async Task ItCanUseKernelInvokeAsyncWithChatClientAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernelWithChatClient(); + + var func = kernel.CreateFunctionFromPrompt("List the two planets after '{{$input}}', excluding moons, using bullet points."); + + // Act + var result = await func.InvokeAsync(kernel, new() { ["input"] = "Jupiter" }); + + // Assert + Assert.NotNull(result); + Assert.Contains("Saturn", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItCanUseKernelInvokeStreamingAsyncWithChatClientAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernelWithChatClient(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + StringBuilder fullResult = new(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { ["input"] = prompt })) + { + fullResult.Append(content); + } + + // Assert + Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItCanUseServiceGetResponseAsync() + { + using var chatClient = CreateChatClient(); + + var messages = new List + { + new(ChatRole.User, "Where is the most famous fish market in Seattle, Washington, USA?") + }; + + var response = await chatClient.GetResponseAsync(messages); + + // Assert + Assert.NotNull(response); + Assert.Contains("Pike Place", response.Text, StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItCanUseServiceGetStreamingResponseAsync() + { + using var chatClient = CreateChatClient(); + + var messages = new List + { + new(ChatRole.User, "Where is the most famous fish market in Seattle, Washington, USA?") + }; + + StringBuilder fullResult = new(); + + await foreach (var update in chatClient.GetStreamingResponseAsync(messages)) + { + fullResult.Append(update.Text); + } + + // Assert + Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + private static IChatClient CreateChatClient() + { + Assert.NotNull(Configuration.ModelPath); + Assert.NotNull(Configuration.ModelId); + + var services = new ServiceCollection(); + services.AddOnnxRuntimeGenAIChatClient(Configuration.ModelId); + + var serviceProvider = services.BuildServiceProvider(); + return serviceProvider.GetRequiredService(); + } + + #region internals + + private Kernel CreateAndInitializeKernelWithChatClient(HttpClient? httpClient = null) + { + Assert.NotNull(Configuration.ModelPath); + Assert.NotNull(Configuration.ModelId); + + var kernelBuilder = base.CreateKernelBuilder(); + + kernelBuilder.AddOnnxRuntimeGenAIChatClient( + modelPath: Configuration.ModelPath, + serviceId: Configuration.ServiceId); + + return kernelBuilder.Build(); + } + + private static OnnxConfiguration Configuration => new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build() + .GetRequiredSection("Onnx") + .Get()!; + + #endregion +} diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatClientTests.cs index 3351f2a78996..4f253afa22c3 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatClientTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatClientTests.cs @@ -148,7 +148,7 @@ public async Task CompletionWithDifferentLineEndingsAsync(string lineEnding) Assert.Contains("John", actual.GetValue(), StringComparison.OrdinalIgnoreCase); } - [Fact(Skip = "Currently not supported - Chat System Prompt is not surfacing as a system message level")] + [Fact] public async Task ChatSystemPromptIsNotIgnoredAsync() { // Arrange diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_NonStreamingTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_NonStreamingTests.cs index 735db93b23cf..e17b5609644b 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_NonStreamingTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_NonStreamingTests.cs @@ -163,7 +163,7 @@ public async Task ChatCompletionWithWebSearchAsync() Assert.NotEmpty(chatCompletion.Annotations); } - [Fact] + [Fact(Skip = "For manual verification only")] public async Task ChatCompletionWithAudioInputAndOutputAsync() { // Arrange diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 1e8f5dff0062..0c223043dcbb 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -5,7 +5,7 @@ net8.0 true false - $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0080,SKEXP0110,SKEXP0130,OPENAI001,MEVD9000 + $(NoWarn);CA2007,CA1861,VSTHRD111,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0080,SKEXP0110,SKEXP0130,OPENAI001,MEVD9000 b7762d10-e29b-4bb1-8b74-b6d69a667dd4 diff --git a/dotnet/src/IntegrationTests/Plugins/ContextualFunctionProviderTests.cs b/dotnet/src/IntegrationTests/Plugins/ContextualFunctionProviderTests.cs index d25588fecd6b..d28fc150aeaf 100644 --- a/dotnet/src/IntegrationTests/Plugins/ContextualFunctionProviderTests.cs +++ b/dotnet/src/IntegrationTests/Plugins/ContextualFunctionProviderTests.cs @@ -66,6 +66,7 @@ void OnModelInvokingAsync(ICollection newMessages, AIContext contex Name = "ReviewGuru", Instructions = "You are a friendly assistant that summarizes key points and sentiments from customer reviews.", Kernel = this._kernel, + UseImmutableKernel = true, // Usage of immutable kernel is required for the context provider feature. Arguments = new(new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new FunctionChoiceBehaviorOptions { RetainArgumentTypes = true }) }) }; @@ -106,6 +107,7 @@ void OnModelInvokingAsync(ICollection newMessages, AIContext contex Instructions = "You are a helpful assistant that helps with Azure resource management. " + "Avoid including the phrase like 'If you need further assistance or have any additional tasks, feel free to let me know!' in any responses.", Kernel = this._kernel, + UseImmutableKernel = true, // Usage of immutable kernel is required for the context provider feature. Arguments = new(new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new FunctionChoiceBehaviorOptions { RetainArgumentTypes = true }) }) }; diff --git a/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseOrchestrationTest.cs b/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseOrchestrationTest.cs index c2d3c6852ee2..0abc5de29af0 100644 --- a/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseOrchestrationTest.cs +++ b/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseOrchestrationTest.cs @@ -1,12 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. +using System.ClientModel; using System.Text; using System.Text.Json; +using Azure.AI.Agents.Persistent; +using Azure.Identity; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.AzureAI; +using Microsoft.SemanticKernel.Agents.OpenAI; using Microsoft.SemanticKernel.ChatCompletion; +using OpenAI.Assistants; /// /// Base class for samples that demonstrate the usage of host agents @@ -20,7 +26,7 @@ public abstract class BaseOrchestrationTest(ITestOutputHelper output) : BaseAgen protected new ILoggerFactory LoggerFactory => this.EnableLogging ? base.LoggerFactory : NullLoggerFactory.Instance; - protected ChatCompletionAgent CreateAgent(string instructions, string? description = null, string? name = null, Kernel? kernel = null) + protected ChatCompletionAgent CreateChatCompletionAgent(string instructions, string? description = null, string? name = null, Kernel? kernel = null) { return new ChatCompletionAgent @@ -32,6 +38,51 @@ protected ChatCompletionAgent CreateAgent(string instructions, string? descripti }; } + protected async Task CreateOpenAIAssistantAgentAsync(string instructions, string? description = null, string? name = null, Kernel? kernel = null) + { + var client = + this.UseOpenAIConfig ? + OpenAIAssistantAgent.CreateOpenAIClient(new ApiKeyCredential(this.ApiKey ?? throw new ConfigurationNotFoundException("OpenAI:ApiKey"))) : + !string.IsNullOrWhiteSpace(this.ApiKey) ? + OpenAIAssistantAgent.CreateAzureOpenAIClient(new ApiKeyCredential(this.ApiKey), new Uri(this.Endpoint!)) : + OpenAIAssistantAgent.CreateAzureOpenAIClient(new AzureCliCredential(), new Uri(this.Endpoint!)); + + var assistantClient = client.GetAssistantClient(); + + var createOptions = new AssistantCreationOptions + { + Name = name, + Description = description, + Instructions = instructions, + }; + + Assistant definition = await assistantClient.CreateAssistantAsync(this.Model, createOptions); + return new OpenAIAssistantAgent( + definition, + assistantClient) + { + Kernel = kernel ?? new Kernel(), + }; + } + + protected async Task CreateAzureAIAgentAsync(string instructions, string? description = null, string? name = null, Kernel? kernel = null, IEnumerable? tools = null) + { + var agentsClient = AzureAIAgent.CreateAgentsClient(TestConfiguration.AzureAI.Endpoint, new AzureCliCredential()); + + PersistentAgent definition = await agentsClient.Administration.CreateAgentAsync( + TestConfiguration.AzureAI.ChatModelId, + name, + description, + instructions, + tools); + + return + new(definition, agentsClient) + { + Kernel = kernel ?? new Kernel(), + }; + } + protected static void WriteResponse(ChatMessageContent response) { if (!string.IsNullOrEmpty(response.Content)) diff --git a/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseResponsesAgentTest.cs b/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseResponsesAgentTest.cs index ff760c89f4fe..b7d56e732c62 100644 --- a/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseResponsesAgentTest.cs +++ b/dotnet/src/InternalUtilities/samples/AgentUtilities/BaseResponsesAgentTest.cs @@ -4,14 +4,16 @@ using System.ClientModel.Primitives; using Microsoft.SemanticKernel.Agents.OpenAI; using OpenAI; +using OpenAI.Files; using OpenAI.Responses; +using OpenAI.VectorStores; /// /// Base class for samples that demonstrate the usage of . /// public abstract class BaseResponsesAgentTest : BaseAgentsTest { - protected BaseResponsesAgentTest(ITestOutputHelper output) : base(output) + protected BaseResponsesAgentTest(ITestOutputHelper output, string? model = null) : base(output) { var options = new OpenAIClientOptions(); if (this.EnableLogging) @@ -25,10 +27,16 @@ protected BaseResponsesAgentTest(ITestOutputHelper output) : base(output) }); } - this.Client = new(model: TestConfiguration.OpenAI.ModelId, credential: new ApiKeyCredential(TestConfiguration.OpenAI.ApiKey), options: options); + this.Client = new(model: model ?? TestConfiguration.OpenAI.ChatModelId, credential: new ApiKeyCredential(TestConfiguration.OpenAI.ApiKey), options: options); + this.FileClient = new OpenAIFileClient(TestConfiguration.OpenAI.ApiKey); + this.VectorStoreClient = new VectorStoreClient(TestConfiguration.OpenAI.ApiKey); } - protected bool EnableLogging { get; set; } = false; + protected OpenAIFileClient FileClient { get; set; } + + protected VectorStoreClient VectorStoreClient { get; set; } + + protected bool EnableLogging { get; set; } = true; /// protected override OpenAIResponseClient Client { get; } diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs index 95bf9f64561f..2918cc97d3ea 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs @@ -42,7 +42,6 @@ public static void Initialize(IConfigurationRoot configRoot) public static RedisConfig Redis => LoadSection(); public static JiraConfig Jira => LoadSection(); public static ChromaConfig Chroma => LoadSection(); - public static KustoConfig Kusto => LoadSection(); public static MongoDBConfig MongoDB => LoadSection(); public static ChatGPTRetrievalPluginConfig ChatGPTRetrievalPlugin => LoadSection(); public static MsGraphConfiguration MSGraph => LoadSection(); @@ -54,6 +53,7 @@ public static void Initialize(IConfigurationRoot configRoot) public static CrewAIConfig CrewAI => LoadSection(); public static BedrockConfig Bedrock => LoadSection(); public static BedrockAgentConfig BedrockAgent => LoadSection(); + public static A2AConfig A2A => LoadSection(); public static Mem0Config Mem0 => LoadSection(); public static IConfigurationSection GetSection(string caller) @@ -224,11 +224,6 @@ public class ChromaConfig public string Endpoint { get; set; } } - public class KustoConfig - { - public string ConnectionString { get; set; } - } - public class MongoDBConfig { public string ConnectionString { get; set; } @@ -359,6 +354,11 @@ public class BedrockAgentConfig public string? KnowledgeBaseId { get; set; } } + public class A2AConfig + { + public Uri AgentUrl { get; set; } = new Uri("http://localhost:5000"); + } + public class Mem0Config { public string? BaseAddress { get; set; } diff --git a/dotnet/src/Plugins/Plugins.AI.UnitTests/Plugins.AI.UnitTests.csproj b/dotnet/src/Plugins/Plugins.AI.UnitTests/Plugins.AI.UnitTests.csproj index 00d08ca13f1a..eb13c8f88562 100644 --- a/dotnet/src/Plugins/Plugins.AI.UnitTests/Plugins.AI.UnitTests.csproj +++ b/dotnet/src/Plugins/Plugins.AI.UnitTests/Plugins.AI.UnitTests.csproj @@ -26,7 +26,6 @@ all - diff --git a/dotnet/src/Plugins/Plugins.AI/Plugins.AI.csproj b/dotnet/src/Plugins/Plugins.AI/Plugins.AI.csproj index 472d0d6b3c2f..c7aa68fc3f4a 100644 --- a/dotnet/src/Plugins/Plugins.AI/Plugins.AI.csproj +++ b/dotnet/src/Plugins/Plugins.AI/Plugins.AI.csproj @@ -22,7 +22,6 @@ - diff --git a/dotnet/src/Plugins/Plugins.Core/Plugins.Core.csproj b/dotnet/src/Plugins/Plugins.Core/Plugins.Core.csproj index 65b889ffaa86..fbe9c9b45ced 100644 --- a/dotnet/src/Plugins/Plugins.Core/Plugins.Core.csproj +++ b/dotnet/src/Plugins/Plugins.Core/Plugins.Core.csproj @@ -22,7 +22,6 @@ - diff --git a/dotnet/src/Plugins/Plugins.UnitTests/Plugins.UnitTests.csproj b/dotnet/src/Plugins/Plugins.UnitTests/Plugins.UnitTests.csproj index d71c37008091..c1661fedddaf 100644 --- a/dotnet/src/Plugins/Plugins.UnitTests/Plugins.UnitTests.csproj +++ b/dotnet/src/Plugins/Plugins.UnitTests/Plugins.UnitTests.csproj @@ -26,7 +26,6 @@ all - diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs index 833be1812c60..2f17f396b358 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -32,13 +32,19 @@ internal static Task GetResponseAsync( var chatOptions = executionSettings.ToChatOptions(kernel); // Try to parse the text as a chat history - if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt)) + if (!ChatPromptParser.TryParse(prompt, out ChatHistory? chatHistory)) { - var messageList = chatHistoryFromPrompt.ToChatMessageList(); - return chatClient.GetResponseAsync(messageList, chatOptions, cancellationToken); + chatHistory = [new ChatMessageContent(AuthorRole.User, prompt)]; } - return chatClient.GetResponseAsync(prompt, chatOptions, cancellationToken); + // Check if the execution settings is present and attempt to prepare the chat history for the request + if (executionSettings is not null) + { + chatHistory = executionSettings.ChatClientPrepareChatHistoryForRequest(chatHistory); + } + + var messageList = chatHistory.ToChatMessageList(); + return chatClient.GetResponseAsync(messageList, chatOptions, cancellationToken); } /// Get ChatClient streaming response for the prompt, settings and kernel. @@ -58,14 +64,21 @@ internal static IAsyncEnumerable GetStreamingResponseAsync( var chatOptions = executionSettings.ToChatOptions(kernel); // Try to parse the text as a chat history - if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt)) + if (!ChatPromptParser.TryParse(prompt, out ChatHistory? chatHistory)) { - var messageList = chatHistoryFromPrompt.ToChatMessageList(); - return chatClient.GetStreamingResponseAsync(messageList, chatOptions, cancellationToken); + chatHistory = [new ChatMessageContent(AuthorRole.User, prompt)]; } + // Check if the execution settings is present and attempt to prepare the chat history for the request + if (executionSettings is not null) + { + chatHistory = executionSettings.ChatClientPrepareChatHistoryForRequest(chatHistory); + } + + var messageList = chatHistory.ToChatMessageList(); + // Otherwise, use the prompt as the chat user message - return chatClient.GetStreamingResponseAsync(prompt, chatOptions, cancellationToken); + return chatClient.GetStreamingResponseAsync(messageList, chatOptions, cancellationToken); } /// Creates an for the specified . diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs index 561838ceb660..f9198c46c4c4 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.ChatCompletion; @@ -30,8 +31,14 @@ internal static ChatMessageContent ToChatMessageContent(this ChatMessage message Microsoft.Extensions.AI.UriContent uc when uc.HasTopLevelMediaType("audio") => new Microsoft.SemanticKernel.AudioContent(uc.Uri), Microsoft.Extensions.AI.DataContent dc => new Microsoft.SemanticKernel.BinaryContent(dc.Uri), Microsoft.Extensions.AI.UriContent uc => new Microsoft.SemanticKernel.BinaryContent(uc.Uri), - Microsoft.Extensions.AI.FunctionCallContent fcc => new Microsoft.SemanticKernel.FunctionCallContent(fcc.Name, null, fcc.CallId, fcc.Arguments is not null ? new(fcc.Arguments) : null), - Microsoft.Extensions.AI.FunctionResultContent frc => new Microsoft.SemanticKernel.FunctionResultContent(callId: frc.CallId, result: frc.Result), + Microsoft.Extensions.AI.FunctionCallContent fcc => new Microsoft.SemanticKernel.FunctionCallContent( + functionName: fcc.Name, + id: fcc.CallId, + arguments: fcc.Arguments is not null ? new(fcc.Arguments) : null), + Microsoft.Extensions.AI.FunctionResultContent frc => new Microsoft.SemanticKernel.FunctionResultContent( + functionName: GetFunctionCallContent(frc.CallId)?.Name, + callId: frc.CallId, + result: frc.Result), _ => null }; @@ -45,6 +52,12 @@ internal static ChatMessageContent ToChatMessageContent(this ChatMessage message } return result; + + Microsoft.Extensions.AI.FunctionCallContent? GetFunctionCallContent(string callId) + => response?.Messages + .Select(m => m.Contents + .FirstOrDefault(c => c is Microsoft.Extensions.AI.FunctionCallContent fcc && fcc.CallId == callId) as Microsoft.Extensions.AI.FunctionCallContent) + .FirstOrDefault(fcc => fcc is not null); } /// Converts a list of to a . diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs index e8251d450a77..de57425e8d3b 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs @@ -50,6 +50,11 @@ public async Task> GetChatMessageContentsAsync { Verify.NotNull(chatHistory); + if (executionSettings is not null) + { + chatHistory = executionSettings.ChatClientPrepareChatHistoryForRequest(chatHistory); + } + var messageList = chatHistory.ToChatMessageList(); var currentSize = messageList.Count; @@ -79,6 +84,11 @@ public async IAsyncEnumerable GetStreamingChatMessa { Verify.NotNull(chatHistory); + if (executionSettings is not null) + { + chatHistory = executionSettings.ChatClientPrepareChatHistoryForRequest(chatHistory); + } + List fcContents = []; ChatRole? role = null; diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettings.cs b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettings.cs index b767c478b1bc..992651fda137 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettings.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettings.cs @@ -167,6 +167,22 @@ protected void ThrowIfFrozen() } } + /// + /// When some specialized is used, this method can be overridden to prepare the chat history before the request is sent based on the + /// current settings configuration + /// + /// Chat history to prepare. + /// Returns the prepared chat history. + protected virtual ChatHistory PrepareChatHistoryForRequest(ChatHistory chatHistory) => chatHistory; + + /// + /// This method is intended to be used only by the for applying any pre-request transformation to the chat history + /// without the need to make the public. + /// + /// Target chat history to prepare. + /// Prepared chat history. + internal ChatHistory ChatClientPrepareChatHistoryForRequest(ChatHistory chatHistory) => this.PrepareChatHistoryForRequest(chatHistory); + #region private ================================================================================ private string? _modelId; diff --git a/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs index 276d500ce787..1cd65399a297 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs @@ -1,15 +1,20 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics.CodeAnalysis; using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel; -internal static class ChatMessageContentExtensions +/// Provides extension methods for . +[Experimental("SKEXP0001")] +public static class ChatMessageContentExtensions { /// Converts a to a . /// This conversion should not be necessary once SK eventually adopts the shared content types. - internal static ChatMessage ToChatMessage(this ChatMessageContent content) + public static ChatMessage ToChatMessage(this ChatMessageContent content) { + Verify.NotNull(content); + ChatMessage message = new() { AdditionalProperties = content.Metadata is not null ? new(content.Metadata) : null, diff --git a/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs index 78e2f8445a78..859099644804 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs @@ -1,18 +1,22 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Text.Json; using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel; /// Provides extension methods for . -internal static class StreamingChatMessageContentExtensions +[Experimental("SKEXP0001")] +public static class StreamingChatMessageContentExtensions { /// Converts a to a . /// This conversion should not be necessary once SK eventually adopts the shared content types. - internal static ChatResponseUpdate ToChatResponseUpdate(this StreamingChatMessageContent content) + public static ChatResponseUpdate ToChatResponseUpdate(this StreamingChatMessageContent content) { + Verify.NotNull(content); + ChatResponseUpdate update = new() { AdditionalProperties = content.Metadata is not null ? new AdditionalPropertiesDictionary(content.Metadata) : null, diff --git a/dotnet/src/SemanticKernel.Core/Data/TextSearchBehavior/TextSearchProvider.cs b/dotnet/src/SemanticKernel.Core/Data/TextSearchBehavior/TextSearchProvider.cs index 6d68ecd0cd45..6ee680d91826 100644 --- a/dotnet/src/SemanticKernel.Core/Data/TextSearchBehavior/TextSearchProvider.cs +++ b/dotnet/src/SemanticKernel.Core/Data/TextSearchBehavior/TextSearchProvider.cs @@ -75,7 +75,7 @@ public override async Task ModelInvokingAsync(ICollection SearchAsync(string userQuestion, CancellationToken c { var searchResults = await this._textSearch.GetTextSearchResultsAsync( userQuestion, - new() { Top = this.Options.Top }, + new() { Top = this.Options.Top, Filter = this.Options.Filter }, cancellationToken: cancellationToken).ConfigureAwait(false); var results = await searchResults.Results.ToListAsync(cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/SemanticKernel.Core/Data/TextSearchBehavior/TextSearchProviderOptions.cs b/dotnet/src/SemanticKernel.Core/Data/TextSearchBehavior/TextSearchProviderOptions.cs index 82ecc01c4c17..f6cd9008012b 100644 --- a/dotnet/src/SemanticKernel.Core/Data/TextSearchBehavior/TextSearchProviderOptions.cs +++ b/dotnet/src/SemanticKernel.Core/Data/TextSearchBehavior/TextSearchProviderOptions.cs @@ -33,6 +33,11 @@ public int Top } } + /// + /// Gets or sets the filter expression to apply to the search query. + /// + public TextSearchFilter? Filter { get; init; } + /// /// Gets or sets the time at which the text search is performed. /// diff --git a/dotnet/src/SemanticKernel.Core/Text/TextChunker.cs b/dotnet/src/SemanticKernel.Core/Text/TextChunker.cs index 333528bf5e50..d8f4a32b4e3c 100644 --- a/dotnet/src/SemanticKernel.Core/Text/TextChunker.cs +++ b/dotnet/src/SemanticKernel.Core/Text/TextChunker.cs @@ -52,7 +52,7 @@ private sealed class StringListWithTokenCount(TextChunker.TokenCounter? tokenCou public delegate int TokenCounter(string input); private static readonly char[] s_spaceChar = [' ']; - private static readonly string?[] s_plaintextSplitOptions = ["\n\r", ".。.", "?!", ";", ":", ",,、", ")]}", " ", "-", null]; + private static readonly string?[] s_plaintextSplitOptions = ["\n", ".。.", "?!", ";", ":", ",,、", ")]}", " ", "-", null]; private static readonly string?[] s_markdownSplitOptions = [".\u3002\uFF0E", "?!", ";", ":", ",\uFF0C\u3001", ")]}", " ", "-", "\n\r", null]; /// @@ -84,8 +84,21 @@ public static List SplitMarkDownLines(string text, int maxTokensPerLine, /// Text to be prepended to each individual chunk. /// Function to count tokens in a string. If not supplied, the default counter will be used. /// List of paragraphs. - public static List SplitPlainTextParagraphs(IEnumerable lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null) => - InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader, static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, s_plaintextSplitOptions, tokenCounter), tokenCounter); + public static List SplitPlainTextParagraphs( + IEnumerable lines, + int maxTokensPerParagraph, + int overlapTokens = 0, + string? chunkHeader = null, + TokenCounter? tokenCounter = null) => + InternalSplitTextParagraphs( + lines.Select(line => line + .Replace("\r\n", "\n") + .Replace('\r', '\n')), + maxTokensPerParagraph, + overlapTokens, + chunkHeader, + static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, s_plaintextSplitOptions, tokenCounter), + tokenCounter); /// /// Split markdown text into paragraphs. diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/ChatClientChatCompletionServiceConversionTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/ChatClientChatCompletionServiceConversionTests.cs index 703d01a95d99..8ef515b416b1 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/ChatClientChatCompletionServiceConversionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/ChatClientChatCompletionServiceConversionTests.cs @@ -418,6 +418,277 @@ public async Task GetStreamingChatMessageContentsAsyncWithTextAndUsageContentCre Assert.Equal("test-model", message.ModelId); } + [Fact] + public async Task GetChatMessageContentsAsyncCallsPrepareChatHistoryToRequestAsync() + { + // Arrange + var originalChatHistory = new ChatHistory(); + originalChatHistory.AddUserMessage("Original message"); + + var modifiedChatHistory = new ChatHistory(); + modifiedChatHistory.AddSystemMessage("System message added by PrepareChatHistoryToRequestAsync"); + modifiedChatHistory.AddUserMessage("Original message"); + + var testSettings = new TestPromptExecutionSettings(modifiedChatHistory); + + using var chatClient = new TestChatClient + { + CompleteAsyncDelegate = (messages, options, cancellationToken) => + { + // Verify that the chat client receives the modified chat history + Assert.Equal(2, messages.Count()); + Assert.Equal("System message added by PrepareChatHistoryToRequestAsync", messages.First().Text); + Assert.Equal("Original message", messages.Last().Text); + + return Task.FromResult(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Test response")])); + } + }; + + var service = chatClient.AsChatCompletionService(); + + // Act + var result = await service.GetChatMessageContentsAsync(originalChatHistory, testSettings); + + // Assert + Assert.Single(result); + Assert.True(testSettings.PrepareChatHistoryWasCalled); + + // Verify that the original chat history reference was passed to PrepareChatHistoryToRequestAsync + Assert.Same(originalChatHistory, testSettings.ReceivedChatHistory); + } + + [Fact] + public async Task GetStreamingChatMessageContentsAsyncCallsPrepareChatHistoryToRequestAsync() + { + // Arrange + var originalChatHistory = new ChatHistory(); + originalChatHistory.AddUserMessage("Original message"); + + var modifiedChatHistory = new ChatHistory(); + modifiedChatHistory.AddSystemMessage("System message added by PrepareChatHistoryToRequestAsync"); + modifiedChatHistory.AddUserMessage("Original message"); + + var testSettings = new TestPromptExecutionSettings(modifiedChatHistory); + + using var chatClient = new TestChatClient + { + CompleteStreamingAsyncDelegate = (messages, options, cancellationToken) => + { + // Verify that the chat client receives the modified chat history + Assert.Equal(2, messages.Count()); + Assert.Equal("System message added by PrepareChatHistoryToRequestAsync", messages.First().Text); + Assert.Equal("Original message", messages.Last().Text); + + return new[] + { + new ChatResponseUpdate(ChatRole.Assistant, "Test streaming response") + }.ToAsyncEnumerable(); + } + }; + + var service = chatClient.AsChatCompletionService(); + + // Act + var results = new List(); + await foreach (var update in service.GetStreamingChatMessageContentsAsync(originalChatHistory, testSettings)) + { + results.Add(update); + } + + // Assert + Assert.Single(results); + Assert.True(testSettings.PrepareChatHistoryWasCalled); + + // Verify that the original chat history reference was passed to PrepareChatHistoryToRequestAsync + Assert.Same(originalChatHistory, testSettings.ReceivedChatHistory); + } + + [Fact] + public async Task GetChatMessageContentsAsyncWithNullExecutionSettingsDoesNotCallPrepareChatHistory() + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Test message"); + + using var chatClient = new TestChatClient + { + CompleteAsyncDelegate = (messages, options, cancellationToken) => + { + // Verify that the chat client receives the original chat history unchanged + Assert.Single(messages); + Assert.Equal("Test message", messages.First().Text); + + return Task.FromResult(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Test response")])); + } + }; + + var service = chatClient.AsChatCompletionService(); + + // Act + var result = await service.GetChatMessageContentsAsync(chatHistory, executionSettings: null); + + // Assert + Assert.Single(result); + } + + [Fact] + public async Task GetStreamingChatMessageContentsAsyncWithNullExecutionSettingsDoesNotCallPrepareChatHistory() + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Test message"); + + using var chatClient = new TestChatClient + { + CompleteStreamingAsyncDelegate = (messages, options, cancellationToken) => + { + // Verify that the chat client receives the original chat history unchanged + Assert.Single(messages); + Assert.Equal("Test message", messages.First().Text); + + return new[] + { + new ChatResponseUpdate(ChatRole.Assistant, "Test streaming response") + }.ToAsyncEnumerable(); + } + }; + + var service = chatClient.AsChatCompletionService(); + + // Act + var results = new List(); + await foreach (var update in service.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings: null)) + { + results.Add(update); + } + + // Assert + Assert.Single(results); + } + + [Fact] + public async Task GetChatMessageContentsAsyncWithMutatingPrepareChatHistoryPreservesChatHistoryMutations() + { + // Arrange + var originalChatHistory = new ChatHistory(); + originalChatHistory.AddUserMessage("Original message"); + + var testSettings = new MutatingTestPromptExecutionSettings(); + + using var chatClient = new TestChatClient + { + CompleteAsyncDelegate = (messages, options, cancellationToken) => + { + // Verify that the chat client receives the mutated chat history + Assert.Equal(2, messages.Count()); + Assert.Equal("System message added by mutation", messages.First().Text); + Assert.Equal("Original message", messages.Last().Text); + + return Task.FromResult(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Test response")])); + } + }; + + var service = chatClient.AsChatCompletionService(); + + // Act + var result = await service.GetChatMessageContentsAsync(originalChatHistory, testSettings); + + // Assert + Assert.Single(result); + Assert.True(testSettings.PrepareChatHistoryWasCalled); + + // Verify that the original chat history was mutated and the mutations are preserved + Assert.Equal(2, originalChatHistory.Count); + Assert.Equal("System message added by mutation", originalChatHistory[0].Content); + Assert.Equal("Original message", originalChatHistory[1].Content); + } + + [Fact] + public async Task GetStreamingChatMessageContentsAsyncWithMutatingPrepareChatHistoryPreservesChatHistoryMutations() + { + // Arrange + var originalChatHistory = new ChatHistory(); + originalChatHistory.AddUserMessage("Original message"); + + var testSettings = new MutatingTestPromptExecutionSettings(); + + using var chatClient = new TestChatClient + { + CompleteStreamingAsyncDelegate = (messages, options, cancellationToken) => + { + // Verify that the chat client receives the mutated chat history + Assert.Equal(2, messages.Count()); + Assert.Equal("System message added by mutation", messages.First().Text); + Assert.Equal("Original message", messages.Last().Text); + + return new[] + { + new ChatResponseUpdate(ChatRole.Assistant, "Test streaming response") + }.ToAsyncEnumerable(); + } + }; + + var service = chatClient.AsChatCompletionService(); + + // Act + var results = new List(); + await foreach (var update in service.GetStreamingChatMessageContentsAsync(originalChatHistory, testSettings)) + { + results.Add(update); + } + + // Assert + Assert.Single(results); + Assert.True(testSettings.PrepareChatHistoryWasCalled); + + // Verify that the original chat history was mutated and the mutations are preserved + Assert.Equal(2, originalChatHistory.Count); + Assert.Equal("System message added by mutation", originalChatHistory[0].Content); + Assert.Equal("Original message", originalChatHistory[1].Content); + } + + /// + /// Test implementation of PromptExecutionSettings that overrides PrepareChatHistoryToRequestAsync. + /// + private sealed class TestPromptExecutionSettings : PromptExecutionSettings + { + private readonly ChatHistory _modifiedChatHistory; + + public bool PrepareChatHistoryWasCalled { get; private set; } + public ChatHistory? ReceivedChatHistory { get; private set; } + + public TestPromptExecutionSettings(ChatHistory modifiedChatHistory) + { + this._modifiedChatHistory = modifiedChatHistory; + } + + protected override ChatHistory PrepareChatHistoryForRequest(ChatHistory chatHistory) + { + this.PrepareChatHistoryWasCalled = true; + this.ReceivedChatHistory = chatHistory; + return this._modifiedChatHistory; + } + } + + /// + /// Test implementation of PromptExecutionSettings that mutates the original chat history. + /// + private sealed class MutatingTestPromptExecutionSettings : PromptExecutionSettings + { + public bool PrepareChatHistoryWasCalled { get; private set; } + + protected override ChatHistory PrepareChatHistoryForRequest(ChatHistory chatHistory) + { + this.PrepareChatHistoryWasCalled = true; + + // Mutate the original chat history by inserting a system message at the beginning + chatHistory.Insert(0, new ChatMessageContent(AuthorRole.System, "System message added by mutation")); + + // Return the same mutated chat history + return chatHistory; + } + } + /// /// Test implementation of IChatClient for unit testing. /// diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/PromptExecutionSettingsTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/PromptExecutionSettingsTests.cs index c8abbace96d2..3d730ca173fd 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/PromptExecutionSettingsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/PromptExecutionSettingsTests.cs @@ -1,11 +1,19 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; +using System.Linq; using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Moq; using Xunit; namespace SemanticKernel.UnitTests.AI; + public class PromptExecutionSettingsTests { [Fact] @@ -98,4 +106,224 @@ public void PromptExecutionSettingsFreezeWorksAsExpected() executionSettings!.Freeze(); // idempotent Assert.True(executionSettings.IsFrozen); } + + [Theory] + [InlineData(true, false)] // System message only + [InlineData(false, true)] // Developer message only + [InlineData(true, true)] // Both system and developer messages + [InlineData(false, false)] // Neither message + public async Task ChatClientExtensionsGetResponseAsyncCallsPrepareChatHistoryForRequest(bool addSystemMessage, bool addDeveloperMessage) + { + // Arrange + var mockChatClient = new Mock(); + var capturedMessages = new List(); + + mockChatClient + .Setup(x => x.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .Callback, ChatOptions?, CancellationToken>((messages, options, ct) => + { + capturedMessages.AddRange(messages); + }) + .ReturnsAsync(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Test response")])); + + var settings = new TestPromptExecutionSettings(addSystemMessage, addDeveloperMessage); + var prompt = "Hello, world!"; + + // Act + await mockChatClient.Object.GetResponseAsync(prompt, settings); + + // Assert + Assert.True(settings.PrepareChatHistoryWasCalled); + + // Verify the expected messages are present + var expectedMessageCount = 1; // Original user message + if (addSystemMessage) + { + expectedMessageCount++; + } + + if (addDeveloperMessage) + { + expectedMessageCount++; + } + + Assert.Equal(expectedMessageCount, capturedMessages.Count); + + var messageIndex = 0; + if (addSystemMessage) + { + Assert.Equal(ChatRole.System, capturedMessages[messageIndex].Role); + Assert.Equal("Test system message", capturedMessages[messageIndex].Text); + messageIndex++; + } + + if (addDeveloperMessage) + { + Assert.Equal(new ChatRole("developer"), capturedMessages[messageIndex].Role); + Assert.Equal("Test developer message", capturedMessages[messageIndex].Text); + messageIndex++; + } + + // Original user message should be last + Assert.Equal(ChatRole.User, capturedMessages[messageIndex].Role); + Assert.Equal("Hello, world!", capturedMessages[messageIndex].Text); + } + + [Theory] + [InlineData(true, false)] // System message only + [InlineData(false, true)] // Developer message only + [InlineData(true, true)] // Both system and developer messages + [InlineData(false, false)] // Neither message + public async Task ChatClientExtensionsGetStreamingResponseAsyncCallsPrepareChatHistoryForRequest(bool addSystemMessage, bool addDeveloperMessage) + { + // Arrange + var mockChatClient = new Mock(); + var capturedMessages = new List(); + + mockChatClient + .Setup(x => x.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .Callback, ChatOptions?, CancellationToken>((messages, options, ct) => + { + capturedMessages.AddRange(messages); + }) + .Returns(new[] { new ChatResponseUpdate(ChatRole.Assistant, "Test streaming response") }.ToAsyncEnumerable()); + + var settings = new TestPromptExecutionSettings(addSystemMessage, addDeveloperMessage); + var prompt = "Hello, world!"; + + // Act + var responses = new List(); + await foreach (var response in mockChatClient.Object.GetStreamingResponseAsync(prompt, settings)) + { + responses.Add(response); + } + + // Assert + Assert.True(settings.PrepareChatHistoryWasCalled); + Assert.Single(responses); + + // Verify the expected messages are present + var expectedMessageCount = 1; // Original user message + if (addSystemMessage) + { + expectedMessageCount++; + } + + if (addDeveloperMessage) + { + expectedMessageCount++; + } + + Assert.Equal(expectedMessageCount, capturedMessages.Count); + + var messageIndex = 0; + if (addSystemMessage) + { + Assert.Equal(ChatRole.System, capturedMessages[messageIndex].Role); + Assert.Equal("Test system message", capturedMessages[messageIndex].Text); + messageIndex++; + } + + if (addDeveloperMessage) + { + Assert.Equal(new ChatRole("developer"), capturedMessages[messageIndex].Role); + Assert.Equal("Test developer message", capturedMessages[messageIndex].Text); + messageIndex++; + } + + // Original user message should be last + Assert.Equal(ChatRole.User, capturedMessages[messageIndex].Role); + Assert.Equal("Hello, world!", capturedMessages[messageIndex].Text); + } + + [Fact] + public async Task ChatClientExtensionsGetResponseAsyncWithNullSettingsDoesNotCallPrepareChatHistory() + { + // Arrange + var mockChatClient = new Mock(); + var capturedMessages = new List(); + + mockChatClient + .Setup(x => x.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .Callback, ChatOptions?, CancellationToken>((messages, options, ct) => + { + capturedMessages.AddRange(messages); + }) + .ReturnsAsync(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Test response")])); + + var prompt = "Hello, world!"; + + // Act + await mockChatClient.Object.GetResponseAsync(prompt, executionSettings: null); + + // Assert + Assert.Single(capturedMessages); + Assert.Equal(ChatRole.User, capturedMessages[0].Role); + Assert.Equal("Hello, world!", capturedMessages[0].Text); + } + + [Fact] + public async Task ChatClientExtensionsGetStreamingResponseAsyncWithNullSettingsDoesNotCallPrepareChatHistory() + { + // Arrange + var mockChatClient = new Mock(); + var capturedMessages = new List(); + + mockChatClient + .Setup(x => x.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .Callback, ChatOptions?, CancellationToken>((messages, options, ct) => + { + capturedMessages.AddRange(messages); + }) + .Returns(new[] { new ChatResponseUpdate(ChatRole.Assistant, "Test streaming response") }.ToAsyncEnumerable()); + + var prompt = "Hello, world!"; + + // Act + var responses = new List(); + await foreach (var response in mockChatClient.Object.GetStreamingResponseAsync(prompt, executionSettings: null)) + { + responses.Add(response); + } + + // Assert + Assert.Single(responses); + Assert.Single(capturedMessages); + Assert.Equal(ChatRole.User, capturedMessages[0].Role); + Assert.Equal("Hello, world!", capturedMessages[0].Text); + } + + /// + /// Test implementation of PromptExecutionSettings that overrides PrepareChatHistoryForRequest. + /// + private sealed class TestPromptExecutionSettings : PromptExecutionSettings + { + private readonly bool _addSystemMessage; + private readonly bool _addDeveloperMessage; + + public bool PrepareChatHistoryWasCalled { get; private set; } + + public TestPromptExecutionSettings(bool addSystemMessage, bool addDeveloperMessage) + { + this._addSystemMessage = addSystemMessage; + this._addDeveloperMessage = addDeveloperMessage; + } + + protected override ChatHistory PrepareChatHistoryForRequest(ChatHistory chatHistory) + { + this.PrepareChatHistoryWasCalled = true; + + if (this._addDeveloperMessage) + { + chatHistory.Insert(0, new ChatMessageContent(AuthorRole.Developer, "Test developer message")); + } + + if (this._addSystemMessage) + { + chatHistory.Insert(0, new ChatMessageContent(AuthorRole.System, "Test system message")); + } + + return chatHistory; + } + } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/TextSearchProviderTests.cs index 2aaa037265bc..28d37124a3c9 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/TextSearchProviderTests.cs @@ -263,4 +263,56 @@ public async Task ModelInvokingShouldUseOverrideContextFormatterIfProvidedAsync( // Assert Assert.Equal("Custom formatted context with 2 results.", result.Instructions); } + + [Fact] + public async Task SearchAsyncRespectsFilterOption() + { + // Arrange + var mockTextSearch = new Mock(); + var searchResults = new Mock>(); + var mockEnumerator = new Mock>(); + + // Simulate the filtered results + var filteredResult = new TextSearchResult("Filtered Content") { Name = "FilteredDoc", Link = "http://example.com/filtered" }; + var results = new List { filteredResult }; + + mockEnumerator.SetupSequence(e => e.MoveNextAsync()) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockEnumerator.SetupSequence(e => e.Current) + .Returns(filteredResult); + + searchResults.Setup(r => r.GetAsyncEnumerator(It.IsAny())) + .Returns(mockEnumerator.Object); + + TextSearchFilter? capturedFilter = null; + mockTextSearch.Setup(ts => ts.GetTextSearchResultsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((q, opts, ct) => + { + capturedFilter = opts?.Filter; + }) + .ReturnsAsync(new KernelSearchResults(searchResults.Object)); + + var filter = new TextSearchFilter().Equality("Name", "FilteredDoc"); + var options = new TextSearchProviderOptions + { + Filter = filter + }; + + var provider = new TextSearchProvider(mockTextSearch.Object, options: options); + + // Act + var result = await provider.SearchAsync("Sample user question?", CancellationToken.None); + + // Assert + Assert.Contains("Filtered Content", result); + Assert.Contains("SourceDocName: FilteredDoc", result); + Assert.Contains("SourceDocLink: http://example.com/filtered", result); + Assert.NotNull(capturedFilter); + Assert.Equal(filter, capturedFilter); + } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Text/TextChunkerTests.cs b/dotnet/src/SemanticKernel.UnitTests/Text/TextChunkerTests.cs index 807282a2778a..a31f077eef66 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Text/TextChunkerTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Text/TextChunkerTests.cs @@ -777,4 +777,46 @@ public void CanSplitTextParagraphsWithOverlapAndHeaderAndCustomTokenCounter() Assert.Equal(expected, result); } + + [Fact] + public void SplitPlainTextParagraphsHandlesExampleFromIssue() + { + var lines = new[] { "First line\nSecond line\nThird line" }; + + var result = TextChunker.SplitPlainTextParagraphs(lines, 100); + + Assert.Equal("First line\nSecond line\nThird line", result[0]); + } + + [Theory] + [InlineData("First line\r\nSecond line\r\nThird line")] + [InlineData("First line\nSecond line\nThird line")] + [InlineData("First line\rSecond line\rThird line")] + public void SplitPlainTextParagraphsNormalizesNewlinesButDoesNotSplit(string input) + { + var lines = new[] { input }; + + var result = TextChunker.SplitPlainTextParagraphs(lines, 100); + + Assert.Single(result); + Assert.DoesNotContain('\r', result[0]); + Assert.Contains("First line", result[0]); + Assert.Contains("Second line", result[0]); + Assert.Contains("Third line", result[0]); + } + + [Fact] + public void SplitPlainTextParagraphsSplitsWhenExceedingTokenLimit() + { + var lines = new[] { "First line\nSecond line\nThird line" }; + + var result = TextChunker.SplitPlainTextParagraphs(lines, 5); + + Assert.True(result.Count > 1); + + var combined = string.Join(" ", result); + Assert.Contains("First line", combined); + Assert.Contains("Second line", combined); + Assert.Contains("Third line", combined); + } } diff --git a/dotnet/src/VectorData/AzureAISearch/AzureAISearchFilterTranslator.cs b/dotnet/src/VectorData/AzureAISearch/AzureAISearchFilterTranslator.cs index b756b4820fea..df8ab46e3d15 100644 --- a/dotnet/src/VectorData/AzureAISearch/AzureAISearchFilterTranslator.cs +++ b/dotnet/src/VectorData/AzureAISearch/AzureAISearchFilterTranslator.cs @@ -98,7 +98,6 @@ private void TranslateConstant(ConstantExpression constant) private void GenerateLiteral(object? value) { - // TODO: Nullable switch (value) { case byte b: @@ -114,6 +113,13 @@ private void GenerateLiteral(object? value) this._filter.Append(l); return; + case float f: + this._filter.Append(f); + return; + case double d: + this._filter.Append(d); + return; + case string untrustedInput: // This is the only place where we allow untrusted input to be passed in, so we need to quote and escape it. this._filter.Append('\'').Append(untrustedInput.Replace("'", "''")).Append('\''); @@ -125,9 +131,9 @@ private void GenerateLiteral(object? value) this._filter.Append('\'').Append(g.ToString()).Append('\''); return; - case DateTime: - case DateTimeOffset: - throw new NotImplementedException(); + case DateTimeOffset d: + this._filter.Append(d.ToString("o")); + return; case Array: throw new NotImplementedException(); @@ -364,11 +370,12 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Prop } // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type + var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; unwrappedExpression = expression; while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) { var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; - if (convertType != property.Type && convertType != typeof(object)) + if (convertType != unwrappedPropertyType && convertType != typeof(object)) { throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); } diff --git a/dotnet/src/VectorData/Common/SqlFilterTranslator.cs b/dotnet/src/VectorData/Common/SqlFilterTranslator.cs index 614092bc1154..086efc010f55 100644 --- a/dotnet/src/VectorData/Common/SqlFilterTranslator.cs +++ b/dotnet/src/VectorData/Common/SqlFilterTranslator.cs @@ -138,7 +138,6 @@ static bool IsNull(Expression expression) protected virtual void TranslateConstant(object? value) { - // TODO: Nullable switch (value) { case byte b: @@ -154,6 +153,16 @@ protected virtual void TranslateConstant(object? value) this._sql.Append(l); return; + case float f: + this._sql.Append(f); + return; + case double d: + this._sql.Append(d); + return; + case decimal d: + this._sql.Append(d); + return; + case string untrustedInput: // This is the only place where we allow untrusted input to be passed in, so we need to quote and escape it. // Luckily for us, values are escaped in the same way for every provider that we support so far. @@ -169,7 +178,11 @@ protected virtual void TranslateConstant(object? value) case DateTime dateTime: case DateTimeOffset dateTimeOffset: case Array: - throw new NotImplementedException(); +#if NET8_0_OR_GREATER + case DateOnly dateOnly: + case TimeOnly timeOnly: +#endif + throw new UnreachableException("Database-specific format, needs to be implemented in the provider's derived translator."); case null: this._sql.Append("NULL"); @@ -350,11 +363,12 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Prop } // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type + var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; unwrappedExpression = expression; while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) { var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; - if (convertType != property.Type && convertType != typeof(object)) + if (convertType != unwrappedPropertyType && convertType != typeof(object)) { throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); } diff --git a/dotnet/src/VectorData/CosmosMongoDB/CosmosMongoFilterTranslator.cs b/dotnet/src/VectorData/CosmosMongoDB/CosmosMongoFilterTranslator.cs index e0e579713a47..d14c6b37a73f 100644 --- a/dotnet/src/VectorData/CosmosMongoDB/CosmosMongoFilterTranslator.cs +++ b/dotnet/src/VectorData/CosmosMongoDB/CosmosMongoFilterTranslator.cs @@ -262,11 +262,12 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Prop } // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type + var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; unwrappedExpression = expression; while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) { var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; - if (convertType != property.Type && convertType != typeof(object)) + if (convertType != unwrappedPropertyType && convertType != typeof(object)) { throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); } diff --git a/dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollection.cs b/dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollection.cs index 112c8e3bfb02..678eb0e8dd8b 100644 --- a/dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollection.cs +++ b/dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollection.cs @@ -299,10 +299,22 @@ public override Task DeleteAsync(TKey key, CancellationToken cancellationToken = Verify.NotNullOrWhiteSpace(compositeKey.RecordKey); Verify.NotNullOrWhiteSpace(compositeKey.PartitionKey); - return this.RunOperationAsync("DeleteItem", () => - this._database - .GetContainer(this.Name) - .DeleteItemAsync(compositeKey.RecordKey, new PartitionKey(compositeKey.PartitionKey), cancellationToken: cancellationToken)); + return this.RunOperationAsync("DeleteItem", async () => + { + try + { + await this._database + .GetContainer(this.Name) + .DeleteItemAsync(compositeKey.RecordKey, new PartitionKey(compositeKey.PartitionKey), cancellationToken: cancellationToken) + .ConfigureAwait(false); + return 0; + } + catch (CosmosException e) when (e.StatusCode == System.Net.HttpStatusCode.NotFound) + { + // Ignore not found errors + return 0; + } + }); } // TODO: Implement bulk delete, #11350 diff --git a/dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlFilterTranslator.cs b/dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlFilterTranslator.cs index 21c319f2f645..deb629f94813 100644 --- a/dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlFilterTranslator.cs +++ b/dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlFilterTranslator.cs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Linq; using System.Linq.Expressions; using System.Text; @@ -100,46 +102,74 @@ private void TranslateBinary(BinaryExpression binary) } private void TranslateConstant(ConstantExpression constant) + => this.TranslateConstant(constant.Value); + + private void TranslateConstant(object? value) { - // TODO: Nullable - switch (constant.Value) + switch (value) { - case byte b: - this._sql.Append(b); + case byte v: + this._sql.Append(v); + return; + case short v: + this._sql.Append(v); + return; + case int v: + this._sql.Append(v); return; - case short s: - this._sql.Append(s); + case long v: + this._sql.Append(v); return; - case int i: - this._sql.Append(i); + + case float v: + this._sql.Append(v); return; - case long l: - this._sql.Append(l); + case double v: + this._sql.Append(v); return; - case string s: - this._sql.Append('"').Append(s.Replace(@"\", @"\\").Replace("\"", "\\\"")).Append('"'); + case string v: + this._sql.Append('"').Append(v.Replace(@"\", @"\\").Replace("\"", "\\\"")).Append('"'); + return; + case bool v: + this._sql.Append(v ? "true" : "false"); return; - case bool b: - this._sql.Append(b ? "true" : "false"); + case Guid v: + this._sql.Append('"').Append(v.ToString()).Append('"'); return; - case Guid g: - this._sql.Append('"').Append(g.ToString()).Append('"'); + + case DateTimeOffset v: + // Cosmos doesn't support DateTimeOffset with non-zero offset, so we convert it to UTC. + // See https://github.com/dotnet/efcore/issues/35310 + this._sql + .Append('"') + .Append(v.ToUniversalTime().ToString("yyyy-MM-ddTHH:mm:ss.FFFFFF", CultureInfo.InvariantCulture)) + .Append("Z\""); return; - case DateTime: - case DateTimeOffset: - throw new NotImplementedException(); + case IEnumerable v when v.GetType() is var type && (type.IsArray || type.IsGenericType && type.GetGenericTypeDefinition() == typeof(List<>)): + this._sql.Append('['); - case Array: - throw new NotImplementedException(); + var i = 0; + foreach (var element in v) + { + if (i++ > 0) + { + this._sql.Append(','); + } + + this.TranslateConstant(element); + } + + this._sql.Append(']'); + return; case null: this._sql.Append("null"); return; default: - throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); + throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); } } diff --git a/dotnet/src/VectorData/MongoDB/MongoFilterTranslator.cs b/dotnet/src/VectorData/MongoDB/MongoFilterTranslator.cs index 767652743845..98ae37a4311a 100644 --- a/dotnet/src/VectorData/MongoDB/MongoFilterTranslator.cs +++ b/dotnet/src/VectorData/MongoDB/MongoFilterTranslator.cs @@ -72,7 +72,13 @@ private BsonDocument GenerateEqualityComparison(PropertyModel property, object? { if (value is null) { - throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); + throw new NotSupportedException("MongoDB does not support null checks in vector search pre-filters"); + } + + if (value is DateTime or decimal or IList) + { + // Operand type is not supported for $vectorSearch: date/decimal + throw new NotSupportedException($"MongoDB does not support type {value.GetType().Name} in vector search pre-filters."); } // Short form of equality (instead of $eq) @@ -261,11 +267,12 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Prop } // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type + var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; unwrappedExpression = expression; while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) { var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; - if (convertType != property.Type && convertType != typeof(object)) + if (convertType != unwrappedPropertyType && convertType != typeof(object)) { throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); } diff --git a/dotnet/src/VectorData/PgVector/PostgresFilterTranslator.cs b/dotnet/src/VectorData/PgVector/PostgresFilterTranslator.cs index 46dd15893400..d05f83a4b26d 100644 --- a/dotnet/src/VectorData/PgVector/PostgresFilterTranslator.cs +++ b/dotnet/src/VectorData/PgVector/PostgresFilterTranslator.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections; using System.Collections.Generic; +using System.Globalization; using System.Linq.Expressions; using System.Text; using Microsoft.Extensions.VectorData.ProviderServices; @@ -9,7 +12,6 @@ namespace Microsoft.SemanticKernel.Connectors.PgVector; internal sealed class PostgresFilterTranslator : SqlFilterTranslator { - private readonly List _parameterValues = new(); private int _parameterIndex; internal PostgresFilterTranslator( @@ -21,7 +23,49 @@ internal PostgresFilterTranslator( this._parameterIndex = startParamIndex; } - internal List ParameterValues => this._parameterValues; + internal List ParameterValues { get; } = new(); + + protected override void TranslateConstant(object? value) + { + switch (value) + { + // TODO: This aligns with our mapping of DateTime to PG's timestamp (as opposed to timestamptz) - we probably want to + // change that to timestamptz (aligning with Npgsql and EF). See #10641. + case DateTime dateTime: + this._sql.Append('\'').Append(dateTime.ToString("yyyy-MM-ddTHH:mm:ss.FFFFFF", CultureInfo.InvariantCulture)).Append('\''); + return; + case DateTimeOffset dateTimeOffset: + if (dateTimeOffset.Offset != TimeSpan.Zero) + { + throw new NotSupportedException("DateTimeOffset with non-zero offset is not supported with PostgreSQL"); + } + + this._sql.Append('\'').Append(dateTimeOffset.ToString("yyyy-MM-ddTHH:mm:ss.FFFFFF", CultureInfo.InvariantCulture)).Append("Z'"); + return; + + // Array constants (ARRAY[1, 2, 3]) + case IEnumerable v when v.GetType() is var type && (type.IsArray || type.IsGenericType && type.GetGenericTypeDefinition() == typeof(List<>)): + this._sql.Append("ARRAY["); + + var i = 0; + foreach (var element in v) + { + if (i++ > 0) + { + this._sql.Append(','); + } + + this.TranslateConstant(element); + } + + this._sql.Append(']'); + return; + + default: + base.TranslateConstant(value); + break; + } + } protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) { @@ -49,7 +93,7 @@ protected override void TranslateQueryParameter(object? value) } else { - this._parameterValues.Add(value); + this.ParameterValues.Add(value); // The param name is just the index, so there is no need for escaping or quoting. this._sql.Append('$').Append(this._parameterIndex++); } diff --git a/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs b/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs index 3eb1d2646c4a..acfa930a67ec 100644 --- a/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs +++ b/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs @@ -60,14 +60,14 @@ internal static string BuildCreateTableSql(string schema, string tableName, Coll // Add the key column var keyPgTypeInfo = PostgresPropertyMapping.GetPostgresTypeName(model.KeyProperty.Type); - createTableCommand.AppendLine($" \"{keyName}\" {keyPgTypeInfo.PgType} {(keyPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + createTableCommand.AppendLine($" \"{keyName}\" {keyPgTypeInfo.PgType}{(keyPgTypeInfo.IsNullable ? "" : " NOT NULL")},"); // Add the data columns foreach (var dataProperty in model.DataProperties) { string columnName = dataProperty.StorageName; var dataPgTypeInfo = PostgresPropertyMapping.GetPostgresTypeName(dataProperty.Type); - createTableCommand.AppendLine($" \"{columnName}\" {dataPgTypeInfo.PgType} {(dataPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + createTableCommand.AppendLine($" \"{columnName}\" {dataPgTypeInfo.PgType}{(dataPgTypeInfo.IsNullable ? "" : " NOT NULL")},"); } // Add the vector columns @@ -75,7 +75,7 @@ internal static string BuildCreateTableSql(string schema, string tableName, Coll { string columnName = vectorProperty.StorageName; var vectorPgTypeInfo = PostgresPropertyMapping.GetPgVectorTypeName(vectorProperty); - createTableCommand.AppendLine($" \"{columnName}\" {vectorPgTypeInfo.PgType} {(vectorPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + createTableCommand.AppendLine($" \"{columnName}\" {vectorPgTypeInfo.PgType}{(vectorPgTypeInfo.IsNullable ? "" : " NOT NULL")},"); } createTableCommand.AppendLine($" PRIMARY KEY (\"{keyName}\")"); diff --git a/dotnet/src/VectorData/Pinecone/PineconeFieldMapping.cs b/dotnet/src/VectorData/Pinecone/PineconeFieldMapping.cs index 08ed42f5cae0..5ac27c018c55 100644 --- a/dotnet/src/VectorData/Pinecone/PineconeFieldMapping.cs +++ b/dotnet/src/VectorData/Pinecone/PineconeFieldMapping.cs @@ -17,14 +17,16 @@ internal static class PineconeFieldMapping => metadataValue.Value switch { null => null, - bool boolValue => boolValue, - string stringValue => stringValue, + bool v => v, + string v => v, + // Numeric values are not always coming from the SDK in the desired type // that the data model requires, so we need to convert them. - int intValue => ConvertToNumericValue(intValue, targetType), - long longValue => ConvertToNumericValue(longValue, targetType), - float floatValue => ConvertToNumericValue(floatValue, targetType), - double doubleValue => ConvertToNumericValue(doubleValue, targetType), + int v => ConvertToNumericValue(v, targetType), + long v => ConvertToNumericValue(v, targetType), + float v => ConvertToNumericValue(v, targetType), + double v => ConvertToNumericValue(v, targetType), + IEnumerable enumerable => DeserializeCollection(enumerable, targetType), _ => throw new InvalidOperationException($"Unsupported metadata type: '{metadataValue.Value?.GetType().FullName}'."), @@ -62,20 +64,15 @@ public static MetadataValue ConvertToMetadataValue(object? sourceValue) }; private static object? ConvertToNumericValue(object? number, Type targetType) - { - if (number is null) - { - return null; - } + => number is null + ? null + : (Nullable.GetUnderlyingType(targetType) ?? targetType) switch + { + Type t when t == typeof(int) => (object)Convert.ToInt32(number), + Type t when t == typeof(long) => Convert.ToInt64(number), + Type t when t == typeof(float) => Convert.ToSingle(number), + Type t when t == typeof(double) => Convert.ToDouble(number), - return targetType switch - { - Type intType when intType == typeof(int) || intType == typeof(int?) => Convert.ToInt32(number), - Type longType when longType == typeof(long) || longType == typeof(long?) => Convert.ToInt64(number), - Type floatType when floatType == typeof(float) || floatType == typeof(float?) => Convert.ToSingle(number), - Type doubleType when doubleType == typeof(double) || doubleType == typeof(double?) => Convert.ToDouble(number), - Type decimalType when decimalType == typeof(decimal) || decimalType == typeof(decimal?) => Convert.ToDecimal(number), - _ => throw new InvalidOperationException($"Unsupported target numeric type '{targetType.FullName}'."), - }; - } + _ => throw new InvalidOperationException($"Unsupported target numeric type '{targetType.FullName}'."), + }; } diff --git a/dotnet/src/VectorData/Pinecone/PineconeFilterTranslator.cs b/dotnet/src/VectorData/Pinecone/PineconeFilterTranslator.cs index 485637f79fe4..34c5fb39d7d7 100644 --- a/dotnet/src/VectorData/Pinecone/PineconeFilterTranslator.cs +++ b/dotnet/src/VectorData/Pinecone/PineconeFilterTranslator.cs @@ -264,11 +264,12 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Prop } // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type + var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; unwrappedExpression = expression; while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) { var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; - if (convertType != property.Type && convertType != typeof(object)) + if (convertType != unwrappedPropertyType && convertType != typeof(object)) { throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); } diff --git a/dotnet/src/VectorData/Pinecone/PineconeModelBuilder.cs b/dotnet/src/VectorData/Pinecone/PineconeModelBuilder.cs index b1f947102ede..60e8bbb1de1d 100644 --- a/dotnet/src/VectorData/Pinecone/PineconeModelBuilder.cs +++ b/dotnet/src/VectorData/Pinecone/PineconeModelBuilder.cs @@ -28,7 +28,7 @@ protected override bool IsKeyPropertyTypeValid(Type type, [NotNullWhen(false)] o protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) { - supportedTypes = "bool, string, int, long, float, double, decimal, string[]/List"; + supportedTypes = "bool, string, int, long, float, double, string[]/List"; if (Nullable.GetUnderlyingType(type) is Type underlyingType) { @@ -41,7 +41,6 @@ protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] || type == typeof(long) || type == typeof(float) || type == typeof(double) - || type == typeof(decimal) || type == typeof(string[]) || type == typeof(List); } diff --git a/dotnet/src/VectorData/Qdrant/QdrantFilterTranslator.cs b/dotnet/src/VectorData/Qdrant/QdrantFilterTranslator.cs index 3e2d4cc6548f..d6420fe86446 100644 --- a/dotnet/src/VectorData/Qdrant/QdrantFilterTranslator.cs +++ b/dotnet/src/VectorData/Qdrant/QdrantFilterTranslator.cs @@ -8,9 +8,11 @@ using System.Linq; using System.Linq.Expressions; using Google.Protobuf.Collections; +using Google.Protobuf.WellKnownTypes; using Microsoft.Extensions.VectorData.ProviderServices; using Microsoft.Extensions.VectorData.ProviderServices.Filter; using Qdrant.Client.Grpc; +using Expression = System.Linq.Expressions.Expression; using Range = Qdrant.Client.Grpc.Range; namespace Microsoft.SemanticKernel.Connectors.Qdrant; @@ -80,12 +82,13 @@ private Filter GenerateEqual(string propertyStorageName, object? value, bool neg Key = propertyStorageName, Match = value switch { - string stringValue => new Match { Keyword = stringValue }, - int intValue => new Match { Integer = intValue }, - long longValue => new Match { Integer = longValue }, - bool boolValue => new Match { Boolean = boolValue }, + string v => new Match { Keyword = v }, + int v => new Match { Integer = v }, + long v => new Match { Integer = v }, + bool v => new Match { Boolean = v }, + DateTimeOffset v => new Match { Keyword = v.ToString("o") }, - _ => throw new InvalidOperationException($"Unsupported filter value type '{value.GetType().Name}'.") + _ => throw new NotSupportedException($"Unsupported filter value type '{value.GetType().Name}'.") } } }; @@ -114,35 +117,49 @@ private Filter TranslateComparison(BinaryExpression comparison) bool TryProcessComparison(Expression first, Expression second, [NotNullWhen(true)] out Filter? result) { - // TODO: Nullable if (this.TryBindProperty(first, out var property) && second is ConstantExpression { Value: var constantValue }) { - double doubleConstantValue = constantValue switch - { - double d => d, - int i => i, - long l => l, - _ => throw new NotSupportedException($"Can't perform comparison on type '{constantValue?.GetType().Name}', which isn't convertible to double") - }; - result = new Filter(); result.Must.Add(new Condition { - Field = new FieldCondition + Field = constantValue switch + { + double v => DoubleFieldCondition(v), + int v => DoubleFieldCondition(v), + long v => DoubleFieldCondition(v), + + DateTimeOffset v => new FieldCondition + { + Key = property.StorageName, + DatetimeRange = new DatetimeRange + { + Gt = Timestamp.FromDateTimeOffset(v), + Gte = Timestamp.FromDateTimeOffset(v), + Lt = Timestamp.FromDateTimeOffset(v), + Lte = Timestamp.FromDateTimeOffset(v) + } + }, + + _ => throw new NotSupportedException($"Can't perform comparison on type '{constantValue?.GetType().Name}'") + } + }); + + return true; + + FieldCondition DoubleFieldCondition(double d) + => new() { Key = property.StorageName, Range = comparison.NodeType switch { - ExpressionType.GreaterThan => new Range { Gt = doubleConstantValue }, - ExpressionType.GreaterThanOrEqual => new Range { Gte = doubleConstantValue }, - ExpressionType.LessThan => new Range { Lt = doubleConstantValue }, - ExpressionType.LessThanOrEqual => new Range { Lte = doubleConstantValue }, + ExpressionType.GreaterThan => new Range { Gt = d }, + ExpressionType.GreaterThanOrEqual => new Range { Gte = d }, + ExpressionType.LessThan => new Range { Lt = d }, + ExpressionType.LessThanOrEqual => new Range { Lte = d }, _ => throw new InvalidOperationException("Unreachable") } - } - }); - return true; + }; } result = null; @@ -380,11 +397,12 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Prop } // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type + var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; unwrappedExpression = expression; while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) { var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; - if (convertType != property.Type && convertType != typeof(object)) + if (convertType != unwrappedPropertyType && convertType != typeof(object)) { throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); } diff --git a/dotnet/src/VectorData/Redis/RedisFilterTranslator.cs b/dotnet/src/VectorData/Redis/RedisFilterTranslator.cs index e58f2d2df73e..eec5ae6f3da5 100644 --- a/dotnet/src/VectorData/Redis/RedisFilterTranslator.cs +++ b/dotnet/src/VectorData/Redis/RedisFilterTranslator.cs @@ -99,7 +99,6 @@ private void TranslateEqualityComparison(BinaryExpression binary) bool TryProcessEqualityComparison(Expression first, Expression second) { - // TODO: Nullable if (this.TryBindProperty(first, out var property) && second is ConstantExpression { Value: var constantValue }) { // Numeric negation has a special syntax (!=), for the rest we nest in a NOT @@ -115,7 +114,7 @@ bool TryProcessEqualityComparison(Expression first, Expression second) this._filter.Append( binary.NodeType switch { - ExpressionType.Equal when constantValue is int or long or float or double => $" == {constantValue}", + ExpressionType.Equal when constantValue is byte or short or int or long or float or double => $" == {constantValue}", ExpressionType.Equal when constantValue is string stringValue #if NET8_0_OR_GREATER => $$""":{"{{stringValue.Replace("\"", "\\\"", StringComparison.Ordinal)}}"}""", @@ -233,11 +232,12 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Prop } // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type + var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; unwrappedExpression = expression; while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) { var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; - if (convertType != property.Type && convertType != typeof(object)) + if (convertType != unwrappedPropertyType && convertType != typeof(object)) { throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); } diff --git a/dotnet/src/VectorData/Redis/RedisModelBuilder.cs b/dotnet/src/VectorData/Redis/RedisModelBuilder.cs index 81f37db020c1..c8ddf3c0c34c 100644 --- a/dotnet/src/VectorData/Redis/RedisModelBuilder.cs +++ b/dotnet/src/VectorData/Redis/RedisModelBuilder.cs @@ -28,7 +28,7 @@ protected override bool IsKeyPropertyTypeValid(Type type, [NotNullWhen(false)] o protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) { - supportedTypes = "string, int, uint, long, ulong, double, float, bool"; + supportedTypes = "string, int, uint, long, ulong, double, float"; if (Nullable.GetUnderlyingType(type) is Type underlyingType) { @@ -41,8 +41,7 @@ protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] || type == typeof(long) || type == typeof(ulong) || type == typeof(double) - || type == typeof(float) - || type == typeof(bool); + || type == typeof(float); } protected override bool IsVectorPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) diff --git a/dotnet/src/VectorData/SqlServer/SqlServerCommandBuilder.cs b/dotnet/src/VectorData/SqlServer/SqlServerCommandBuilder.cs index c5dc9b39db55..c847bce286ac 100644 --- a/dotnet/src/VectorData/SqlServer/SqlServerCommandBuilder.cs +++ b/dotnet/src/VectorData/SqlServer/SqlServerCommandBuilder.cs @@ -615,33 +615,39 @@ private static void AddParameter(this SqlCommand command, PropertyModel? propert case float[] vectorArray: command.Parameters.AddWithValue(name, JsonSerializer.Serialize(vectorArray, SqlServerJsonSerializerContext.Default.SingleArray)); break; + case DateTime dateTime: + command.Parameters.Add(name, System.Data.SqlDbType.DateTime2).Value = dateTime; + break; default: command.Parameters.AddWithValue(name, value); break; } } - private static string Map(PropertyModel property) => property.Type switch - { - Type t when t == typeof(byte) => "TINYINT", - Type t when t == typeof(short) => "SMALLINT", - Type t when t == typeof(int) => "INT", - Type t when t == typeof(long) => "BIGINT", - Type t when t == typeof(Guid) => "UNIQUEIDENTIFIER", - Type t when t == typeof(string) && property is KeyPropertyModel => "NVARCHAR(4000)", - Type t when t == typeof(string) && property is DataPropertyModel { IsIndexed: true } => "NVARCHAR(4000)", - Type t when t == typeof(string) => "NVARCHAR(MAX)", - Type t when t == typeof(byte[]) => "VARBINARY(MAX)", - Type t when t == typeof(bool) => "BIT", - Type t when t == typeof(DateTime) => "DATETIME2", + private static string Map(PropertyModel property) + => (Nullable.GetUnderlyingType(property.Type) ?? property.Type) switch + { + Type t when t == typeof(byte) => "TINYINT", + Type t when t == typeof(short) => "SMALLINT", + Type t when t == typeof(int) => "INT", + Type t when t == typeof(long) => "BIGINT", + Type t when t == typeof(Guid) => "UNIQUEIDENTIFIER", + Type t when t == typeof(string) && property is KeyPropertyModel => "NVARCHAR(4000)", + Type t when t == typeof(string) && property is DataPropertyModel { IsIndexed: true } => "NVARCHAR(4000)", + Type t when t == typeof(string) => "NVARCHAR(MAX)", + Type t when t == typeof(byte[]) => "VARBINARY(MAX)", + Type t when t == typeof(bool) => "BIT", + Type t when t == typeof(DateTime) => "DATETIME2", #if NET - Type t when t == typeof(TimeOnly) => "TIME", + Type t when t == typeof(DateOnly) => "DATE", + Type t when t == typeof(TimeOnly) => "TIME", #endif - Type t when t == typeof(decimal) => "DECIMAL", - Type t when t == typeof(double) => "FLOAT", - Type t when t == typeof(float) => "REAL", - _ => throw new NotSupportedException($"Type {property.Type} is not supported.") - }; + Type t when t == typeof(decimal) => "DECIMAL(18,2)", + Type t when t == typeof(double) => "FLOAT", + Type t when t == typeof(float) => "REAL", + + _ => throw new NotSupportedException($"Type {property.Type} is not supported.") + }; // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql private static (string distanceMetric, string sorting) MapDistanceFunction(string name) => name switch diff --git a/dotnet/src/VectorData/SqlServer/SqlServerFilterTranslator.cs b/dotnet/src/VectorData/SqlServer/SqlServerFilterTranslator.cs index 74c1f60b15ef..6aea071e5210 100644 --- a/dotnet/src/VectorData/SqlServer/SqlServerFilterTranslator.cs +++ b/dotnet/src/VectorData/SqlServer/SqlServerFilterTranslator.cs @@ -3,6 +3,9 @@ using System; using System.Collections; using System.Collections.Generic; +#if NET8_0_OR_GREATER +using System.Globalization; +#endif using System.Linq.Expressions; using System.Text; using Microsoft.Extensions.VectorData.ProviderServices; @@ -34,11 +37,21 @@ protected override void TranslateConstant(object? value) this._sql.Append(boolValue ? "1" : "0"); return; case DateTime dateTime: - this._sql.AppendFormat("'{0:yyyy-MM-dd HH:mm:ss}'", dateTime); + this._sql.Append('\'').Append(dateTime.ToString("o")).Append('\''); return; case DateTimeOffset dateTimeOffset: - this._sql.AppendFormat("'{0:yyy-MM-dd HH:mm:ss zzz}'", dateTimeOffset); + this._sql.Append('\'').Append(dateTimeOffset.ToString("o")).Append('\''); return; +#if NET8_0_OR_GREATER + case DateOnly dateOnly: + this._sql.Append('\'').Append(dateOnly.ToString("o")).Append('\''); + return; + case TimeOnly timeOnly: + this._sql.AppendFormat(timeOnly.Ticks % 10000000 == 0 + ? string.Format(CultureInfo.InvariantCulture, @"'{0:HH\:mm\:ss}'", value) + : string.Format(CultureInfo.InvariantCulture, @"'{0:HH\:mm\:ss\.FFFFFFF}'", value)); + return; +#endif default: base.TranslateConstant(value); break; diff --git a/dotnet/src/VectorData/SqlServer/SqlServerMapper.cs b/dotnet/src/VectorData/SqlServer/SqlServerMapper.cs index 3e07c3a68437..671fe9cc5be8 100644 --- a/dotnet/src/VectorData/SqlServer/SqlServerMapper.cs +++ b/dotnet/src/VectorData/SqlServer/SqlServerMapper.cs @@ -67,7 +67,7 @@ static void PopulateValue(SqlDataReader reader, PropertyModel property, object r return; } - switch (property.Type) + switch (Nullable.GetUnderlyingType(property.Type) ?? property.Type) { case var t when t == typeof(byte): property.SetValue(record, reader.GetByte(ordinal)); // TINYINT @@ -110,6 +110,9 @@ static void PopulateValue(SqlDataReader reader, PropertyModel property, object r break; #if NET + case var t when t == typeof(DateOnly): + property.SetValue(record, reader.GetFieldValue(ordinal)); // DATE + break; case var t when t == typeof(TimeOnly): property.SetValue(record, reader.GetFieldValue(ordinal)); // TIME break; diff --git a/dotnet/src/VectorData/SqlServer/SqlServerModelBuilder.cs b/dotnet/src/VectorData/SqlServer/SqlServerModelBuilder.cs index cde5dc8bc8bb..fcb5b07396c8 100644 --- a/dotnet/src/VectorData/SqlServer/SqlServerModelBuilder.cs +++ b/dotnet/src/VectorData/SqlServer/SqlServerModelBuilder.cs @@ -32,7 +32,7 @@ protected override bool IsKeyPropertyTypeValid(Type type, [NotNullWhen(false)] o protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) { - supportedTypes = "string, int, long, double, float, bool, DateTimeOffset, or arrays/lists of these types"; + supportedTypes = "string, short, int, long, double, float, decimal, bool, DateTime, DateTimeOffset, DateOnly, TimeOnly, Guid, byte[]"; if (Nullable.GetUnderlyingType(type) is Type underlyingType) { @@ -49,8 +49,9 @@ protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] || type == typeof(bool) // BIT || type == typeof(DateTime) // DATETIME2 #if NET - // We don't support mapping TimeSpan to TIME on purpose - // See https://github.com/microsoft/semantic-kernel/pull/10623#discussion_r1980350721 + || type == typeof(DateOnly) // DATE + // We don't support mapping TimeSpan to TIME on purpose + // See https://github.com/microsoft/semantic-kernel/pull/10623#discussion_r1980350721 || type == typeof(TimeOnly) // TIME #endif || type == typeof(decimal) // DECIMAL diff --git a/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/Filter/FilterTranslationPreprocessor.cs b/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/Filter/FilterTranslationPreprocessor.cs index 2978e654c31f..78289e41c021 100644 --- a/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/Filter/FilterTranslationPreprocessor.cs +++ b/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/Filter/FilterTranslationPreprocessor.cs @@ -127,4 +127,56 @@ protected override Expression VisitMember(MemberExpression node) return new QueryParameterExpression(name, evaluatedValue, visited.Type); } + + /// + protected override Expression VisitNew(NewExpression node) + { + var visited = (NewExpression)base.VisitNew(node); + + // Recognize certain well-known constructors where we can evaluate immediately, converting the NewExpression to a ConstantExpression. + // This is particularly useful for converting inline instantiation of DateTime and DateTimeOffset to constants, which can then be easily translated. + switch (visited.Constructor) + { + case ConstructorInfo constructor when constructor.DeclaringType == typeof(DateTimeOffset) || constructor.DeclaringType == typeof(DateTime): + var constantArguments = new object?[visited.Arguments.Count]; + + // We first do a fast path to check if all arguments are constants; this catches the common case of e.g. new DateTime(2023, 10, 1). + // If an argument isn't a constant (e.g. new DateTimeOffset(..., TimeSpan.FromHours(2))), we fall back to trying the LINQ interpreter + // as a general-purpose expression evaluator - but note that this is considerably slower. + for (var i = 0; i < visited.Arguments.Count; i++) + { + if (visited.Arguments[i] is ConstantExpression constantArgument) + { + constantArguments[i] = constantArgument.Value; + } + else + { + // There's a non-constant argument - try the LINQ interpreter. +#pragma warning disable CA1031 // Do not catch general exception types + try + { + var evaluated = Expression.Lambda>(Expression.Convert(visited, typeof(object))) +#if NET8_0_OR_GREATER + .Compile(preferInterpretation: true) +#else + .Compile() +#endif + .Invoke(); + + return Expression.Constant(evaluated, constructor.DeclaringType); + } + catch + { + return visited; + } +#pragma warning restore CA1031 + } + } + + var constantValue = constructor.Invoke(constantArguments); + return Expression.Constant(constantValue, constructor.DeclaringType); + } + + return visited; + } } diff --git a/dotnet/src/VectorData/VectorData.Abstractions/VectorData.Abstractions.csproj b/dotnet/src/VectorData/VectorData.Abstractions/VectorData.Abstractions.csproj index 5ebb187299de..775ccb90c79f 100644 --- a/dotnet/src/VectorData/VectorData.Abstractions/VectorData.Abstractions.csproj +++ b/dotnet/src/VectorData/VectorData.Abstractions/VectorData.Abstractions.csproj @@ -12,10 +12,10 @@ - 9.6.0 + 9.7.0 9.0.0.0 - 9.5.0 + 9.6.0 Microsoft.Extensions.VectorData.Abstractions $(AssemblyName) Abstractions for vector database access. diff --git a/dotnet/src/VectorData/Weaviate/WeaviateFilterTranslator.cs b/dotnet/src/VectorData/Weaviate/WeaviateFilterTranslator.cs index fbd5eab8a6c1..4d2f1177ba4e 100644 --- a/dotnet/src/VectorData/Weaviate/WeaviateFilterTranslator.cs +++ b/dotnet/src/VectorData/Weaviate/WeaviateFilterTranslator.cs @@ -273,11 +273,12 @@ private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out Prop } // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type + var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; unwrappedExpression = expression; while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) { var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; - if (convertType != property.Type && convertType != typeof(object)) + if (convertType != unwrappedPropertyType && convertType != typeof(object)) { throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); } diff --git a/dotnet/src/VectorData/Weaviate/WeaviateMapper.cs b/dotnet/src/VectorData/Weaviate/WeaviateMapper.cs index 3eb3cf72e2bd..730199f0a87c 100644 --- a/dotnet/src/VectorData/Weaviate/WeaviateMapper.cs +++ b/dotnet/src/VectorData/Weaviate/WeaviateMapper.cs @@ -145,7 +145,7 @@ public TRecord MapFromStorageToDataModel(JsonObject storageModel, bool includeVe if (dataPropertiesJson.TryGetPropertyValue(property.StorageName, out var dataValue)) { // TODO: NativeAOT support, #11963 - property.SetValueAsObject(record, dataValue.Deserialize(property.Type, this._jsonSerializerOptions)); + property.SetValueAsObject(record, dataValue?.Deserialize(property.Type, this._jsonSerializerOptions)); } } } diff --git a/dotnet/test/VectorData/AzureAISearch.ConformanceTests/AzureAISearchDataTypeTests.cs b/dotnet/test/VectorData/AzureAISearch.ConformanceTests/AzureAISearchDataTypeTests.cs new file mode 100644 index 000000000000..be9fe82afc9f --- /dev/null +++ b/dotnet/test/VectorData/AzureAISearch.ConformanceTests/AzureAISearchDataTypeTests.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. + +using AzureAISearch.ConformanceTests.Support; +using Microsoft.Extensions.VectorData; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using VectorData.ConformanceTests.Xunit; +using Xunit; + +namespace AzureAISearch.ConformanceTests; + +public class AzureAISearchDataTypeTests(AzureAISearchDataTypeTests.Fixture fixture) + : DataTypeTests(fixture), + IClassFixture +{ + public override Task Byte() => Task.CompletedTask; + public override Task Short() => Task.CompletedTask; + public override Task Decimal() => Task.CompletedTask; + + [ConditionalFact(Skip = "Guid not yet supported")] + public override Task Guid() => Task.CompletedTask; + + [ConditionalFact(Skip = "DateTime not yet supported")] + public override Task DateTime() => Task.CompletedTask; + + // [ConditionalFact(Skip = "DateTimeOffset not yet supported")] + // public override Task DateTimeOffset() => Task.CompletedTask; + + [ConditionalFact(Skip = "DateOnly not yet supported")] + public override Task DateOnly() => Task.CompletedTask; + + [ConditionalFact(Skip = "TimeOnly not yet supported")] + public override Task TimeOnly() => Task.CompletedTask; + + public override Task String_array() => Task.CompletedTask; + + protected override object? GenerateEmptyProperty(VectorStoreProperty property) + => property.Type switch + { + null => throw new InvalidOperationException($"Property '{property.Name}' has no type defined."), + + // In Azure AI Search, array fields must be non-null (at least for now) + var t when t.IsArray => Array.CreateInstance(t.GetElementType()!, 0), + + _ => base.GenerateEmptyProperty(property) + }; + + public new class Fixture : DataTypeTests.Fixture + { + public override TestStore TestStore => AzureAISearchTestStore.Instance; + + // Azure AI search only supports lowercase letters, digits or dashes. + public override string CollectionName => "data-type-tests" + AzureAISearchTestEnvironment.TestIndexPostfix; + + public override IList GetDataProperties() + => base.GetDataProperties().Where(p => + p.Type != typeof(byte) + && p.Type != typeof(short) + && p.Type != typeof(decimal) + && p.Type != typeof(Guid) + && p.Type != typeof(DateTime) +#if NET8_0_OR_GREATER + && p.Type != typeof(DateOnly) + && p.Type != typeof(TimeOnly) +#endif + ).ToList(); + + public class AzureAISearchRecord : RecordBase + { + public int Int { get; set; } + public long Long { get; set; } + public float Float { get; set; } + public double Double { get; set; } + + public string? String { get; set; } + public bool Bool { get; set; } + + public DateTimeOffset DateTimeOffset { get; set; } + + public string[] StringArray { get; set; } = null!; + + public int? NullableInt { get; set; } + } + } +} diff --git a/dotnet/test/VectorData/AzureAISearch.ConformanceTests/Support/AzureAISearchAllTypes.cs b/dotnet/test/VectorData/AzureAISearch.ConformanceTests/Support/AzureAISearchAllTypes.cs index 1ddb4d17d85f..1c099bc7f4d1 100644 --- a/dotnet/test/VectorData/AzureAISearch.ConformanceTests/Support/AzureAISearchAllTypes.cs +++ b/dotnet/test/VectorData/AzureAISearch.ConformanceTests/Support/AzureAISearchAllTypes.cs @@ -6,7 +6,6 @@ namespace AzureAISearch.ConformanceTests.Support; #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable. -#pragma warning disable CA1819 // Properties should not return arrays public class AzureAISearchAllTypes { diff --git a/dotnet/test/VectorData/CosmosMongoDB.ConformanceTests/CosmosMongoDataTypeTests.cs b/dotnet/test/VectorData/CosmosMongoDB.ConformanceTests/CosmosMongoDataTypeTests.cs new file mode 100644 index 000000000000..789a54a76ba3 --- /dev/null +++ b/dotnet/test/VectorData/CosmosMongoDB.ConformanceTests/CosmosMongoDataTypeTests.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosMongoDB.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace CosmosMongoDB.ConformanceTests; + +public class CosmosMongoDataTypeTests(CosmosMongoDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public override Task Decimal() + => this.Test( + "Decimal", 8.5m, 9.5m, + isFilterable: false); // TODO: Filtering doesn't fail but the data doesn't seem to appear... + + public override Task DateTime() + => this.Test( + "DateTime", + new DateTime(2020, 1, 1, 12, 30, 45, DateTimeKind.Utc), + new DateTime(2021, 2, 3, 13, 40, 55, DateTimeKind.Utc), + instantiationExpression: () => new DateTime(2020, 1, 1, 12, 30, 45, DateTimeKind.Utc)); + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => CosmosMongoTestStore.Instance; + + // MongoDB does not support null checks in vector search pre-filters + public override bool IsNullFilteringSupported => false; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(byte), + typeof(short), + typeof(Guid), + typeof(DateTimeOffset), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/CosmosNoSql.ConformanceTests/CosmosNoSqlDataTypeTests.cs b/dotnet/test/VectorData/CosmosNoSql.ConformanceTests/CosmosNoSqlDataTypeTests.cs new file mode 100644 index 000000000000..e4b541c631f5 --- /dev/null +++ b/dotnet/test/VectorData/CosmosNoSql.ConformanceTests/CosmosNoSqlDataTypeTests.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSql.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using VectorData.ConformanceTests.Xunit; +using Xunit; + +namespace CosmosNoSql.ConformanceTests; + +public class CosmosNoSqlDataTypeTests(CosmosNoSqlDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + // Cosmos doesn't support DateTimeOffset with non-zero offset, so we convert it to UTC. + // See https://github.com/dotnet/efcore/issues/35310 + [ConditionalFact(Skip = "Need to convert DateTimeOffset to UTC before sending to Cosmos")] + public override Task DateTimeOffset() + => this.Test( + "DateTimeOffset", + new DateTimeOffset(2020, 1, 1, 12, 30, 45, TimeSpan.FromHours(2)), + new DateTimeOffset(2021, 2, 3, 13, 40, 55, TimeSpan.FromHours(3)), + instantiationExpression: () => new DateTimeOffset(2020, 1, 1, 10, 30, 45, TimeSpan.FromHours(0))); + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => CosmosNoSqlTestStore.Instance; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(byte), + typeof(short), + typeof(decimal), + typeof(Guid), + typeof(DateTime), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/Directory.Build.props b/dotnet/test/VectorData/Directory.Build.props index 3f4402245e90..e58e3eb681b5 100644 --- a/dotnet/test/VectorData/Directory.Build.props +++ b/dotnet/test/VectorData/Directory.Build.props @@ -9,6 +9,8 @@ $(NoWarn);CA1716 $(NoWarn);CA1720 $(NoWarn);CA1721 + $(NoWarn);CA1819 + $(NoWarn);CS1819 $(NoWarn);CA1861 $(NoWarn);CA1863 $(NoWarn);CA2007;VSTHRD111 diff --git a/dotnet/test/VectorData/InMemory.ConformanceTests/InMemoryDataTypeTests.cs b/dotnet/test/VectorData/InMemory.ConformanceTests/InMemoryDataTypeTests.cs new file mode 100644 index 000000000000..d0e47b840f93 --- /dev/null +++ b/dotnet/test/VectorData/InMemory.ConformanceTests/InMemoryDataTypeTests.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemory.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace InMemory.ConformanceTests; + +public class InMemoryDataTypeTests(InMemoryDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => InMemoryTestStore.Instance; + + public override Type[] UnsupportedDefaultTypes { get; } = []; + } +} diff --git a/dotnet/test/VectorData/MongoDB.ConformanceTests/MongoDataTypeTests.cs b/dotnet/test/VectorData/MongoDB.ConformanceTests/MongoDataTypeTests.cs new file mode 100644 index 000000000000..6a9361bbff28 --- /dev/null +++ b/dotnet/test/VectorData/MongoDB.ConformanceTests/MongoDataTypeTests.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDB.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace MongoDB.ConformanceTests; + +public class MongoDataTypeTests(MongoDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public override Task Decimal() + => this.Test( + "Decimal", 8.5m, 9.5m, + isFilterable: false); // Operand type is not supported for $vectorSearch: decimal + + public override Task DateTime() + => this.Test( + "DateTime", + new DateTime(2020, 1, 1, 12, 30, 45, DateTimeKind.Utc), + new DateTime(2021, 2, 3, 13, 40, 55, DateTimeKind.Utc), + instantiationExpression: () => new DateTime(2020, 1, 1, 12, 30, 45), + isFilterable: false); // Operand type is not supported for $vectorSearch: date + + public override Task String_array() + => this.Test( + "StringArray", + ["foo", "bar"], + ["foo", "baz"], + isFilterable: false); // Operand type is not supported for $vectorSearch: array + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => MongoTestStore.Instance; + + // MongoDB does not support null checks in vector search pre-filters + public override bool IsNullFilteringSupported => false; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(byte), + typeof(short), + typeof(Guid), + typeof(DateTimeOffset), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/PgVector.ConformanceTests/PostgresDataTypeTests.cs b/dotnet/test/VectorData/PgVector.ConformanceTests/PostgresDataTypeTests.cs new file mode 100644 index 000000000000..53245b2e089e --- /dev/null +++ b/dotnet/test/VectorData/PgVector.ConformanceTests/PostgresDataTypeTests.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PgVector.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace PgVector.ConformanceTests; + +public class PostgresDataTypeTests(PostgresDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + // PostgreSQL does not support representing an offset, so only DateTimeOffsets with offset=0 are supported. + public override Task DateTimeOffset() + => this.Test( + "DateTimeOffset", + new DateTimeOffset(2020, 1, 1, 12, 30, 45, TimeSpan.FromHours(0)), + new DateTimeOffset(2021, 2, 3, 13, 40, 55, TimeSpan.FromHours(0)), + instantiationExpression: () => new DateTimeOffset(2020, 1, 1, 12, 30, 45, TimeSpan.FromHours(0))); + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => PostgresTestStore.Instance; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(byte), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly), +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/PgVector.ConformanceTests/Support/PostgresTestStore.cs b/dotnet/test/VectorData/PgVector.ConformanceTests/Support/PostgresTestStore.cs index cff6dd48eeb3..093804607da6 100644 --- a/dotnet/test/VectorData/PgVector.ConformanceTests/Support/PostgresTestStore.cs +++ b/dotnet/test/VectorData/PgVector.ConformanceTests/Support/PostgresTestStore.cs @@ -14,7 +14,7 @@ internal sealed class PostgresTestStore : TestStore public static PostgresTestStore Instance { get; } = new(); private static readonly PostgreSqlContainer s_container = new PostgreSqlBuilder() - .WithImage("pgvector/pgvector:pg16") + .WithImage("pgvector/pgvector:pg17") .Build(); private NpgsqlDataSource? _dataSource; diff --git a/dotnet/test/VectorData/Pinecone.ConformanceTests/PineconeDataTypeTests.cs b/dotnet/test/VectorData/Pinecone.ConformanceTests/PineconeDataTypeTests.cs new file mode 100644 index 000000000000..ade59387ebf5 --- /dev/null +++ b/dotnet/test/VectorData/Pinecone.ConformanceTests/PineconeDataTypeTests.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Pinecone.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace Pinecone.ConformanceTests; + +public class PineconeDataTypeTests(PineconeDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => PineconeTestStore.Instance; + + // https://docs.pinecone.io/troubleshooting/restrictions-on-index-names + public override string CollectionName => "data-type-tests"; + + // Pincone does not support null checks in vector search pre-filters + public override bool IsNullFilteringSupported => false; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(byte), + typeof(short), + typeof(decimal), + typeof(Guid), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(string[]), // TODO: Error with gRPC status code 3 + +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/Qdrant.ConformanceTests/QdrantDataTypeTests.cs b/dotnet/test/VectorData/Qdrant.ConformanceTests/QdrantDataTypeTests.cs new file mode 100644 index 000000000000..e5cdb008ee82 --- /dev/null +++ b/dotnet/test/VectorData/Qdrant.ConformanceTests/QdrantDataTypeTests.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Qdrant.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using VectorData.ConformanceTests.Xunit; +using Xunit; + +namespace Qdrant.ConformanceTests; + +public class QdrantDataTypeTests(QdrantDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + // Qdrant doesn't seem to support filtering on float/double or string ararys, + // https://qdrant.tech/documentation/concepts/filtering/#match + [ConditionalFact] + public override Task Float() + => this.Test("Float", 8.5f, 9.5f, isFilterable: false); + + [ConditionalFact] + public override Task Double() + => this.Test("Double", 8.5d, 9.5d, isFilterable: false); + + [ConditionalFact] + public override Task String_array() + => this.Test( + "StringArray", + ["foo", "bar"], + ["foo", "baz"], + isFilterable: false); + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => QdrantTestStore.UnnamedVectorInstance; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(byte), + typeof(short), + typeof(decimal), + typeof(Guid), + typeof(DateTime), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/Redis.ConformanceTests/RedisHashSetDataTypeTests.cs b/dotnet/test/VectorData/Redis.ConformanceTests/RedisHashSetDataTypeTests.cs new file mode 100644 index 000000000000..068e3c49a120 --- /dev/null +++ b/dotnet/test/VectorData/Redis.ConformanceTests/RedisHashSetDataTypeTests.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Redis.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using VectorData.ConformanceTests.Xunit; +using Xunit; + +namespace Redis.ConformanceTests; + +public class RedisHashSetDataTypeTests(RedisHashSetDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public override Task Bool() => Task.CompletedTask; + public override Task Decimal() => Task.CompletedTask; + public override Task DateTime() => Task.CompletedTask; + public override Task DateTimeOffset() => Task.CompletedTask; + public override Task DateOnly() => Task.CompletedTask; + public override Task TimeOnly() => Task.CompletedTask; + + [ConditionalFact(Skip = "Guid not yet supported")] + public override Task Guid() => Task.CompletedTask; + + public override Task String_array() + => this.Test( + "StringArray", + ["foo", "bar"], + ["foo", "baz"], + isFilterable: false); + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => RedisTestStore.JsonInstance; + + public override bool IsNullSupported => false; + public override bool IsNullFilteringSupported => false; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(bool), + typeof(decimal), + typeof(Guid), + typeof(DateTime), + typeof(DateTimeOffset), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/Redis.ConformanceTests/RedisJsonDataTypeTests.cs b/dotnet/test/VectorData/Redis.ConformanceTests/RedisJsonDataTypeTests.cs new file mode 100644 index 000000000000..426eebe2e478 --- /dev/null +++ b/dotnet/test/VectorData/Redis.ConformanceTests/RedisJsonDataTypeTests.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Redis.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace Redis.ConformanceTests; + +public class RedisJsonDataTypeTests(RedisJsonDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public override Task String_array() + => this.Test( + "StringArray", + ["foo", "bar"], + ["foo", "baz"], + isFilterable: false); + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => RedisTestStore.JsonInstance; + + public override bool IsNullSupported => false; + public override bool IsNullFilteringSupported => false; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(bool), + typeof(decimal), + typeof(Guid), + typeof(DateTime), + typeof(DateTimeOffset), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetDynamicMappingTests.cs b/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetDynamicMappingTests.cs index b0bf08e1854f..a29fccfa75bf 100644 --- a/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetDynamicMappingTests.cs +++ b/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetDynamicMappingTests.cs @@ -35,14 +35,12 @@ public void MapFromDataToStorageModelMapsAllSupportedTypes() ["ULongData"] = 4ul, ["DoubleData"] = 5.5d, ["FloatData"] = 6.6f, - ["BoolData"] = true, ["NullableIntData"] = 7, ["NullableUIntData"] = 8u, ["NullableLongData"] = 9L, ["NullableULongData"] = 10ul, ["NullableDoubleData"] = 11.1d, ["NullableFloatData"] = 12.2f, - ["NullableBoolData"] = false, ["FloatVector"] = new ReadOnlyMemory(s_floatVector), ["DoubleVector"] = new ReadOnlyMemory(s_doubleVector), @@ -113,14 +111,12 @@ public void MapFromStorageToDataModelMapsAllSupportedTypes() Assert.Equal(4ul, dataModel["ULongData"]); Assert.Equal(5.5d, dataModel["DoubleData"]); Assert.Equal(6.6f, dataModel["FloatData"]); - Assert.True((bool)dataModel["BoolData"]!); Assert.Equal(7, dataModel["NullableIntData"]); Assert.Equal(8u, dataModel["NullableUIntData"]); Assert.Equal(9L, dataModel["NullableLongData"]); Assert.Equal(10ul, dataModel["NullableULongData"]); Assert.Equal(11.1d, dataModel["NullableDoubleData"]); Assert.Equal(12.2f, dataModel["NullableFloatData"]); - Assert.False((bool)dataModel["NullableBoolData"]!); Assert.Equal(new float[] { 1, 2, 3, 4 }, ((ReadOnlyMemory)dataModel["FloatVector"]!).ToArray()); Assert.Equal(new double[] { 5, 6, 7, 8 }, ((ReadOnlyMemory)dataModel["DoubleVector"]!).ToArray()); } diff --git a/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetMapperTests.cs b/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetMapperTests.cs index e2ef1c72c6f0..8a6b953f4e80 100644 --- a/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetMapperTests.cs +++ b/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetMapperTests.cs @@ -52,14 +52,12 @@ public void MapsAllFieldsFromStorageToDataModel() Assert.Equal(4ul, actual.ULongData); Assert.Equal(5.5d, actual.DoubleData); Assert.Equal(6.6f, actual.FloatData); - Assert.True(actual.BoolData); Assert.Equal(7, actual.NullableIntData); Assert.Equal(8u, actual.NullableUIntData); Assert.Equal(9, actual.NullableLongData); Assert.Equal(10ul, actual.NullableULongData); Assert.Equal(11.1d, actual.NullableDoubleData); Assert.Equal(12.2f, actual.NullableFloatData); - Assert.False(actual.NullableBoolData); Assert.Equal(new float[] { 1, 2, 3, 4 }, actual.FloatVector!.Value.ToArray()); Assert.Equal(new double[] { 5, 6, 7, 8 }, actual.DoubleVector!.Value.ToArray()); @@ -77,14 +75,12 @@ private static AllTypesModel CreateModel(string key) ULongData = 4, DoubleData = 5.5d, FloatData = 6.6f, - BoolData = true, NullableIntData = 7, NullableUIntData = 8, NullableLongData = 9, NullableULongData = 10, NullableDoubleData = 11.1d, NullableFloatData = 12.2f, - NullableBoolData = false, FloatVector = new float[] { 1, 2, 3, 4 }, DoubleVector = new double[] { 5, 6, 7, 8 }, NotAnnotated = "notAnnotated", @@ -117,9 +113,6 @@ private sealed class AllTypesModel [VectorStoreData] public float FloatData { get; set; } - [VectorStoreData] - public bool BoolData { get; set; } - [VectorStoreData] public int? NullableIntData { get; set; } @@ -138,9 +131,6 @@ private sealed class AllTypesModel [VectorStoreData] public float? NullableFloatData { get; set; } - [VectorStoreData] - public bool? NullableBoolData { get; set; } - [VectorStoreVector(10)] public ReadOnlyMemory? FloatVector { get; set; } diff --git a/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetMappingTestHelpers.cs b/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetMappingTestHelpers.cs index c18fa41e7355..5dab6b2977f7 100644 --- a/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetMappingTestHelpers.cs +++ b/dotnet/test/VectorData/Redis.UnitTests/RedisHashSetMappingTestHelpers.cs @@ -27,14 +27,12 @@ internal static class RedisHashSetMappingTestHelpers new VectorStoreDataProperty("ULongData", typeof(ulong)), new VectorStoreDataProperty("DoubleData", typeof(double)), new VectorStoreDataProperty("FloatData", typeof(float)), - new VectorStoreDataProperty("BoolData", typeof(bool)), new VectorStoreDataProperty("NullableIntData", typeof(int?)), new VectorStoreDataProperty("NullableUIntData", typeof(uint?)), new VectorStoreDataProperty("NullableLongData", typeof(long?)), new VectorStoreDataProperty("NullableULongData", typeof(ulong?)), new VectorStoreDataProperty("NullableDoubleData", typeof(double?)), new VectorStoreDataProperty("NullableFloatData", typeof(float?)), - new VectorStoreDataProperty("NullableBoolData", typeof(bool?)), new VectorStoreVectorProperty("FloatVector", typeof(ReadOnlyMemory), 10), new VectorStoreVectorProperty("DoubleVector", typeof(ReadOnlyMemory), 10), } @@ -42,7 +40,7 @@ internal static class RedisHashSetMappingTestHelpers public static HashEntry[] CreateHashSet() { - var hashSet = new HashEntry[17]; + var hashSet = new HashEntry[15]; hashSet[0] = new HashEntry("storage_string_data", "data 1"); hashSet[1] = new HashEntry("IntData", 1); hashSet[2] = new HashEntry("UIntData", 2); @@ -50,16 +48,14 @@ public static HashEntry[] CreateHashSet() hashSet[4] = new HashEntry("ULongData", 4); hashSet[5] = new HashEntry("DoubleData", 5.5); hashSet[6] = new HashEntry("FloatData", 6.6); - hashSet[7] = new HashEntry("BoolData", true); - hashSet[8] = new HashEntry("NullableIntData", 7); - hashSet[9] = new HashEntry("NullableUIntData", 8); - hashSet[10] = new HashEntry("NullableLongData", 9); - hashSet[11] = new HashEntry("NullableULongData", 10); - hashSet[12] = new HashEntry("NullableDoubleData", 11.1); - hashSet[13] = new HashEntry("NullableFloatData", 12.2); - hashSet[14] = new HashEntry("NullableBoolData", false); - hashSet[15] = new HashEntry("FloatVector", MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 1, 2, 3, 4 })).ToArray()); - hashSet[16] = new HashEntry("DoubleVector", MemoryMarshal.AsBytes(new ReadOnlySpan(new double[] { 5, 6, 7, 8 })).ToArray()); + hashSet[7] = new HashEntry("NullableIntData", 7); + hashSet[8] = new HashEntry("NullableUIntData", 8); + hashSet[9] = new HashEntry("NullableLongData", 9); + hashSet[10] = new HashEntry("NullableULongData", 10); + hashSet[11] = new HashEntry("NullableDoubleData", 11.1); + hashSet[12] = new HashEntry("NullableFloatData", 12.2); + hashSet[13] = new HashEntry("FloatVector", MemoryMarshal.AsBytes(new ReadOnlySpan(new float[] { 1, 2, 3, 4 })).ToArray()); + hashSet[14] = new HashEntry("DoubleVector", MemoryMarshal.AsBytes(new ReadOnlySpan(new double[] { 5, 6, 7, 8 })).ToArray()); return hashSet; } @@ -86,34 +82,28 @@ public static void VerifyHashSet(HashEntry[] hashEntries) Assert.Equal("FloatData", hashEntries[6].Name.ToString()); Assert.Equal(6.6f, (float)hashEntries[6].Value); - Assert.Equal("BoolData", hashEntries[7].Name.ToString()); - Assert.True((bool)hashEntries[7].Value); + Assert.Equal("NullableIntData", hashEntries[7].Name.ToString()); + Assert.Equal(7, (int)hashEntries[7].Value); - Assert.Equal("NullableIntData", hashEntries[8].Name.ToString()); - Assert.Equal(7, (int)hashEntries[8].Value); + Assert.Equal("NullableUIntData", hashEntries[8].Name.ToString()); + Assert.Equal(8u, (uint)hashEntries[8].Value); - Assert.Equal("NullableUIntData", hashEntries[9].Name.ToString()); - Assert.Equal(8u, (uint)hashEntries[9].Value); + Assert.Equal("NullableLongData", hashEntries[9].Name.ToString()); + Assert.Equal(9, (long)hashEntries[9].Value); - Assert.Equal("NullableLongData", hashEntries[10].Name.ToString()); - Assert.Equal(9, (long)hashEntries[10].Value); + Assert.Equal("NullableULongData", hashEntries[10].Name.ToString()); + Assert.Equal(10ul, (ulong)hashEntries[10].Value); - Assert.Equal("NullableULongData", hashEntries[11].Name.ToString()); - Assert.Equal(10ul, (ulong)hashEntries[11].Value); + Assert.Equal("NullableDoubleData", hashEntries[11].Name.ToString()); + Assert.Equal(11.1d, (double)hashEntries[11].Value); - Assert.Equal("NullableDoubleData", hashEntries[12].Name.ToString()); - Assert.Equal(11.1d, (double)hashEntries[12].Value); + Assert.Equal("NullableFloatData", hashEntries[12].Name.ToString()); + Assert.Equal(12.2f, (float)hashEntries[12].Value); - Assert.Equal("NullableFloatData", hashEntries[13].Name.ToString()); - Assert.Equal(12.2f, (float)hashEntries[13].Value); + Assert.Equal("FloatVector", hashEntries[13].Name.ToString()); + Assert.Equal(new float[] { 1, 2, 3, 4 }, MemoryMarshal.Cast((byte[])hashEntries[13].Value!).ToArray()); - Assert.Equal("NullableBoolData", hashEntries[14].Name.ToString()); - Assert.False((bool)hashEntries[14].Value); - - Assert.Equal("FloatVector", hashEntries[15].Name.ToString()); - Assert.Equal(new float[] { 1, 2, 3, 4 }, MemoryMarshal.Cast((byte[])hashEntries[15].Value!).ToArray()); - - Assert.Equal("DoubleVector", hashEntries[16].Name.ToString()); - Assert.Equal(new double[] { 5, 6, 7, 8 }, MemoryMarshal.Cast((byte[])hashEntries[16].Value!).ToArray()); + Assert.Equal("DoubleVector", hashEntries[14].Name.ToString()); + Assert.Equal(new double[] { 5, 6, 7, 8 }, MemoryMarshal.Cast((byte[])hashEntries[14].Value!).ToArray()); } } diff --git a/dotnet/test/VectorData/SqlServer.ConformanceTests/SqlServerDataTypeTests.cs b/dotnet/test/VectorData/SqlServer.ConformanceTests/SqlServerDataTypeTests.cs new file mode 100644 index 000000000000..b307152d79af --- /dev/null +++ b/dotnet/test/VectorData/SqlServer.ConformanceTests/SqlServerDataTypeTests.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests; + +public class SqlServerDataTypeTests(SqlServerDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(DateTimeOffset), + typeof(string[]) + ]; + } +} diff --git a/dotnet/test/VectorData/SqlServer.ConformanceTests/SqlServerVectorStoreTests.cs b/dotnet/test/VectorData/SqlServer.ConformanceTests/SqlServerVectorStoreTests.cs index 217a3eaa5644..05d326e4d78e 100644 --- a/dotnet/test/VectorData/SqlServer.ConformanceTests/SqlServerVectorStoreTests.cs +++ b/dotnet/test/VectorData/SqlServer.ConformanceTests/SqlServerVectorStoreTests.cs @@ -395,9 +395,7 @@ public sealed class FancyTestModel public long Number64 { get; set; } [VectorStoreData(StorageName = "bytes")] -#pragma warning disable CA1819 // Properties should not return arrays public byte[]? Bytes { get; set; } -#pragma warning restore CA1819 // Properties should not return arrays [VectorStoreVector(Dimensions: 10, StorageName = "embedding")] public ReadOnlyMemory Floats { get; set; } diff --git a/dotnet/test/VectorData/SqliteVec.ConformanceTests/SqliteDataTypeTests.cs b/dotnet/test/VectorData/SqliteVec.ConformanceTests/SqliteDataTypeTests.cs new file mode 100644 index 000000000000..e4c73b7e509f --- /dev/null +++ b/dotnet/test/VectorData/SqliteVec.ConformanceTests/SqliteDataTypeTests.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqliteVec.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqliteVec.ConformanceTests; + +public class SqliteDataTypeTests(SqliteDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => SqliteTestStore.Instance; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(byte), + typeof(decimal), + typeof(Guid), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(string[]), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/dotnet/test/VectorData/VectorData.ConformanceTests/DataTypeTests.cs b/dotnet/test/VectorData/VectorData.ConformanceTests/DataTypeTests.cs new file mode 100644 index 000000000000..a17c95f28c43 --- /dev/null +++ b/dotnet/test/VectorData/VectorData.ConformanceTests/DataTypeTests.cs @@ -0,0 +1,587 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.Extensions.VectorData; +using VectorData.ConformanceTests.Support; +using VectorData.ConformanceTests.Xunit; +using Xunit; + +namespace VectorData.ConformanceTests; + +/// +/// Tests that the various embedding types natively supported by the provider (ReadOnlyMemory<float>, ReadOnlyMemory<Half>...) work correctly. +/// +public abstract class DataTypeTests(DataTypeTests.Fixture fixture) : DataTypeTests() + where TKey : notnull + where TRecord : DataTypeTests.RecordBase, new() +{ + // Note: nullable value types are tested automatically within TestTypeStructAsync + + [ConditionalFact] + public virtual Task Byte() + => fixture.UnsupportedDefaultTypes.Contains(typeof(byte)) + ? Task.CompletedTask + : this.Test("Byte", 8, 9); + + [ConditionalFact] + public virtual Task Short() + => fixture.UnsupportedDefaultTypes.Contains(typeof(short)) + ? Task.CompletedTask + : this.Test("Short", 8, 9); + + [ConditionalFact] + public virtual Task Int() + => fixture.UnsupportedDefaultTypes.Contains(typeof(int)) + ? Task.CompletedTask + : this.Test("Int", 8, 9); + + [ConditionalFact] + public virtual Task Long() + => fixture.UnsupportedDefaultTypes.Contains(typeof(long)) + ? Task.CompletedTask + : this.Test("Long", 8L, 9L); + + [ConditionalFact] + public virtual Task Float() + => fixture.UnsupportedDefaultTypes.Contains(typeof(float)) + ? Task.CompletedTask + : this.Test("Float", 8.5f, 9.5f); + + [ConditionalFact] + public virtual Task Double() + => fixture.UnsupportedDefaultTypes.Contains(typeof(double)) + ? Task.CompletedTask + : this.Test("Double", 8.5d, 9.5d); + + [ConditionalFact] + public virtual Task Decimal() + => fixture.UnsupportedDefaultTypes.Contains(typeof(decimal)) + ? Task.CompletedTask + : this.Test("Decimal", 8.5m, 9.5m); + + [ConditionalFact] + public virtual Task String() + => fixture.UnsupportedDefaultTypes.Contains(typeof(string)) + ? Task.CompletedTask + : this.Test("String", "foo", "bar"); + + [ConditionalFact] + public virtual Task Bool() + => fixture.UnsupportedDefaultTypes.Contains(typeof(bool)) + ? Task.CompletedTask + : this.Test("Bool", true, false); + + [ConditionalFact] + public virtual Task Guid() + => fixture.UnsupportedDefaultTypes.Contains(typeof(Guid)) + ? Task.CompletedTask + : this.Test( + "Guid", + new Guid("603840bf-cf91-4521-8b8e-8b6a2e75910a"), + new Guid("e9a97807-8cf0-4741-8ce3-82df676ca0f0")); + + [ConditionalFact] + public virtual Task DateTime() + => fixture.UnsupportedDefaultTypes.Contains(typeof(DateTime)) + ? Task.CompletedTask + : this.Test( + "DateTime", + new DateTime(2020, 1, 1, 12, 30, 45), + new DateTime(2021, 2, 3, 13, 40, 55), + instantiationExpression: () => new DateTime(2020, 1, 1, 12, 30, 45)); + + [ConditionalFact] + public virtual Task DateTimeOffset() + => fixture.UnsupportedDefaultTypes.Contains(typeof(DateTimeOffset)) + ? Task.CompletedTask + : this.Test( + "DateTimeOffset", + new DateTimeOffset(2020, 1, 1, 12, 30, 45, TimeSpan.FromHours(2)), + new DateTimeOffset(2021, 2, 3, 13, 40, 55, TimeSpan.FromHours(3)), + instantiationExpression: () => new DateTimeOffset(2020, 1, 1, 12, 30, 45, TimeSpan.FromHours(2))); + + [ConditionalFact] + public virtual Task DateOnly() + { +#if NET8_0_OR_GREATER + return fixture.UnsupportedDefaultTypes.Contains(typeof(DateOnly)) + ? Task.CompletedTask + : this.Test( + "DateOnly", + new DateOnly(2020, 1, 1), + new DateOnly(2021, 2, 3)); +#else + return Task.CompletedTask; +#endif + } + + [ConditionalFact] + public virtual Task TimeOnly() + { +#if NET8_0_OR_GREATER + return fixture.UnsupportedDefaultTypes.Contains(typeof(TimeOnly)) + ? Task.CompletedTask + : this.Test( + "TimeOnly", + new TimeOnly(12, 30, 45), + new TimeOnly(13, 40, 55)); +#else + return Task.CompletedTask; +#endif + } + + [ConditionalFact] + public virtual Task String_array() + => fixture.UnsupportedDefaultTypes.Contains(typeof(string[])) + ? Task.CompletedTask + : this.Test( + "StringArray", + ["foo", "bar"], + ["foo", "baz"]); + + [ConditionalFact] + public virtual Task Nullable_value_type() + => fixture.UnsupportedDefaultTypes.Contains(typeof(int?)) + ? Task.CompletedTask + : this.Test("NullableInt", 8, 9); + + protected virtual async Task Test( + string propertyName, + TTestType mainValue, + TTestType otherValue, + bool isFilterable = true, + Action? comparisonAction = null, + Expression>? instantiationExpression = null) + { + if (propertyName is "Key" or "Vector") + { + throw new ArgumentException($"The property name '{propertyName}' is reserved and cannot be used for testing.", nameof(propertyName)); + } + + var property = typeof(TRecord).GetProperty(propertyName) + ?? throw new ArgumentException($"The type '{typeof(TRecord).Name}' does not have a property named '{propertyName}'.", nameof(propertyName)); + comparisonAction ??= (a, b) => Assert.Equal(a, b); + var instantiationExpressionBody = instantiationExpression is null + ? Expression.Constant(mainValue, typeof(TTestType)) + : instantiationExpression.Body; + + await fixture.Collection.DeleteAsync([fixture.MainRecordKey, fixture.OtherRecordKey, fixture.NullRecordKey]); + await fixture.TestStore.WaitForDataAsync(fixture.Collection, recordCount: 0); + + // Step 1: Insert data + await this.InsertData(property, mainValue, otherValue); + + // Step 2: Read the values back via GetAsync + TRecord result = await fixture.Collection.GetAsync(fixture.MainRecordKey) ?? throw new InvalidOperationException($"Record with key '{fixture.MainRecordKey}' was not found."); + comparisonAction(mainValue, (TTestType)property.GetValue(result)!); + + // Step 3: Exercise filtering by the value, using a constant in the filter expression + if (isFilterable) + { + await this.TestFiltering(fixture.Collection, property, mainValue, comparisonAction, instantiationExpressionBody); + } + + /////////////////////// + // Test dynamic mapping + /////////////////////// + if (fixture.RecreateCollection) + { + await fixture.Collection.EnsureCollectionDeletedAsync(); + } + else + { + await fixture.Collection.DeleteAsync([fixture.MainRecordKey, fixture.OtherRecordKey, fixture.NullRecordKey]); + await fixture.TestStore.WaitForDataAsync(fixture.Collection, recordCount: 0); + } + + var dynamicCollection = fixture.VectorStore.GetDynamicCollection(fixture.CollectionName, fixture.CreateRecordDefinition()); + + if (fixture.RecreateCollection) + { + await dynamicCollection.EnsureCollectionExistsAsync(); + } + + // Step 1: Insert data + await this.InsertDynamicData(dynamicCollection, propertyName, mainValue, otherValue); + + // Step 2: Read the values back via GetAsync + var dynamicResult = await dynamicCollection.GetAsync(fixture.MainRecordKey) ?? throw new InvalidOperationException($"Record with key '{fixture.MainRecordKey}' was not found."); + comparisonAction(mainValue, (TTestType)dynamicResult[propertyName]!); + + // Step 3: Exercise dynamic filtering by the value, using a constant in the filter expression + if (isFilterable) + { + await this.TestDynamicFiltering(dynamicCollection, propertyName, mainValue, comparisonAction, instantiationExpressionBody); + } + } + + private async Task InsertData(PropertyInfo property, TTestType mainValue, TTestType otherValue) + { + // Note that all records have the same vector + var mainRecord = GenerateEmptyRecord(); + mainRecord.Key = fixture.MainRecordKey; + mainRecord.Vector = fixture.Vector; + property.SetValue(mainRecord, mainValue); + + var otherRecord = GenerateEmptyRecord(); + otherRecord.Key = fixture.OtherRecordKey; + otherRecord.Vector = fixture.Vector; + property.SetValue(otherRecord, otherValue); + + List testData = [mainRecord, otherRecord]; + + if (default(TTestType) == null && fixture.IsNullSupported) + { + var nullRecord = GenerateEmptyRecord(); + nullRecord.Key = fixture.NullRecordKey; + nullRecord.Vector = fixture.Vector; + property.SetValue(nullRecord, null); + testData.Add(nullRecord); + } + + await fixture.Collection.UpsertAsync(testData); + await fixture.TestStore.WaitForDataAsync(fixture.Collection, recordCount: testData.Count); + + TRecord GenerateEmptyRecord() + { + var record = new TRecord(); + + foreach (var property in fixture.CreateRecordDefinition().Properties) + { + var propertyInfo = typeof(TRecord).GetProperty(property.Name) + ?? throw new InvalidOperationException($"Property '{property.Name}' not found on record type '{typeof(TRecord).Name}'."); + propertyInfo.SetValue(record, this.GenerateEmptyProperty(property)); + } + + return record; + } + } + + private async Task InsertDynamicData( + VectorStoreCollection> dynamicCollection, + string propertyName, + TTestType mainValue, + TTestType otherValue) + { + // Note that all records have the same vector + var mainRecord = GenerateEmptyRecord(); + mainRecord[nameof(RecordBase.Key)] = fixture.MainRecordKey; + mainRecord[nameof(RecordBase.Vector)] = fixture.Vector; + mainRecord[propertyName] = mainValue; + + var otherRecord = GenerateEmptyRecord(); + otherRecord[nameof(RecordBase.Key)] = fixture.OtherRecordKey; + otherRecord[nameof(RecordBase.Vector)] = fixture.Vector; + otherRecord[propertyName] = otherValue; + + List> testData = [mainRecord, otherRecord]; + + if (default(TTestType) == null && fixture.IsNullSupported) + { + var nullRecord = GenerateEmptyRecord(); + nullRecord[nameof(RecordBase.Key)] = fixture.NullRecordKey; + nullRecord[nameof(RecordBase.Vector)] = fixture.Vector; + nullRecord[propertyName] = null; + testData.Add(nullRecord); + } + + await dynamicCollection.UpsertAsync(testData); + await fixture.TestStore.WaitForDataAsync(dynamicCollection, recordCount: testData.Count); + + Dictionary GenerateEmptyRecord() + { + var record = new Dictionary(); + + foreach (var property in fixture.CreateRecordDefinition().Properties) + { + record[property.Name] = this.GenerateEmptyProperty(property); + } + + return record; + } + } + + protected virtual object? GenerateEmptyProperty(VectorStoreProperty property) + => property.Type switch + { + null => throw new InvalidOperationException($"Property '{property.Name}' has no type defined."), + + // For value types, we create an instance with the default value. + // This is necessary for relational providers where non-nullable columns are created. + var t when t.IsValueType => Activator.CreateInstance(t), + + // In some cases (Azure AI Search), array fields must be non-null + var t when t.IsArray => Array.CreateInstance(t.GetElementType()!, 0), + + _ => null + }; + + private async Task TestFiltering( + VectorStoreCollection collection, + PropertyInfo property, + TTestType mainValue, + Action comparisonAction, + Expression instantiationExpression) + { + // Note: we need to manually build the expression tree since the equality operator can't be used over + // unconstrained generic types. + var lambdaParameter = Expression.Parameter(typeof(TRecord), "r"); + var filter = Expression.Lambda>( + Expression.Equal( + Expression.Property(lambdaParameter, property), + instantiationExpression), + lambdaParameter); + + // Some databases (Mongo) update the filter index asynchronously, so we wait until the record appears under the filter, + // and then do the main search to make sure only the main record is returned. + await fixture.TestStore.WaitForDataAsync(collection, filter: filter, recordCount: 1); + var result = (await collection.SearchAsync(fixture.Vector, top: 100, new() { Filter = filter }).SingleAsync()).Record; + + Assert.Equal(fixture.MainRecordKey, result.Key); + comparisonAction(mainValue, (TTestType)property.GetValue(result)!); + + // Exercise filtering by a null value + if (default(TTestType) == null && fixture.IsNullFilteringSupported) + { + lambdaParameter = Expression.Parameter(typeof(TRecord), "r"); + filter = Expression.Lambda>( + Expression.Equal( + Expression.Property(lambdaParameter, property), + Expression.Constant(null, typeof(TTestType))), + lambdaParameter); + + result = (await collection.SearchAsync(fixture.Vector, top: 100, new() { Filter = filter }).SingleAsync()).Record; + + Assert.Equal(fixture.NullRecordKey, result.Key); + } + } + + private async Task TestDynamicFiltering( + VectorStoreCollection> dynamicCollection, + string propertyName, + TTestType mainValue, + Action comparisonAction, + Expression instantiationExpression) + { + // Note: we need to manually build the expression tree since we want the property name to be a constant + var lambdaParameter = Expression.Parameter(typeof(Dictionary), "r"); + var filter = Expression.Lambda, bool>>( + Expression.Equal( + Expression.Convert( + Expression.Call(lambdaParameter, DynamicDictionaryIndexer, Expression.Constant(propertyName)), + typeof(TTestType)), + instantiationExpression), + lambdaParameter); + + // Some databases (Mongo) update the filter index asynchronously, so we wait until the record appears under the filter, + // and then do the main search to make sure only the main record is returned. + await fixture.TestStore.WaitForDataAsync(dynamicCollection, filter: filter, recordCount: 1); + var result = (await dynamicCollection.SearchAsync(fixture.Vector, top: 100, new() { Filter = filter }).SingleAsync()).Record; + Assert.Equal(fixture.MainRecordKey, result[nameof(RecordBase.Key)]); + comparisonAction(mainValue, (TTestType)result[propertyName]!); + + // Exercise filtering by a null value + if (default(TTestType) == null && fixture.IsNullFilteringSupported) + { + lambdaParameter = Expression.Parameter(typeof(Dictionary), "r"); + filter = Expression.Lambda, bool>>( + Expression.Equal( + Expression.Convert( + Expression.Call(lambdaParameter, DynamicDictionaryIndexer, Expression.Constant(propertyName)), + typeof(TTestType)), + Expression.Constant(null, typeof(TTestType))), + lambdaParameter); + + result = (await dynamicCollection.SearchAsync(fixture.Vector, top: 100, new() { Filter = filter }).SingleAsync()).Record; + + Assert.Equal(fixture.NullRecordKey, result[nameof(RecordBase.Key)]); + } + } + + private static readonly MethodInfo DynamicDictionaryIndexer = typeof(Dictionary).GetMethod("get_Item")!; + + public abstract class Fixture : VectorStoreCollectionFixture + { + public override string CollectionName => "DataTypeTests"; + + public virtual bool IsNullSupported => true; + public virtual bool IsNullFilteringSupported => true; + + public virtual Type[] UnsupportedDefaultTypes { get; } = []; + + public virtual TKey MainRecordKey { get; protected set; } = default!; + public virtual TKey OtherRecordKey { get; protected set; } = default!; + public virtual TKey NullRecordKey { get; protected set; } = default!; + + public virtual float[] Vector { get; } = [1, 2, 3]; + + private readonly IList _defaultDataProperties = null!; + + /// + /// Whether the recreate the collection while testing, as opposed to deleting the records. + /// Necessary for InMemory, where the .NET mapped on the collection cannot be changed. + /// + public virtual bool RecreateCollection => false; + +#pragma warning disable CA2214 // Do not call overridable methods in constructors + protected Fixture() + { + this._defaultDataProperties = this.GetDataProperties(); + } +#pragma warning restore CA2214 + + public override async Task InitializeAsync() + { + await base.InitializeAsync(); + + this.MainRecordKey = this.GenerateNextKey(); + this.OtherRecordKey = this.GenerateNextKey(); + this.NullRecordKey = this.GenerateNextKey(); + } + + public override VectorStoreCollectionDefinition CreateRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreKeyProperty(nameof(RecordBase.Key), typeof(TKey)), + new VectorStoreVectorProperty(nameof(RecordBase.Vector), typeof(float[]), 3) + { + DistanceFunction = this.DistanceFunction, + IndexKind = this.IndexKind + }, + + .. this._defaultDataProperties + ] + }; + + public virtual IList GetDataProperties() + { + var properties = new List(); + + if (!this.UnsupportedDefaultTypes.Contains(typeof(byte))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Byte), typeof(byte)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(short))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Short), typeof(short)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(int))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Int), typeof(int)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(long))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Long), typeof(long)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(float))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Float), typeof(float)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(double))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Double), typeof(double)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(decimal))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Decimal), typeof(decimal)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(string))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.String), typeof(string)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(bool))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Bool), typeof(bool)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(Guid))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.Guid), typeof(Guid)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(DateTime))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.DateTime), typeof(DateTime)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(DateTimeOffset))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.DateTimeOffset), typeof(DateTimeOffset)) { IsIndexed = true }); + } + +#if NET8_0_OR_GREATER + if (!this.UnsupportedDefaultTypes.Contains(typeof(DateOnly))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.DateOnly), typeof(DateOnly)) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(TimeOnly))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.TimeOnly), typeof(TimeOnly)) { IsIndexed = true }); + } +#endif + if (!this.UnsupportedDefaultTypes.Contains(typeof(string[]))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.StringArray), typeof(string[])) { IsIndexed = true }); + } + + if (!this.UnsupportedDefaultTypes.Contains(typeof(int?))) + { + properties.Add(new VectorStoreDataProperty(nameof(DefaultRecord.NullableInt), typeof(int?)) { IsIndexed = true }); + } + + return properties; + } + } +} + +// We have this base class so the Record type can be referenced in subtypes (the main TypeTests class +// is generic over the record type as well). +public abstract class DataTypeTests() + where TKey : notnull +{ + public class RecordBase + { + public TKey Key { get; set; } = default!; + public float[] Vector { get; set; } = default!; + } + + public class DefaultRecord : RecordBase + { + public byte Byte { get; set; } + public short Short { get; set; } + public int Int { get; set; } + public long Long { get; set; } + + public float Float { get; set; } + public double Double { get; set; } + public decimal Decimal { get; set; } + + public string? String { get; set; } + public bool Bool { get; set; } + public Guid Guid { get; set; } + + public DateTime DateTime { get; set; } + public DateTimeOffset DateTimeOffset { get; set; } + +#if NET8_0_OR_GREATER + public DateOnly DateOnly { get; set; } + public TimeOnly TimeOnly { get; set; } +#endif + + public string[] StringArray { get; set; } = null!; + + public int? NullableInt { get; set; } + } +} diff --git a/dotnet/test/VectorData/VectorData.ConformanceTests/EmbeddingGenerationTests.cs b/dotnet/test/VectorData/VectorData.ConformanceTests/EmbeddingGenerationTests.cs index 8002c8dc757a..60b562d60d6b 100644 --- a/dotnet/test/VectorData/VectorData.ConformanceTests/EmbeddingGenerationTests.cs +++ b/dotnet/test/VectorData/VectorData.ConformanceTests/EmbeddingGenerationTests.cs @@ -9,7 +9,6 @@ namespace VectorData.ConformanceTests; -#pragma warning disable CA1819 // Properties should not return arrays #pragma warning disable CA2000 // Don't actually need to dispose FakeEmbeddingGenerator #pragma warning disable CS8605 // Unboxing a possibly null value. diff --git a/dotnet/test/VectorData/VectorData.ConformanceTests/Filter/BasicFilterTests.cs b/dotnet/test/VectorData/VectorData.ConformanceTests/Filter/BasicFilterTests.cs index 85fa89dcff9e..b429059b73b2 100644 --- a/dotnet/test/VectorData/VectorData.ConformanceTests/Filter/BasicFilterTests.cs +++ b/dotnet/test/VectorData/VectorData.ConformanceTests/Filter/BasicFilterTests.cs @@ -539,8 +539,6 @@ protected virtual async Task TestLegacyFilterAsync( } } -#pragma warning disable CS1819 // Properties should not return arrays -#pragma warning disable CA1819 // Properties should not return arrays public class FilterRecord { public TKey Key { get; set; } = default!; @@ -553,8 +551,6 @@ public class FilterRecord public string[] StringArray { get; set; } = null!; public List StringList { get; set; } = null!; } -#pragma warning restore CA1819 // Properties should not return arrays -#pragma warning restore CS1819 public abstract class Fixture : VectorStoreCollectionFixture { diff --git a/dotnet/test/VectorData/VectorData.ConformanceTests/Support/TestStore.cs b/dotnet/test/VectorData/VectorData.ConformanceTests/Support/TestStore.cs index 8aeb85aef293..719153071515 100644 --- a/dotnet/test/VectorData/VectorData.ConformanceTests/Support/TestStore.cs +++ b/dotnet/test/VectorData/VectorData.ConformanceTests/Support/TestStore.cs @@ -100,17 +100,13 @@ public virtual async Task WaitForDataAsync( { var results = collection.SearchAsync( vector, - top: recordCount, + top: recordCount is 0 ? 1 : recordCount, new() { Filter = filter }); var count = await results.CountAsync(); if (count == recordCount) { return; } - if (count > recordCount) - { - throw new InvalidOperationException($"Expected at most {recordCount} records, but found {count}."); - } await Task.Delay(TimeSpan.FromMilliseconds(100)); } diff --git a/dotnet/test/VectorData/VectorData.ConformanceTests/Support/VectorStoreCollectionFixture.cs b/dotnet/test/VectorData/VectorData.ConformanceTests/Support/VectorStoreCollectionFixture.cs index 3bdadfeec684..87251e5ab2b0 100644 --- a/dotnet/test/VectorData/VectorData.ConformanceTests/Support/VectorStoreCollectionFixture.cs +++ b/dotnet/test/VectorData/VectorData.ConformanceTests/Support/VectorStoreCollectionFixture.cs @@ -17,7 +17,7 @@ public abstract class VectorStoreCollectionFixture : VectorStoreF private List? _testData; public abstract VectorStoreCollectionDefinition CreateRecordDefinition(); - protected abstract List BuildTestData(); + protected virtual List BuildTestData() => []; public virtual string CollectionName => Guid.NewGuid().ToString(); protected virtual string DistanceFunction => this.TestStore.DefaultDistanceFunction; @@ -47,8 +47,11 @@ public override async Task InitializeAsync() protected virtual async Task SeedAsync() { - await this.Collection.UpsertAsync(this.TestData); - await this.WaitForDataAsync(); + if (this.TestData.Count > 0) + { + await this.Collection.UpsertAsync(this.TestData); + await this.WaitForDataAsync(); + } } protected virtual Task WaitForDataAsync() diff --git a/dotnet/test/VectorData/Weaviate.ConformanceTests/WeaviateDataTypeTests.cs b/dotnet/test/VectorData/Weaviate.ConformanceTests/WeaviateDataTypeTests.cs new file mode 100644 index 000000000000..9a258fa2d472 --- /dev/null +++ b/dotnet/test/VectorData/Weaviate.ConformanceTests/WeaviateDataTypeTests.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Weaviate.ConformanceTests.Support; +using Xunit; + +namespace Weaviate.ConformanceTests; + +public class WeaviateDataTypeTests(WeaviateDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public override Task String_array() + => this.Test( + "StringArray", + ["foo", "bar"], + ["foo", "baz"], + isFilterable: false); // TODO: We don't currently support filtering on arrays + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => WeaviateTestStore.NamedVectorsInstance; + + // TODO: Weaviate requires special indexing for filtering on nulls, see #10358 + public override bool IsNullFilteringSupported => false; + + public override Type[] UnsupportedDefaultTypes { get; } = + [ + typeof(DateTime), +#if NET8_0_OR_GREATER + typeof(DateOnly), + typeof(TimeOnly) +#endif + ]; + } +} diff --git a/python/pyproject.toml b/python/pyproject.toml index e6d8477e1362..6f977994dd27 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -68,7 +68,7 @@ autogen = [ "autogen-agentchat >= 0.2, <0.4" ] aws = [ - "boto3>=1.36.4,<1.39.0", + "boto3>=1.36.4,<1.40.0", ] azure = [ "azure-ai-inference >= 1.0.0b6", @@ -107,7 +107,7 @@ mistralai = [ "mistralai >= 1.2,< 2.0" ] mongo = [ - "pymongo >= 4.8.0, < 4.13", + "pymongo >= 4.8.0, < 4.14", "motor >= 3.3.2,< 3.8.0" ] notebooks = [ diff --git a/python/semantic_kernel/__init__.py b/python/semantic_kernel/__init__.py index 58b35c9a2243..2410f2f0c43e 100644 --- a/python/semantic_kernel/__init__.py +++ b/python/semantic_kernel/__init__.py @@ -2,7 +2,7 @@ from semantic_kernel.kernel import Kernel -__version__ = "1.34.0" +__version__ = "1.35.0" DEFAULT_RC_VERSION = f"{__version__}-rc9" diff --git a/python/semantic_kernel/agents/azure_ai/agent_content_generation.py b/python/semantic_kernel/agents/azure_ai/agent_content_generation.py index 5e1082e857a8..c2152e894c3d 100644 --- a/python/semantic_kernel/agents/azure_ai/agent_content_generation.py +++ b/python/semantic_kernel/agents/azure_ai/agent_content_generation.py @@ -299,26 +299,33 @@ def generate_bing_grounding_content( @experimental def generate_azure_ai_search_content( agent_name: str, azure_ai_search_tool_call: "RunStepAzureAISearchToolCall" -) -> ChatMessageContent: +) -> ChatMessageContent | None: """Generate function result content related to an Azure AI Search Tool.""" - message_content: ChatMessageContent = ChatMessageContent(role=AuthorRole.ASSISTANT, name=agent_name) # type: ignore + items: list[FunctionCallContent | FunctionResultContent] = [] + # Azure AI Search tool call contains both tool call input and output - message_content.items.append( - FunctionCallContent( - id=azure_ai_search_tool_call.id, - name=azure_ai_search_tool_call.type, - function_name=azure_ai_search_tool_call.type, - arguments=azure_ai_search_tool_call.azure_ai_search.get("input"), + arguments = azure_ai_search_tool_call.azure_ai_search.get("input") + if arguments: + items.append( + FunctionCallContent( + id=azure_ai_search_tool_call.id, + name=azure_ai_search_tool_call.type, + function_name=azure_ai_search_tool_call.type, + arguments=arguments, + inner_content=azure_ai_search_tool_call, + ) ) - ) - message_content.items.append( - FunctionResultContent( - function_name=azure_ai_search_tool_call.type, - id=azure_ai_search_tool_call.id, - result=azure_ai_search_tool_call.azure_ai_search.get("output"), + result = azure_ai_search_tool_call.azure_ai_search.get("output") + if result: + items.append( + FunctionResultContent( + function_name=azure_ai_search_tool_call.type, + id=azure_ai_search_tool_call.id, + result=result, + inner_content=azure_ai_search_tool_call, + ) ) - ) - return message_content + return ChatMessageContent(role=AuthorRole.ASSISTANT, name=agent_name, items=items) if items else None # type: ignore @experimental @@ -506,29 +513,39 @@ def generate_streaming_azure_ai_search_content( for index, tool in enumerate(step_details.tool_calls): if tool.type == "azure_ai_search": azure_ai_search_tool = cast(RunStepAzureAISearchToolCall, tool) - arguments = getattr(azure_ai_search_tool, "azure_ai_search", None) - items.append( - FunctionCallContent( - id=azure_ai_search_tool.id, - index=index, - name=azure_ai_search_tool.type, - function_name=azure_ai_search_tool.type, - arguments=arguments, + azure_ai_search_dict: dict = azure_ai_search_tool.get("azure_ai_search", None) + arguments = azure_ai_search_dict.get("input", {}) if azure_ai_search_dict else None + if arguments: + items.append( + FunctionCallContent( + id=azure_ai_search_tool.id, + index=index, + name=azure_ai_search_tool.type, + function_name=azure_ai_search_tool.type, + arguments=arguments, + inner_content=azure_ai_search_tool, + ) ) - ) - items.append( - FunctionResultContent( - function_name=azure_ai_search_tool.type, - id=azure_ai_search_tool.id, - result=azure_ai_search_tool.azure_ai_search.get("output"), + result = azure_ai_search_dict.get("output", {}) if azure_ai_search_dict else None + if result: + items.append( + FunctionResultContent( + function_name=azure_ai_search_tool.type, + id=azure_ai_search_tool.id, + result=result, + inner_content=azure_ai_search_tool, + ) ) - ) - return StreamingChatMessageContent( - role=AuthorRole.ASSISTANT, - name=agent_name, - choice_index=0, - items=items, # type: ignore + return ( + StreamingChatMessageContent( + role=AuthorRole.ASSISTANT, + name=agent_name, + choice_index=0, + items=items, # type: ignore + ) + if items + else None ) # type: ignore diff --git a/python/semantic_kernel/agents/orchestration/agent_actor_base.py b/python/semantic_kernel/agents/orchestration/agent_actor_base.py index c3a4f4d6146d..8d4ae0458d63 100644 --- a/python/semantic_kernel/agents/orchestration/agent_actor_base.py +++ b/python/semantic_kernel/agents/orchestration/agent_actor_base.py @@ -2,8 +2,10 @@ import inspect +import logging import sys from collections.abc import Awaitable, Callable +from functools import wraps from typing import Any from semantic_kernel.agents.agent import Agent, AgentThread @@ -18,11 +20,27 @@ else: from typing_extensions import override # pragma: no cover +logger: logging.Logger = logging.getLogger(__name__) + @experimental class ActorBase(RoutedAgent): """A base class for actors running in the AgentRuntime.""" + def __init__( + self, + description: str, + exception_callback: Callable[[BaseException], None], + ): + """Initialize the actor with a description and an exception callback. + + Args: + description (str): A description of the actor. + exception_callback (Callable[[BaseException], None]): A callback function to handle exceptions. + """ + super().__init__(description=description) + self._exception_callback = exception_callback + @override async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None: """Handle a message. @@ -34,6 +52,46 @@ async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None return await super().on_message_impl(message, ctx) + @staticmethod + def exception_handler(func: Callable[..., Any]) -> Callable[..., Any]: + """Decorator that wraps a function in a try-catch block and calls the exception callback on errors. + + This decorator can be used on both synchronous and asynchronous functions. When an exception + occurs during function execution, it will call the exception_callback with the exception + and then re-raise the exception. + + Args: + func: The function to be wrapped. + + Returns: + The wrapped function. + """ + log_message_template = "Exception occurred in agent {agent_id}:\n{exception}" + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(self, *args, **kwargs): + try: + return await func(self, *args, **kwargs) + except BaseException as e: + self._exception_callback(e) + logger.error(log_message_template.format(agent_id=self.id, exception=e)) + raise + + return async_wrapper + + @wraps(func) + def sync_wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except BaseException as e: + self._exception_callback(e) + logger.error(log_message_template.format(agent_id=self.id, exception=e)) + raise + + return sync_wrapper + @experimental class AgentActorBase(ActorBase): @@ -43,6 +101,7 @@ def __init__( self, agent: Agent, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], agent_response_callback: Callable[[DefaultTypeAlias], Awaitable[None] | None] | None = None, streaming_agent_response_callback: Callable[[StreamingChatMessageContent, bool], Awaitable[None] | None] | None = None, @@ -52,6 +111,7 @@ def __init__( Args: agent (Agent): An agent to be run in the container. internal_topic_type (str): The topic type of the internal topic. + exception_callback (Callable): A function that is called when an exception occurs. agent_response_callback (Callable | None): A function that is called when a full response is produced by the agents. streaming_agent_response_callback (Callable | None): A function that is called when a streaming response @@ -63,10 +123,10 @@ def __init__( self._streaming_agent_response_callback = streaming_agent_response_callback self._agent_thread: AgentThread | None = None - # Chat history to temporarily store messages before the agent thread is created - self._chat_history = ChatHistory() + # Chat history to temporarily store messages before each invoke. + self._message_cache: ChatHistory = ChatHistory() - ActorBase.__init__(self, description=agent.description or "Semantic Kernel Agent") + super().__init__(agent.description or "Semantic Kernel Actor", exception_callback) async def _call_agent_response_callback(self, message: DefaultTypeAlias) -> None: """Call the agent_response_callback function if it is set. @@ -97,6 +157,7 @@ async def _call_streaming_agent_response_callback( else: self._streaming_agent_response_callback(message_chunk, is_final) + @ActorBase.exception_handler async def _invoke_agent(self, additional_messages: DefaultTypeAlias | None = None, **kwargs) -> ChatMessageContent: """Invoke the agent with the current chat history or thread and optionally additional messages. @@ -146,7 +207,10 @@ def _create_messages(self, additional_messages: DefaultTypeAlias | None = None) Returns: list[ChatMessageContent]: A list of messages to be sent to the agent. """ - base_messages = self._chat_history.messages[:] if self._agent_thread is None else [] + base_messages = self._message_cache.messages[:] + + # Clear the message cache for the next invoke. + self._message_cache.clear() if additional_messages is None: return base_messages diff --git a/python/semantic_kernel/agents/orchestration/concurrent.py b/python/semantic_kernel/agents/orchestration/concurrent.py index c669114ca7d3..576b97210ab3 100644 --- a/python/semantic_kernel/agents/orchestration/concurrent.py +++ b/python/semantic_kernel/agents/orchestration/concurrent.py @@ -56,6 +56,7 @@ def __init__( agent: Agent, internal_topic_type: str, collection_agent_type: str, + exception_callback: Callable[[BaseException], None], agent_response_callback: Callable[[DefaultTypeAlias], Awaitable[None] | None] | None = None, streaming_agent_response_callback: Callable[[StreamingChatMessageContent, bool], Awaitable[None] | None] | None = None, @@ -66,6 +67,7 @@ def __init__( agent: The agent to be executed. internal_topic_type: The internal topic type for the actor. collection_agent_type: The collection agent type for the actor. + exception_callback: A callback function to handle exceptions. agent_response_callback: A callback function to handle the full response from the agent. streaming_agent_response_callback: A callback function to handle streaming responses from the agent. """ @@ -73,6 +75,7 @@ def __init__( super().__init__( agent=agent, internal_topic_type=internal_topic_type, + exception_callback=exception_callback, agent_response_callback=agent_response_callback, streaming_agent_response_callback=streaming_agent_response_callback, ) @@ -102,6 +105,7 @@ def __init__( self, description: str, expected_answer_count: int, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, ) -> None: """Initialize the collection agent container.""" @@ -110,7 +114,7 @@ def __init__( self._results: list[ChatMessageContent] = [] self._lock = asyncio.Lock() - super().__init__(description=description) + super().__init__(description, exception_callback) @message_handler async def _handle_message(self, message: ConcurrentResponseMessage, _: MessageContext) -> None: @@ -147,17 +151,20 @@ async def _prepare( self, runtime: CoreRuntime, internal_topic_type: str, - result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, + exception_callback: Callable[[BaseException], None], + result_callback: Callable[[DefaultTypeAlias], Awaitable[None]], ) -> None: """Register the actors and orchestrations with the runtime and add the required subscriptions.""" await asyncio.gather(*[ self._register_members( runtime, internal_topic_type, + exception_callback, ), self._register_collection_actor( runtime, internal_topic_type, + exception_callback, result_callback=result_callback, ), self._add_subscriptions( @@ -170,6 +177,7 @@ async def _register_members( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], ) -> None: """Register the members.""" @@ -180,7 +188,8 @@ async def _internal_helper(agent: Agent) -> None: lambda agent=agent: ConcurrentAgentActor( # type: ignore[misc] agent, internal_topic_type, - collection_agent_type=self._get_collection_actor_type(internal_topic_type), + self._get_collection_actor_type(internal_topic_type), + exception_callback, agent_response_callback=self._agent_response_callback, streaming_agent_response_callback=self._streaming_agent_response_callback, ), @@ -192,6 +201,7 @@ async def _register_collection_actor( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, ) -> None: await CollectionActor.register( @@ -200,6 +210,7 @@ async def _register_collection_actor( lambda: CollectionActor( description="An internal agent that is responsible for collection results", expected_answer_count=len(self._members), + exception_callback=exception_callback, result_callback=result_callback, ), ) diff --git a/python/semantic_kernel/agents/orchestration/group_chat.py b/python/semantic_kernel/agents/orchestration/group_chat.py index 4b4fe07b2cc7..65c4640e72a5 100644 --- a/python/semantic_kernel/agents/orchestration/group_chat.py +++ b/python/semantic_kernel/agents/orchestration/group_chat.py @@ -107,27 +107,17 @@ async def _handle_start_message(self, message: GroupChatStartMessage, ctx: Messa """Handle the start message for the group chat.""" logger.debug(f"{self.id}: Received group chat start message.") if isinstance(message.body, ChatMessageContent): - if self._agent_thread: - await self._agent_thread.on_new_message(message.body) - else: - self._chat_history.add_message(message.body) + self._message_cache.add_message(message.body) elif isinstance(message.body, list) and all(isinstance(m, ChatMessageContent) for m in message.body): - if self._agent_thread: - for m in message.body: - await self._agent_thread.on_new_message(m) - else: - for m in message.body: - self._chat_history.add_message(m) + for m in message.body: + self._message_cache.add_message(m) else: raise ValueError(f"Invalid message body type: {type(message.body)}. Expected {DefaultTypeAlias}.") @message_handler async def _handle_response_message(self, message: GroupChatResponseMessage, ctx: MessageContext) -> None: logger.debug(f"{self.id}: Received group chat response message.") - if self._agent_thread is not None: - await self._agent_thread.on_new_message(message.body) - else: - self._chat_history.add_message(message.body) + self._message_cache.add_message(message.body) @message_handler async def _handle_request_message(self, message: GroupChatRequestMessage, ctx: MessageContext) -> None: @@ -267,6 +257,7 @@ def __init__( manager: GroupChatManager, internal_topic_type: str, participant_descriptions: dict[str, str], + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, ): """Initialize the group chat manager actor. @@ -275,8 +266,7 @@ def __init__( manager (GroupChatManager): The group chat manager that manages the flow of the group chat. internal_topic_type (str): The topic type of the internal topic. participant_descriptions (dict[str, str]): The descriptions of the participants in the group chat. - agent_response_callback (Callable | None): A function that is called when a response is produced - by the agents. + exception_callback (Callable[[BaseException], None]): A function that is called when an exception occurs. result_callback (Callable | None): A function that is called when the group chat manager produces a result. """ self._manager = manager @@ -285,7 +275,7 @@ def __init__( self._participant_descriptions = participant_descriptions self._result_callback = result_callback - super().__init__(description="An actor for the group chat manager.") + super().__init__("An actor for the group chat manager.", exception_callback) @message_handler async def _handle_start_message(self, message: GroupChatStartMessage, ctx: MessageContext) -> None: @@ -314,6 +304,7 @@ async def _handle_response_message(self, message: GroupChatResponseMessage, ctx: await self._determine_state_and_take_action(ctx.cancellation_token) + @ActorBase.exception_handler async def _determine_state_and_take_action(self, cancellation_token: CancellationToken) -> None: """Determine the state of the group chat and take action accordingly.""" # User input state @@ -458,14 +449,20 @@ async def _prepare( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]], ) -> None: """Register the actors and orchestrations with the runtime and add the required subscriptions.""" - await self._register_members(runtime, internal_topic_type) - await self._register_manager(runtime, internal_topic_type, result_callback=result_callback) + await self._register_members(runtime, internal_topic_type, exception_callback) + await self._register_manager(runtime, internal_topic_type, exception_callback, result_callback) await self._add_subscriptions(runtime, internal_topic_type) - async def _register_members(self, runtime: CoreRuntime, internal_topic_type: str) -> None: + async def _register_members( + self, + runtime: CoreRuntime, + internal_topic_type: str, + exception_callback: Callable[[BaseException], None], + ) -> None: """Register the agents.""" await asyncio.gather(*[ GroupChatAgentActor.register( @@ -474,6 +471,7 @@ async def _register_members(self, runtime: CoreRuntime, internal_topic_type: str lambda agent=agent: GroupChatAgentActor( # type: ignore[misc] agent, internal_topic_type, + exception_callback=exception_callback, agent_response_callback=self._agent_response_callback, streaming_agent_response_callback=self._streaming_agent_response_callback, ), @@ -485,6 +483,7 @@ async def _register_manager( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, ) -> None: """Register the group chat manager.""" @@ -495,6 +494,7 @@ async def _register_manager( self._manager, internal_topic_type=internal_topic_type, participant_descriptions={agent.name: agent.description for agent in self._members}, # type: ignore[misc] + exception_callback=exception_callback, result_callback=result_callback, ), ) diff --git a/python/semantic_kernel/agents/orchestration/handoffs.py b/python/semantic_kernel/agents/orchestration/handoffs.py index c680f2815242..569c0598188a 100644 --- a/python/semantic_kernel/agents/orchestration/handoffs.py +++ b/python/semantic_kernel/agents/orchestration/handoffs.py @@ -9,7 +9,7 @@ from functools import partial from semantic_kernel.agents.agent import Agent -from semantic_kernel.agents.orchestration.agent_actor_base import AgentActorBase +from semantic_kernel.agents.orchestration.agent_actor_base import ActorBase, AgentActorBase from semantic_kernel.agents.orchestration.orchestration_base import DefaultTypeAlias, OrchestrationBase, TIn, TOut from semantic_kernel.agents.runtime.core.cancellation_token import CancellationToken from semantic_kernel.agents.runtime.core.core_runtime import CoreRuntime @@ -161,6 +161,7 @@ def __init__( agent: Agent, internal_topic_type: str, handoff_connections: AgentHandoffs, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, agent_response_callback: Callable[[DefaultTypeAlias], Awaitable[None] | None] | None = None, streaming_agent_response_callback: Callable[[StreamingChatMessageContent, bool], Awaitable[None] | None] @@ -181,6 +182,7 @@ def __init__( super().__init__( agent=agent, internal_topic_type=internal_topic_type, + exception_callback=exception_callback, agent_response_callback=agent_response_callback, streaming_agent_response_callback=streaming_agent_response_callback, ) @@ -250,16 +252,10 @@ async def _complete_task(self, task_summary: str) -> None: async def _handle_start_message(self, message: HandoffStartMessage, cts: MessageContext) -> None: logger.debug(f"{self.id}: Received handoff start message.") if isinstance(message.body, ChatMessageContent): - if self._agent_thread: - await self._agent_thread.on_new_message(message.body) - else: - self._chat_history.add_message(message.body) + self._message_cache.add_message(message.body) elif isinstance(message.body, list) and all(isinstance(m, ChatMessageContent) for m in message.body): for m in message.body: - if self._agent_thread: - await self._agent_thread.on_new_message(m) - else: - self._chat_history.add_message(m) + self._message_cache.add_message(m) else: raise ValueError(f"Invalid message body type: {type(message.body)}. Expected {DefaultTypeAlias}.") @@ -267,10 +263,7 @@ async def _handle_start_message(self, message: HandoffStartMessage, cts: Message async def _handle_response_message(self, message: HandoffResponseMessage, cts: MessageContext) -> None: """Handle a response message from an agent in the handoff group.""" logger.debug(f"{self.id}: Received handoff response message.") - if self._agent_thread is not None: - await self._agent_thread.on_new_message(message.body) - else: - self._chat_history.add_message(message.body) + self._message_cache.add_message(message.body) @message_handler async def _handle_request_message(self, message: HandoffRequestMessage, cts: MessageContext) -> None: @@ -325,6 +318,7 @@ async def _call_human_response_function(self) -> ChatMessageContent: return await self._human_response_function() return self._human_response_function() # type: ignore[return-value] + @ActorBase.exception_handler async def _invoke_agent_with_potentially_no_response( self, additional_messages: DefaultTypeAlias | None = None, **kwargs ) -> ChatMessageContent | None: @@ -462,16 +456,18 @@ async def _prepare( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]], ) -> None: """Register the actors and orchestrations with the runtime and add the required subscriptions.""" - await self._register_members(runtime, internal_topic_type, result_callback) + await self._register_members(runtime, internal_topic_type, exception_callback, result_callback) await self._add_subscriptions(runtime, internal_topic_type) async def _register_members( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, ) -> None: """Register the members with the runtime.""" @@ -485,6 +481,7 @@ async def _register_helper(agent: Agent) -> None: agent, internal_topic_type, handoff_connections, + exception_callback, result_callback=result_callback, agent_response_callback=self._agent_response_callback, streaming_agent_response_callback=self._streaming_agent_response_callback, diff --git a/python/semantic_kernel/agents/orchestration/magentic.py b/python/semantic_kernel/agents/orchestration/magentic.py index c644da66ae46..45b8e29ec89c 100644 --- a/python/semantic_kernel/agents/orchestration/magentic.py +++ b/python/semantic_kernel/agents/orchestration/magentic.py @@ -477,6 +477,7 @@ def __init__( manager: MagenticManagerBase, internal_topic_type: str, participant_descriptions: dict[str, str], + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, ) -> None: """Initialize the Magentic One manager actor. @@ -485,6 +486,7 @@ def __init__( manager (MagenticManagerBase): The Magentic One manager. internal_topic_type (str): The internal topic type. participant_descriptions (dict[str, str]): The participant descriptions. + exception_callback (Callable[[BaseException], None]): A callback function to handle exceptions. result_callback (Callable | None): A callback function to handle the final answer. """ self._manager = manager @@ -494,9 +496,10 @@ def __init__( self._context: MagenticContext | None = None self._task_ledger: ChatMessageContent | None = None - super().__init__(description="Magentic One Manager") + super().__init__("Magentic One Manager", exception_callback) @message_handler + @ActorBase.exception_handler async def _handle_start_message(self, message: MagenticStartMessage, ctx: MessageContext) -> None: """Handle the start message for the Magentic One manager.""" logger.debug(f"{self.id}: Received Magentic One start message.") @@ -512,6 +515,7 @@ async def _handle_start_message(self, message: MagenticStartMessage, ctx: Messag await self._run_outer_loop(ctx.cancellation_token) @message_handler + @ActorBase.exception_handler async def _handle_response_message(self, message: MagenticResponseMessage, ctx: MessageContext) -> None: """Handle the response message for the Magentic One manager.""" if self._context is None or self._task_ledger is None: @@ -647,19 +651,32 @@ async def _check_within_limits(self) -> bool: if self._context is None: raise RuntimeError("The Magentic manager is not started yet. Make sure to send a start message first.") - if ( + hit_round_limit = ( self._manager.max_round_count is not None and self._context.round_count >= self._manager.max_round_count - ) or (self._manager.max_reset_count is not None and self._context.reset_count > self._manager.max_reset_count): - message = ( - "Max round count reached." - if self._manager.max_round_count and self._context.round_count >= self._manager.max_round_count - else "Max reset count reached." + ) + hit_reset_limit = ( + self._manager.max_reset_count is not None and self._context.reset_count > self._manager.max_reset_count + ) + + if hit_round_limit or hit_reset_limit: + limit_type = "round" if hit_round_limit else "reset" + logger.debug(f"Max {limit_type} count reached.") + + # Retrieve the latest assistant content produced so far + partial_result = next( + (m for m in reversed(self._context.chat_history.messages) if m.role == AuthorRole.ASSISTANT), + None, ) - logger.debug(message) - if self._result_callback: - await self._result_callback( - ChatMessageContent(role=AuthorRole.ASSISTANT, content=message, name=self.__class__.__name__) + if partial_result is None: + partial_result = ChatMessageContent( + role=AuthorRole.ASSISTANT, + content=f"Stopped because the maximum {limit_type} limit was reached. No partial result available.", + name=self.__class__.__name__, ) + + if self._result_callback: + await self._result_callback(partial_result) + return False return True @@ -677,10 +694,7 @@ class MagenticAgentActor(AgentActorBase): @message_handler async def _handle_response_message(self, message: MagenticResponseMessage, ctx: MessageContext) -> None: logger.debug(f"{self.id}: Received response message.") - if self._agent_thread is not None: - await self._agent_thread.on_new_message(message.body) - else: - self._chat_history.add_message(message.body) + self._message_cache.add_message(message.body) @message_handler async def _handle_request_message(self, message: MagenticRequestMessage, ctx: MessageContext) -> None: @@ -703,7 +717,7 @@ async def _handle_request_message(self, message: MagenticRequestMessage, ctx: Me async def _handle_reset_message(self, message: MagenticResetMessage, ctx: MessageContext) -> None: """Handle the reset message for the Magentic One group chat.""" logger.debug(f"{self.id}: Received reset message.") - self._chat_history.clear() + self._message_cache.clear() if self._agent_thread: await self._agent_thread.delete() self._agent_thread = None @@ -785,14 +799,20 @@ async def _prepare( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]], ) -> None: """Register the actors and orchestrations with the runtime and add the required subscriptions.""" - await self._register_members(runtime, internal_topic_type) - await self._register_manager(runtime, internal_topic_type, result_callback=result_callback) + await self._register_members(runtime, internal_topic_type, exception_callback) + await self._register_manager(runtime, internal_topic_type, exception_callback, result_callback=result_callback) await self._add_subscriptions(runtime, internal_topic_type) - async def _register_members(self, runtime: CoreRuntime, internal_topic_type: str) -> None: + async def _register_members( + self, + runtime: CoreRuntime, + internal_topic_type: str, + exception_callback: Callable[[BaseException], None], + ) -> None: """Register the agents.""" await asyncio.gather(*[ MagenticAgentActor.register( @@ -801,6 +821,7 @@ async def _register_members(self, runtime: CoreRuntime, internal_topic_type: str lambda agent=agent: MagenticAgentActor( # type: ignore[misc] agent, internal_topic_type, + exception_callback, self._agent_response_callback, self._streaming_agent_response_callback, ), @@ -812,6 +833,7 @@ async def _register_manager( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]] | None = None, ) -> None: """Register the group chat manager.""" @@ -822,6 +844,7 @@ async def _register_manager( self._manager, internal_topic_type=internal_topic_type, participant_descriptions={agent.name: agent.description for agent in self._members}, # type: ignore[misc] + exception_callback=exception_callback, result_callback=result_callback, ), ) diff --git a/python/semantic_kernel/agents/orchestration/orchestration_base.py b/python/semantic_kernel/agents/orchestration/orchestration_base.py index ed547828ba0f..8ff4f1938b71 100644 --- a/python/semantic_kernel/agents/orchestration/orchestration_base.py +++ b/python/semantic_kernel/agents/orchestration/orchestration_base.py @@ -34,6 +34,7 @@ class OrchestrationResult(KernelBaseModel, Generic[TOut]): """The result of an invocation of an orchestration.""" + background_task: asyncio.Task | None = None value: TOut | None = None exception: BaseException | None = None event: asyncio.Event = Field(default_factory=asyncio.Event) @@ -205,6 +206,11 @@ async def result_callback(result: DefaultTypeAlias) -> None: orchestration_result.value = transformed_result orchestration_result.event.set() + def inner_exception_callback(exception: BaseException) -> None: + nonlocal orchestration_result + orchestration_result.exception = exception + orchestration_result.event.set() + # This unique topic type is used to isolate the orchestration run from others. internal_topic_type = uuid.uuid4().hex @@ -212,6 +218,7 @@ async def result_callback(result: DefaultTypeAlias) -> None: runtime, internal_topic_type=internal_topic_type, result_callback=result_callback, + exception_callback=inner_exception_callback, ) if isinstance(task, str): @@ -235,8 +242,8 @@ async def result_callback(result: DefaultTypeAlias) -> None: ) ) - # Add a callback to surface any exceptions that occur during the task execution. - def exception_callback(task: asyncio.Task) -> None: + # Add a callback to surface any exceptions that occur during outside of the runtime. + def outer_exception_callback(task: asyncio.Task) -> None: nonlocal orchestration_result try: task.result() @@ -244,7 +251,8 @@ def exception_callback(task: asyncio.Task) -> None: orchestration_result.exception = e orchestration_result.event.set() - background_task.add_done_callback(exception_callback) + background_task.add_done_callback(outer_exception_callback) + orchestration_result.background_task = background_task return orchestration_result @@ -271,6 +279,7 @@ async def _prepare( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]], ) -> None: """Register the actors and orchestrations with the runtime and add the required subscriptions. @@ -278,8 +287,7 @@ async def _prepare( Args: runtime (CoreRuntime): The runtime environment for the agents. internal_topic_type (str): The internal topic type for the orchestration that this actor is part of. - external_topic_type (str | None): The external topic type for the orchestration. - direct_actor_type (str | None): The direct actor type for which this actor will relay the output message to. + exception_callback (Callable): A function that is called when an exception occurs. result_callback (Callable): A function that is called when the result is available. """ pass diff --git a/python/semantic_kernel/agents/orchestration/sequential.py b/python/semantic_kernel/agents/orchestration/sequential.py index d04369d88c02..c28fec0d4812 100644 --- a/python/semantic_kernel/agents/orchestration/sequential.py +++ b/python/semantic_kernel/agents/orchestration/sequential.py @@ -48,6 +48,7 @@ def __init__( agent: Agent, internal_topic_type: str, next_agent_type: str, + exception_callback: Callable[[BaseException], None], agent_response_callback: Callable[[DefaultTypeAlias], Awaitable[None] | None] | None = None, streaming_agent_response_callback: Callable[[StreamingChatMessageContent, bool], Awaitable[None] | None] | None = None, @@ -57,6 +58,7 @@ def __init__( super().__init__( agent=agent, internal_topic_type=internal_topic_type, + exception_callback=exception_callback, agent_response_callback=agent_response_callback, streaming_agent_response_callback=streaming_agent_response_callback, ) @@ -85,12 +87,13 @@ class CollectionActor(ActorBase): def __init__( self, description: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]], ) -> None: """Initialize the collection actor.""" self._result_callback = result_callback - super().__init__(description=description) + super().__init__(description, exception_callback) @message_handler async def _handle_message(self, message: SequentialRequestMessage, _: MessageContext) -> None: @@ -123,16 +126,18 @@ async def _prepare( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]], ) -> None: """Register the actors and orchestrations with the runtime and add the required subscriptions.""" - await self._register_members(runtime, internal_topic_type) - await self._register_collection_actor(runtime, internal_topic_type, result_callback) + await self._register_members(runtime, internal_topic_type, exception_callback) + await self._register_collection_actor(runtime, internal_topic_type, exception_callback, result_callback) async def _register_members( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], ) -> None: """Register the members. @@ -143,6 +148,7 @@ async def _register_members( Args: runtime (CoreRuntime): The agent runtime. internal_topic_type (str): The internal topic type for the orchestration that this actor is part of. + exception_callback (Callable[[BaseException], None]): A callback function to handle exceptions. Returns: str: The first actor type in the sequence. @@ -156,6 +162,7 @@ async def _register_members( agent, internal_topic_type, next_agent_type=next_actor_type, + exception_callback=exception_callback, agent_response_callback=self._agent_response_callback, streaming_agent_response_callback=self._streaming_agent_response_callback, ), @@ -167,6 +174,7 @@ async def _register_collection_actor( self, runtime: CoreRuntime, internal_topic_type: str, + exception_callback: Callable[[BaseException], None], result_callback: Callable[[DefaultTypeAlias], Awaitable[None]], ) -> None: """Register the collection actor.""" @@ -175,6 +183,7 @@ async def _register_collection_actor( self._get_collection_actor_type(internal_topic_type), lambda: CollectionActor( description="An internal agent that is responsible for collection results", + exception_callback=exception_callback, result_callback=result_callback, ), ) diff --git a/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_text_to_image_execution_settings.py b/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_text_to_image_execution_settings.py index f33c92ad6cb8..3c4d45122940 100644 --- a/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_text_to_image_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_text_to_image_execution_settings.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import logging -from typing import Any +from typing import Any, Literal from pydantic import Field, model_validator @@ -40,6 +40,10 @@ class OpenAITextToImageExecutionSettings(PromptExecutionSettings): size: ImageSize | None = None quality: str | None = None style: str | None = None + output_compression: int | None = None + background: Literal["transparent", "opaque", "auto"] | None = None + n: int | None = Field(default=1, ge=1, le=10) + moderation: Literal["auto", "low"] | None = None @model_validator(mode="before") @classmethod diff --git a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_handler.py b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_handler.py index 081a67b07ad0..59af2d709dc0 100644 --- a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_handler.py +++ b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_handler.py @@ -5,6 +5,7 @@ from typing import Any, Union from openai import AsyncOpenAI, AsyncStream, BadRequestError, _legacy_response +from openai._types import NOT_GIVEN, FileTypes, NotGiven from openai.lib._parsing._completions import type_to_response_format_param from openai.types import Completion, CreateEmbeddingResponse from openai.types.audio import Transcription @@ -122,12 +123,41 @@ async def _send_embedding_request(self, settings: OpenAIEmbeddingPromptExecution async def _send_text_to_image_request(self, settings: OpenAITextToImageExecutionSettings) -> ImagesResponse: """Send a request to the OpenAI text to image endpoint.""" try: - return await self.client.images.generate( + response: ImagesResponse = await self.client.images.generate( **settings.prepare_settings_dict(), ) + self.store_usage(response) + return response except Exception as ex: raise ServiceResponseException(f"Failed to generate image: {ex}") from ex + async def _send_image_edit_request( + self, + image: list[FileTypes], + settings: OpenAITextToImageExecutionSettings, + mask: FileTypes | NotGiven = NOT_GIVEN, + ) -> ImagesResponse: + """Send a request to the OpenAI image edit endpoint. + + Args: + image: List of image files to edit. Accepts file paths or bytes. + settings: Image edit execution settings. + mask: Optional mask image. Accepts file path or bytes. + + Returns: + ImagesResponse: The response from the image edit API. + """ + try: + response: ImagesResponse = await self.client.images.edit( + image=image, + mask=mask, + **settings.prepare_settings_dict(), + ) + self.store_usage(response) + return response + except Exception as ex: + raise ServiceResponseException(f"Failed to edit image: {ex}") from ex + async def _send_audio_to_text_request(self, settings: OpenAIAudioToTextExecutionSettings) -> Transcription: """Send a request to the OpenAI audio to text endpoint.""" if not settings.filename: @@ -187,12 +217,19 @@ def store_usage( | Completion | AsyncStream[ChatCompletionChunk] | AsyncStream[Completion] - | CreateEmbeddingResponse, + | CreateEmbeddingResponse + | ImagesResponse, ): """Store the usage information from the response.""" - if not isinstance(response, AsyncStream) and response.usage: + if isinstance(response, ImagesResponse) and hasattr(response, "usage") and response.usage: + logger.info(f"OpenAI image usage: {response.usage}") + self.prompt_tokens += response.usage.input_tokens + self.total_tokens += response.usage.total_tokens + self.completion_tokens += response.usage.output_tokens + return + if not isinstance(response, AsyncStream) and not isinstance(response, ImagesResponse) and response.usage: logger.info(f"OpenAI usage: {response.usage}") self.prompt_tokens += response.usage.prompt_tokens self.total_tokens += response.usage.total_tokens if hasattr(response.usage, "completion_tokens"): - self.completion_tokens += response.usage.completion_tokens + self.completion_tokens += response.usage.completion_tokens # type: ignore diff --git a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_to_image_base.py b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_to_image_base.py index 1cfef29d087a..a8ba3977ca63 100644 --- a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_to_image_base.py +++ b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_to_image_base.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any +from pathlib import Path +from typing import IO, Any from warnings import warn +from openai._types import NOT_GIVEN, FileTypes, NotGiven from openai.types.images_response import ImagesResponse from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_text_to_image_execution_settings import ( @@ -38,6 +40,7 @@ async def generate_image( Returns: bytes | str: Image bytes or image URL. """ + warn("generate_image is deprecated. Use generate_images.", DeprecationWarning, stacklevel=2) if not settings: settings = OpenAITextToImageExecutionSettings(**kwargs) if not isinstance(settings, OpenAITextToImageExecutionSettings): @@ -70,6 +73,177 @@ async def generate_image( return response.data[0].url + async def generate_images( + self, + prompt: str, + settings: PromptExecutionSettings | None = None, + **kwargs: Any, + ) -> list[str]: + """Generate one or more images from text. Returns URLs or base64-encoded images. + + Args: + prompt: Description of the image(s) to generate. + settings: Execution settings for the prompt. + kwargs: Additional arguments, check the openai images.generate documentation for the supported arguments. + + Returns: + list[str]: Image URLs or base64-encoded images. + + Example: + Generate images and save them as PNG files: + + ```python + from semantic_kernel.connectors.ai.open_ai import AzureTextToImage + import base64, os + + service = AzureTextToImage( + service_id="image1", + deployment_name="gpt-image-1", + endpoint="https://your-endpoint.cognitiveservices.azure.com", + api_key="your-api-key", + api_version="2025-04-01-preview", + ) + settings = service.get_prompt_execution_settings_class()(service_id="image1") + settings.n = 3 + images_b64 = await service.generate_images("A cute cat wearing a whimsical striped hat", settings=settings) + ``` + """ + if not settings: + settings = OpenAITextToImageExecutionSettings(**kwargs) + if not isinstance(settings, OpenAITextToImageExecutionSettings): + settings = OpenAITextToImageExecutionSettings.from_prompt_execution_settings(settings) + if prompt: + settings.prompt = prompt + + if not settings.prompt: + raise ServiceInvalidRequestError("Prompt is required.") + + if not settings.ai_model_id: + settings.ai_model_id = self.ai_model_id + + response = await self._send_request(settings) + + assert isinstance(response, ImagesResponse) # nosec + if not response.data or not isinstance(response.data, list) or len(response.data) == 0: + raise ServiceResponseException("Failed to generate image.") + + results: list[str] = [] + for image in response.data: + url: str | None = getattr(image, "url", None) + b64_json: str | None = getattr(image, "b64_json", None) + if url: + results.append(url) + elif b64_json: + results.append(b64_json) + else: + continue + + if len(results) == 0: + raise ServiceResponseException("No valid image data found in response.") + return results + + async def edit_image( + self, + prompt: str, + image_paths: list[str] | None = None, + image_files: list[IO[bytes]] | None = None, + mask_path: str | None = None, + mask_file: IO[bytes] | None = None, + settings: PromptExecutionSettings | None = None, + **kwargs: Any, + ) -> list[str]: + """Edit images using the OpenAI image edit API. + + Args: + prompt: Instructional prompt for image editing. + image_paths: List of image file paths to edit. + image_files: List of file-like objects (opened in binary mode) to edit. + mask_path: Optional mask image file path. + mask_file: Optional mask image file-like object (opened in binary mode). + settings: Optional execution settings. If not provided, will be constructed from kwargs. + kwargs: Additional API parameters. + + Returns: + list[str]: List of edited image URLs or base64-encoded strings. + + Example: + Edit images from file path and save results: + + ```python + from semantic_kernel.connectors.ai.open_ai import AzureTextToImage + import base64, os + + service = AzureTextToImage( + service_id="image1", + deployment_name="gpt-image-1", + endpoint="https://your-endpoint.cognitiveservices.azure.com", + api_key="your-api-key", + api_version="2025-04-01-preview", + ) + file_paths = ["./new_images/img_1.png", "./new_images/img_2.png"] + settings = service.get_prompt_execution_settings_class()(service_id="image1") + settings.n = 2 + results = await service.edit_image( + prompt="Make the cat wear a wizard hat", + image_paths=file_paths, + settings=settings, + ) + ``` + + Edit images from file object: + + ```python + with open("./new_images/img_1.png", "rb") as f: + results = await service.edit_image( + prompt="Make the cat wear a wizard hat", + image_files=[f], + ) + ``` + """ + if not settings: + settings = OpenAITextToImageExecutionSettings(**kwargs) + if not isinstance(settings, OpenAITextToImageExecutionSettings): + settings = OpenAITextToImageExecutionSettings.from_prompt_execution_settings(settings) + settings.prompt = prompt + + if not settings.prompt: + raise ServiceInvalidRequestError("Prompt is required.") + if (image_paths is None and image_files is None) or (image_paths is not None and image_files is not None): + raise ServiceInvalidRequestError("Provide either 'image_paths' or 'image_files', and only one.") + + images: list[FileTypes] = [] + if image_paths is not None: + images = [Path(p) for p in image_paths] + elif image_files is not None: + images = list(image_files) + + mask: FileTypes | NotGiven = NOT_GIVEN + if mask_path is not None: + mask = Path(mask_path) + elif mask_file is not None: + mask = mask_file + + response: ImagesResponse = await self._send_image_edit_request( + image=images, + mask=mask, + settings=settings, + ) + + if not response or not response.data or not isinstance(response.data, list): + raise ServiceResponseException("Failed to edit image.") + + results: list[str] = [] + for img in response.data: + b64_json: str | None = getattr(img, "b64_json", None) + url: str | None = getattr(img, "url", None) + if b64_json: + results.append(b64_json) + elif url: + results.append(url) + if not results: + raise ServiceResponseException("No valid image data found in response.") + return results + def get_prompt_execution_settings_class(self) -> type[PromptExecutionSettings]: """Get the request settings class.""" return OpenAITextToImageExecutionSettings diff --git a/python/tests/integration/agents/openai_assistant_agent/test_openai_assistant_agent_integration.py b/python/tests/integration/agents/openai_assistant_agent/test_openai_assistant_agent_integration.py index 7ffbde5ce70e..dd8eda10a551 100644 --- a/python/tests/integration/agents/openai_assistant_agent/test_openai_assistant_agent_integration.py +++ b/python/tests/integration/agents/openai_assistant_agent/test_openai_assistant_agent_integration.py @@ -193,8 +193,12 @@ async def test_invoke_stream_with_thread( @pytest.mark.parametrize( "assistant_agent", [ - ("azure", {"enable_code_interpreter": True}), - ("openai", {"enable_code_interpreter": True}), + pytest.param( + ("azure", {"enable_code_interpreter": True}), marks=pytest.mark.xfail(reason="Service outage") + ), + pytest.param( + ("openai", {"enable_code_interpreter": True}), + ), ], indirect=["assistant_agent"], ids=["azure-code-interpreter", "openai-code-interpreter"], @@ -219,8 +223,12 @@ async def test_code_interpreter_get_response( @pytest.mark.parametrize( "assistant_agent", [ - ("azure", {"enable_code_interpreter": True}), - ("openai", {"enable_code_interpreter": True}), + pytest.param( + ("azure", {"enable_code_interpreter": True}), marks=pytest.mark.xfail(reason="Service outage") + ), + pytest.param( + ("openai", {"enable_code_interpreter": True}), + ), ], indirect=["assistant_agent"], ids=["azure-code-interpreter", "openai-code-interpreter"], @@ -245,8 +253,12 @@ async def test_code_interpreter_invoke(self, assistant_agent: OpenAIAssistantAge @pytest.mark.parametrize( "assistant_agent", [ - ("azure", {"enable_code_interpreter": True}), - ("openai", {"enable_code_interpreter": True}), + pytest.param( + ("azure", {"enable_code_interpreter": True}), marks=pytest.mark.xfail(reason="Service outage") + ), + pytest.param( + ("openai", {"enable_code_interpreter": True}), + ), ], indirect=["assistant_agent"], ids=["azure-code-interpreter", "openai-code-interpreter"], diff --git a/python/tests/unit/agents/orchestration/conftest.py b/python/tests/unit/agents/orchestration/conftest.py index 6a286f4aaed5..8035d5bf469b 100644 --- a/python/tests/unit/agents/orchestration/conftest.py +++ b/python/tests/unit/agents/orchestration/conftest.py @@ -90,6 +90,35 @@ async def invoke_stream( ) +class MockAgentWithException(MockAgent): + """A mock agent that raises an exception for testing purposes.""" + + @override + async def invoke_stream( + self, + messages: str | ChatMessageContent | list[str | ChatMessageContent] | None = None, + *, + thread: AgentThread | None = None, + on_intermediate_message: Callable[[ChatMessageContent], Awaitable[None]] | None = None, + **kwargs, + ) -> AsyncIterable[AgentResponseItem[StreamingChatMessageContent]]: + """Simulate streaming response from the agent that raises an exception.""" + # Simulate some processing time + await asyncio.sleep(0.05) + + yield AgentResponseItem[StreamingChatMessageContent]( + message=StreamingChatMessageContent( + role=AuthorRole.ASSISTANT, + name=self.name, + content="mock", + choice_index=0, + ), + thread=thread or MockAgentThread(), + ) + + raise RuntimeError("Mock agent exception") + + class MockRuntime(CoreRuntime): """A mock agent runtime for testing purposes.""" diff --git a/python/tests/unit/agents/orchestration/test_concurrent.py b/python/tests/unit/agents/orchestration/test_concurrent.py index 9c1d96eaac28..75e04250417b 100644 --- a/python/tests/unit/agents/orchestration/test_concurrent.py +++ b/python/tests/unit/agents/orchestration/test_concurrent.py @@ -11,7 +11,7 @@ from semantic_kernel.agents.runtime.in_process.in_process_runtime import InProcessRuntime from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent -from tests.unit.agents.orchestration.conftest import MockAgent, MockRuntime +from tests.unit.agents.orchestration.conftest import MockAgent, MockAgentWithException, MockRuntime async def test_prepare(): @@ -183,3 +183,22 @@ async def test_invoke_with_double_get_result(): assert len(result) == 2 finally: await runtime.stop_when_idle() + + +async def test_invoke_with_agent_raising_exception(): + """Test the invoke method of the ConcurrentOrchestration with an agent raising an exception.""" + agent_a = MockAgent() + agent_b = MockAgentWithException() + + runtime = InProcessRuntime() + runtime.start() + + try: + orchestration = ConcurrentOrchestration(members=[agent_a, agent_b]) + orchestration_result = await orchestration.invoke(task="test_message", runtime=runtime) + + with pytest.raises(RuntimeError, match="Mock agent exception"): + await orchestration_result.get(1.0) + assert orchestration_result.exception is not None + finally: + await runtime.stop_when_idle() diff --git a/python/tests/unit/agents/orchestration/test_group_chat.py b/python/tests/unit/agents/orchestration/test_group_chat.py index 54ab1746b68f..b217f35e36c5 100644 --- a/python/tests/unit/agents/orchestration/test_group_chat.py +++ b/python/tests/unit/agents/orchestration/test_group_chat.py @@ -17,7 +17,7 @@ from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole -from tests.unit.agents.orchestration.conftest import MockAgent, MockRuntime +from tests.unit.agents.orchestration.conftest import MockAgent, MockAgentWithException, MockRuntime if sys.version_info >= (3, 12): from typing import override # pragma: no cover @@ -289,6 +289,29 @@ async def test_invoke_cancel_after_completion(): await runtime.stop_when_idle() +async def test_invoke_with_agent_raising_exception(): + """Test the invoke method of the GroupChatOrchestration with an agent raising an exception.""" + agent_a = MockAgent(description="test agent") + agent_b = MockAgentWithException(description="test agent") + + runtime = InProcessRuntime() + runtime.start() + + try: + orchestration = GroupChatOrchestration( + members=[agent_a, agent_b], + manager=RoundRobinGroupChatManager(max_rounds=3), + ) + + orchestration_result = await orchestration.invoke(task="test_message", runtime=runtime) + + with pytest.raises(RuntimeError, match="Mock agent exception"): + await orchestration_result.get(1.0) + assert orchestration_result.exception is not None + finally: + await runtime.stop_when_idle() + + # endregion GroupChatOrchestration # region RoundRobinGroupChatManager diff --git a/python/tests/unit/agents/orchestration/test_handoff.py b/python/tests/unit/agents/orchestration/test_handoff.py index 6b4ccc887671..4797fc8bd792 100644 --- a/python/tests/unit/agents/orchestration/test_handoff.py +++ b/python/tests/unit/agents/orchestration/test_handoff.py @@ -23,7 +23,7 @@ from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.kernel import Kernel -from tests.unit.agents.orchestration.conftest import MockAgent, MockRuntime +from tests.unit.agents.orchestration.conftest import MockAgent, MockAgentWithException, MockRuntime if sys.version_info >= (3, 12): from typing import override # pragma: no cover @@ -695,4 +695,27 @@ async def test_invoke_cancel_after_completion(): await runtime.stop_when_idle() +async def test_invoke_with_agent_raising_exception(): + """Test the invoke method of the HandoffOrchestration with an agent raising an exception.""" + agent_a = MockAgentWithException() + agent_b = MockAgent() + + runtime = InProcessRuntime() + runtime.start() + + try: + orchestration = HandoffOrchestration( + members=[agent_a, agent_b], + handoffs={agent_a.name: {agent_b.name: "test"}}, + ) + + orchestration_result = await orchestration.invoke(task="test_message", runtime=runtime) + + with pytest.raises(RuntimeError, match="Mock agent exception"): + await orchestration_result.get(1.0) + assert orchestration_result.exception is not None + finally: + await runtime.stop_when_idle() + + # endregion diff --git a/python/tests/unit/agents/orchestration/test_magentic.py b/python/tests/unit/agents/orchestration/test_magentic.py index 6525303bc02c..ce7a4ec4bc8a 100644 --- a/python/tests/unit/agents/orchestration/test_magentic.py +++ b/python/tests/unit/agents/orchestration/test_magentic.py @@ -31,7 +31,7 @@ from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole -from tests.unit.agents.orchestration.conftest import MockAgent, MockRuntime +from tests.unit.agents.orchestration.conftest import MockAgent, MockAgentWithException, MockRuntime class MockChatCompletionService(ChatCompletionClientBase): @@ -245,6 +245,41 @@ async def test_invoke_with_list_error(): await orchestration_result.get(1.0) +async def test_invoke_with_agent_raising_exception(): + """Test the invoke method of the MagenticOrchestration with a list of messages which raises an error.""" + with ( + patch.object( + MockChatCompletionService, "get_chat_message_content", new_callable=AsyncMock + ) as mock_get_chat_message_content, + patch.object( + StandardMagenticManager, "create_progress_ledger", new_callable=AsyncMock, side_effect=ManagerProgressList + ), + ): + mock_get_chat_message_content.return_value = ChatMessageContent(role="assistant", content="mock_response") + chat_completion_service = MockChatCompletionService(ai_model_id="test") + prompt_execution_settings = MockPromptExecutionSettings() + + manager = StandardMagenticManager( + chat_completion_service=chat_completion_service, + prompt_execution_settings=prompt_execution_settings, + ) + + agent_a = MockAgentWithException(name="agent_a", description="test agent") + agent_b = MockAgent(name="agent_b", description="test agent") + + runtime = InProcessRuntime() + runtime.start() + + orchestration = MagenticOrchestration(members=[agent_a, agent_b], manager=manager) + try: + orchestration_result = await orchestration.invoke(task="test_message", runtime=runtime) + with pytest.raises(RuntimeError, match="Mock agent exception"): + await orchestration_result.get(1.0) + assert orchestration_result.exception is not None + finally: + await runtime.stop_when_idle() + + @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.10 doesn't bound the original function provided to the wraps argument of the patch object.", @@ -425,7 +460,8 @@ async def test_invoke_with_max_round_count_exceeded(): finally: await runtime.stop_when_idle() - assert result.content == "Max round count reached." + # Partial result will be returned when max round count is exceeded. + assert result.content == mock_get_chat_message_content.return_value.content assert mock_invoke_stream.call_count == 1 # Planning will be called once, so the facts and plan will be created once. assert mock_get_chat_message_content.call_count == 2 @@ -472,7 +508,9 @@ async def test_invoke_with_max_reset_count_exceeded(): finally: await runtime.stop_when_idle() - assert result.content == "Max reset count reached." + # Partial result will be returned when max reset count is exceeded. The test emits content based on the prompt + # so check that the content is not None and not an exact match to a mock response. + assert result.content is not None assert mock_invoke_stream.call_count == 1 # Planning and replanning will be each called once, so the facts and plan will be created twice. assert mock_get_chat_message_content.call_count == 4 diff --git a/python/tests/unit/agents/orchestration/test_orchestration_base.py b/python/tests/unit/agents/orchestration/test_orchestration_base.py index 803e9def9195..1a65a1348bdf 100644 --- a/python/tests/unit/agents/orchestration/test_orchestration_base.py +++ b/python/tests/unit/agents/orchestration/test_orchestration_base.py @@ -65,7 +65,7 @@ class MockOrchestration(OrchestrationBase[TIn, TOut]): async def _start(self, task, runtime, internal_topic_type, collection_agent_type): pass - async def _prepare(self, runtime, internal_topic_type, result_callback): + async def _prepare(self, runtime, internal_topic_type, exception_callback, result_callback): pass diff --git a/python/tests/unit/agents/orchestration/test_sequential.py b/python/tests/unit/agents/orchestration/test_sequential.py index 46aef154252b..5883e9d2c0e4 100644 --- a/python/tests/unit/agents/orchestration/test_sequential.py +++ b/python/tests/unit/agents/orchestration/test_sequential.py @@ -11,7 +11,7 @@ from semantic_kernel.agents.runtime.in_process.in_process_runtime import InProcessRuntime from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent -from tests.unit.agents.orchestration.conftest import MockAgent, MockRuntime +from tests.unit.agents.orchestration.conftest import MockAgent, MockAgentWithException, MockRuntime async def test_prepare(): @@ -181,3 +181,22 @@ async def test_invoke_with_double_get_result(): assert result.content == "mock_response" finally: await runtime.stop_when_idle() + + +async def test_invoke_with_agent_raising_exception(): + """Test the invoke method of the SequentialOrchestration with an agent raising an exception.""" + agent_a = MockAgent() + agent_b = MockAgentWithException() + + runtime = InProcessRuntime() + runtime.start() + + try: + orchestration = SequentialOrchestration(members=[agent_a, agent_b]) + orchestration_result = await orchestration.invoke(task="test_message", runtime=runtime) + + with pytest.raises(RuntimeError, match="Mock agent exception"): + await orchestration_result.get(1.0) + assert orchestration_result.exception is not None + finally: + await runtime.stop_when_idle() diff --git a/python/tests/unit/connectors/ai/open_ai/services/test_azure_text_to_image.py b/python/tests/unit/connectors/ai/open_ai/services/test_azure_text_to_image.py index 4d462fcbf251..20e46f27fcc0 100644 --- a/python/tests/unit/connectors/ai/open_ai/services/test_azure_text_to_image.py +++ b/python/tests/unit/connectors/ai/open_ai/services/test_azure_text_to_image.py @@ -81,15 +81,19 @@ def test_azure_text_to_image_init_with_from_dict(azure_openai_unit_test_env) -> @patch.object(AsyncImages, "generate", return_value=AsyncMock(spec=ImagesResponse)) async def test_azure_text_to_image_calls_with_parameters(mock_generate, azure_openai_unit_test_env) -> None: mock_generate.return_value.data = [Image(url="abc")] + mock_generate.return_value.usage = None prompt = "A painting of a vase with flowers" width = 512 - azure_text_to_image = AzureTextToImage() + azure_text_to_image = AzureTextToImage( + deployment_name=azure_openai_unit_test_env["AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME"] + ) await azure_text_to_image.generate_image(prompt, width=width, height=width) mock_generate.assert_awaited_once_with( prompt=prompt, model=azure_openai_unit_test_env["AZURE_OPENAI_TEXT_TO_IMAGE_DEPLOYMENT_NAME"], size=f"{width}x{width}", + n=1, ) diff --git a/python/tests/unit/connectors/ai/open_ai/services/test_openai_text_to_image.py b/python/tests/unit/connectors/ai/open_ai/services/test_openai_text_to_image.py index 09c61eecff90..18ff4b749d0f 100644 --- a/python/tests/unit/connectors/ai/open_ai/services/test_openai_text_to_image.py +++ b/python/tests/unit/connectors/ai/open_ai/services/test_openai_text_to_image.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. +import os import warnings from unittest.mock import AsyncMock, patch +import pydantic import pytest from openai import AsyncClient from openai.resources.images import AsyncImages @@ -10,15 +12,20 @@ from openai.types.images_response import ImagesResponse from semantic_kernel.connectors.ai.open_ai import OpenAITextToImage, OpenAITextToImageExecutionSettings +from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_to_image_base import OpenAITextToImageBase from semantic_kernel.exceptions.service_exceptions import ( ServiceInitializationError, ServiceInvalidExecutionSettingsError, + ServiceInvalidRequestError, ServiceResponseException, ) +sample_img = os.path.join(os.path.dirname(__file__), "../../../../../assets/sample_image.jpg") + def test_init(openai_unit_test_env): - openai_text_to_image = OpenAITextToImage() + """Test that OpenAITextToImage initializes with the correct model id and client.""" + openai_text_to_image = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) assert openai_text_to_image.client is not None assert isinstance(openai_text_to_image.client, AsyncClient) @@ -26,11 +33,13 @@ def test_init(openai_unit_test_env): def test_init_validation_fail() -> None: + """Test that initialization fails when required parameters are missing.""" with pytest.raises(ServiceInitializationError): - OpenAITextToImage(api_key="34523", ai_model_id={"test": "dict"}) + OpenAITextToImage(api_key="34523", ai_model_id=None) def test_init_to_from_dict(openai_unit_test_env): + """Test to_dict and from_dict methods for correct serialization and deserialization.""" default_headers = {"X-Unit-Test": "test-guid"} settings = { @@ -46,6 +55,7 @@ def test_init_to_from_dict(openai_unit_test_env): @pytest.mark.parametrize("exclude_list", [["OPENAI_API_KEY"]], indirect=True) def test_init_with_empty_api_key(openai_unit_test_env) -> None: + """Test that initialization fails when API key is missing.""" with pytest.raises(ServiceInitializationError): OpenAITextToImage( env_file_path="test.env", @@ -54,6 +64,7 @@ def test_init_with_empty_api_key(openai_unit_test_env) -> None: @pytest.mark.parametrize("exclude_list", [["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]], indirect=True) def test_init_with_no_model_id(openai_unit_test_env) -> None: + """Test that initialization fails when model id is missing.""" with pytest.raises(ServiceInitializationError): OpenAITextToImage( env_file_path="test.env", @@ -61,13 +72,16 @@ def test_init_with_no_model_id(openai_unit_test_env) -> None: def test_prompt_execution_settings_class(openai_unit_test_env) -> None: + """Test that the correct prompt execution settings class is returned.""" openai_text_to_image = OpenAITextToImage() assert openai_text_to_image.get_prompt_execution_settings_class() == OpenAITextToImageExecutionSettings @patch.object(AsyncImages, "generate", return_value=AsyncMock(spec=ImagesResponse)) async def test_generate_calls_with_parameters(mock_generate, openai_unit_test_env) -> None: + """Test that generate_image calls the OpenAI API with correct parameters.""" mock_generate.return_value.data = [Image(url="abc")] + mock_generate.return_value.usage = None ai_model_id = "test_model_id" prompt = "painting of flowers in vase" @@ -82,12 +96,14 @@ async def test_generate_calls_with_parameters(mock_generate, openai_unit_test_en prompt=prompt, model=ai_model_id, size=f"{width}x{width}", + n=1, ) - assert len(w) == 2 + assert len(w) == 3 @patch.object(AsyncImages, "generate", new_callable=AsyncMock, side_effect=Exception) async def test_generate_fail(mock_generate, openai_unit_test_env) -> None: + """Test that generate_image raises ServiceResponseException on API failure.""" ai_model_id = "test_model_id" width = 512 @@ -97,6 +113,7 @@ async def test_generate_fail(mock_generate, openai_unit_test_env) -> None: async def test_generate_invalid_image_size(openai_unit_test_env) -> None: + """Test that invalid image size raises ServiceInvalidExecutionSettingsError.""" ai_model_id = "test_model_id" width = 100 @@ -106,6 +123,7 @@ async def test_generate_invalid_image_size(openai_unit_test_env) -> None: async def test_generate_empty_description(openai_unit_test_env) -> None: + """Test that empty description raises ServiceInvalidExecutionSettingsError.""" ai_model_id = "test_model_id" width = 100 @@ -116,10 +134,135 @@ async def test_generate_empty_description(openai_unit_test_env) -> None: @patch.object(AsyncImages, "generate", new_callable=AsyncMock) async def test_generate_no_result(mock_generate, openai_unit_test_env) -> None: - mock_generate.return_value = ImagesResponse(created=0, data=[]) + """Test that no result from API raises ServiceResponseException.""" + mock_generate.return_value = ImagesResponse(created=0, data=[], usage=None) ai_model_id = "test_model_id" width = 512 openai_text_to_image = OpenAITextToImage(ai_model_id=ai_model_id) with pytest.raises(ServiceResponseException): await openai_text_to_image.generate_image(description="painting of flowers in vase", width=width, height=width) + + +@patch.object(OpenAITextToImageBase, "_send_image_edit_request", new_callable=AsyncMock) +async def test_edit_image_with_path_success(mock_edit, openai_unit_test_env): + """Test editing an image using a file path returns the expected URL.""" + mock_edit.return_value = ImagesResponse(created=1, data=[Image(url="edited_url")], usage=None) + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + result = await service.edit_image( + prompt="edit this image", + image_paths=[sample_img], + ) + assert result == ["edited_url"] + mock_edit.assert_awaited() + + +@patch.object(OpenAITextToImageBase, "_send_image_edit_request", new_callable=AsyncMock) +async def test_edit_image_with_file_success(mock_edit, openai_unit_test_env): + """Test editing an image using a file object returns the expected URL.""" + mock_edit.return_value = ImagesResponse(created=1, data=[Image(url="edited_url")], usage=None) + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + with open(sample_img, "rb") as f: + result = await service.edit_image( + prompt="edit this image", + image_files=[f], + ) + assert result == ["edited_url"] + mock_edit.assert_awaited() + + +@patch.object(OpenAITextToImageBase, "_send_image_edit_request", new_callable=AsyncMock) +async def test_edit_image_with_mask_path_and_file(mock_edit, openai_unit_test_env): + """Test editing an image with both mask path and mask file returns the expected URL.""" + mock_edit.return_value = ImagesResponse(created=1, data=[Image(url="edited_url")], usage=None) + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + # mask_path + result = await service.edit_image( + prompt="edit with mask", + image_paths=[sample_img], + mask_path=sample_img, + ) + assert result == ["edited_url"] + # mask_file + with open(sample_img, "rb") as mf: + result2 = await service.edit_image( + prompt="edit with mask", + image_paths=[sample_img], + mask_file=mf, + ) + assert result2 == ["edited_url"] + + +@pytest.mark.asyncio +async def test_edit_image_prompt_required(openai_unit_test_env): + """Test that an empty prompt raises ServiceInvalidRequestError.""" + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + with pytest.raises(ServiceInvalidRequestError): + await service.edit_image(prompt="", image_paths=[sample_img]) + + +@pytest.mark.asyncio +async def test_edit_image_both_path_and_file_error(openai_unit_test_env): + """Test that providing both image_paths and image_files raises ServiceInvalidRequestError.""" + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + with ( + open(sample_img, "rb") as f, + pytest.raises(ServiceInvalidRequestError), + ): + await service.edit_image( + prompt="edit", + image_paths=[sample_img], + image_files=[f], + ) + + +@patch.object(OpenAITextToImageBase, "_send_image_edit_request", new_callable=AsyncMock) +async def test_edit_image_no_valid_data_in_response(mock_edit, openai_unit_test_env): + """Test that no valid data in edit response raises ServiceResponseException.""" + mock_edit.return_value = ImagesResponse(created=1, data=[], usage=None) + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + with pytest.raises(ServiceResponseException): + await service.edit_image( + prompt="edit", + image_paths=[sample_img], + ) + + +@patch.object(OpenAITextToImageBase, "_send_request", new_callable=AsyncMock) +async def test_generate_images_with_n_parameter(mock_generate, openai_unit_test_env): + """Test that generate_images returns correct URLs when n parameter is set.""" + mock_generate.return_value = ImagesResponse(created=3, data=[Image(url=f"url_{i}") for i in range(3)], usage=None) + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + settings = OpenAITextToImageExecutionSettings(n=3) + result = await service.generate_images("prompt", settings=settings) + assert result == [f"url_{i}" for i in range(3)] + + +@patch.object(OpenAITextToImageBase, "_send_request", new_callable=AsyncMock) +async def test_generate_images_with_output_compression_and_background(mock_generate, openai_unit_test_env): + """Test that output_compression and background parameters are handled correctly.""" + mock_generate.return_value = ImagesResponse(created=1, data=[Image(url="url")], usage=None) + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + settings = OpenAITextToImageExecutionSettings(output_compression=5, background="transparent") + await service.generate_images("prompt", settings=settings) + called_settings = mock_generate.call_args[0][0] + assert called_settings.output_compression == 5 + assert called_settings.background == "transparent" + + +@patch.object(OpenAITextToImageBase, "store_usage") +def test_store_usage_for_images_response(mock_store_usage, openai_unit_test_env): + """Test that store_usage is called for ImagesResponse.""" + service = OpenAITextToImage(ai_model_id=openai_unit_test_env["OPENAI_TEXT_TO_IMAGE_MODEL_ID"]) + response = ImagesResponse(created=1, data=[Image(url="url")], usage=None) + service.store_usage(response) + mock_store_usage.assert_called() + + +@pytest.mark.asyncio +async def test_edit_image_invalid_n_parameter(): + """Test that invalid n parameter raises pydantic.ValidationError.""" + with pytest.raises(pydantic.ValidationError): + OpenAITextToImageExecutionSettings(n=0) + with pytest.raises(pydantic.ValidationError): + OpenAITextToImageExecutionSettings(n=11) diff --git a/python/uv.lock b/python/uv.lock index ca583e833a0f..46f5f6d4c1d7 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1555,7 +1555,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.95.1" +version = "1.97.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -1572,9 +1572,9 @@ dependencies = [ { name = "shapely", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/89/4e/af3cb47f7d28dacda9cba9853ce4cc998b0d3ba445907a28755d056bdda7/google_cloud_aiplatform-1.95.1.tar.gz", hash = "sha256:75beb3bf79d58648d40380e25a6863c02d5424558e1c6fdbffc6ed0ebce098fb", size = 9184978, upload-time = "2025-05-30T04:16:28.82Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9b/ea/38224d2972e16c82ee16c13407e647586e25671bd2f75d4455491c678c92/google_cloud_aiplatform-1.97.0.tar.gz", hash = "sha256:01277ac5648abe7d2af688b123d7d050c1a34922e9f4297e51e44d165cb79b45", size = 9229557, upload-time = "2025-06-11T06:40:19.907Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b4/8f/7bdc805192941d4859c720c8e7351eead720fae76c19cd7900424c0891f0/google_cloud_aiplatform-1.95.1-py2.py3-none-any.whl", hash = "sha256:f8a072857aef12391ee6cf128b2775d0f133baafaceed639f053dc551c356e05", size = 7651112, upload-time = "2025-05-30T04:16:25.506Z" }, + { url = "https://files.pythonhosted.org/packages/9a/b8/f9ca10a648bc2596e904c30270c49e72528e2b3b583d886eeeec5080b27d/google_cloud_aiplatform-1.97.0-py2.py3-none-any.whl", hash = "sha256:4db9455308110b1e8c1b587bd3ff34449fa459fda45c4466b9b2d9ae259a7af6", size = 7687924, upload-time = "2025-06-11T06:40:16.947Z" }, ] [[package]] @@ -5684,7 +5684,7 @@ requires-dist = [ { name = "defusedxml", specifier = "~=0.7" }, { name = "faiss-cpu", marker = "extra == 'faiss'", specifier = ">=1.10.0" }, { name = "flask-dapr", marker = "extra == 'dapr'", specifier = ">=1.14.0" }, - { name = "google-cloud-aiplatform", marker = "extra == 'google'", specifier = "==1.95.1" }, + { name = "google-cloud-aiplatform", marker = "extra == 'google'", specifier = "==1.97.0" }, { name = "google-generativeai", marker = "extra == 'google'", specifier = "~=0.8" }, { name = "ipykernel", marker = "extra == 'notebooks'", specifier = "~=6.29" }, { name = "jinja2", specifier = "~=3.1" }, @@ -5718,7 +5718,7 @@ requires-dist = [ { name = "redisvl", marker = "extra == 'redis'", specifier = "~=0.4" }, { name = "scipy", specifier = ">=1.15.1" }, { name = "sentence-transformers", marker = "extra == 'hugging-face'", specifier = ">=2.2,<5.0" }, - { name = "torch", marker = "extra == 'hugging-face'", specifier = "==2.7.0" }, + { name = "torch", marker = "extra == 'hugging-face'", specifier = "==2.7.1" }, { name = "transformers", extras = ["torch"], marker = "extra == 'hugging-face'", specifier = "~=4.28" }, { name = "types-redis", marker = "extra == 'redis'", specifier = "~=4.6.0.20240425" }, { name = "typing-extensions", specifier = ">=4.13" }, @@ -6130,7 +6130,7 @@ wheels = [ [[package]] name = "torch" -version = "2.7.0" +version = "2.7.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -6158,26 +6158,26 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/46/c2/3fb87940fa160d956ee94d644d37b99a24b9c05a4222bf34f94c71880e28/torch-2.7.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c9afea41b11e1a1ab1b258a5c31afbd646d6319042bfe4f231b408034b51128b", size = 99158447, upload-time = "2025-04-23T14:35:10.557Z" }, - { url = "https://files.pythonhosted.org/packages/cc/2c/91d1de65573fce563f5284e69d9c56b57289625cffbbb6d533d5d56c36a5/torch-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0b9960183b6e5b71239a3e6c883d8852c304e691c0b2955f7045e8a6d05b9183", size = 865164221, upload-time = "2025-04-23T14:33:27.864Z" }, - { url = "https://files.pythonhosted.org/packages/7f/7e/1b1cc4e0e7cc2666cceb3d250eef47a205f0821c330392cf45eb08156ce5/torch-2.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:2ad79d0d8c2a20a37c5df6052ec67c2078a2c4e9a96dd3a8b55daaff6d28ea29", size = 212521189, upload-time = "2025-04-23T14:34:53.898Z" }, - { url = "https://files.pythonhosted.org/packages/dc/0b/b2b83f30b8e84a51bf4f96aa3f5f65fdf7c31c591cc519310942339977e2/torch-2.7.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:34e0168ed6de99121612d72224e59b2a58a83dae64999990eada7260c5dd582d", size = 68559462, upload-time = "2025-04-23T14:35:39.889Z" }, - { url = "https://files.pythonhosted.org/packages/40/da/7378d16cc636697f2a94f791cb496939b60fb8580ddbbef22367db2c2274/torch-2.7.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2b7813e904757b125faf1a9a3154e1d50381d539ced34da1992f52440567c156", size = 99159397, upload-time = "2025-04-23T14:35:35.304Z" }, - { url = "https://files.pythonhosted.org/packages/0e/6b/87fcddd34df9f53880fa1f0c23af7b6b96c935856473faf3914323588c40/torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fd5cfbb4c3bbadd57ad1b27d56a28008f8d8753733411a140fcfb84d7f933a25", size = 865183681, upload-time = "2025-04-23T14:34:21.802Z" }, - { url = "https://files.pythonhosted.org/packages/13/85/6c1092d4b06c3db1ed23d4106488750917156af0b24ab0a2d9951830b0e9/torch-2.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:58df8d5c2eeb81305760282b5069ea4442791a6bbf0c74d9069b7b3304ff8a37", size = 212520100, upload-time = "2025-04-23T14:35:27.473Z" }, - { url = "https://files.pythonhosted.org/packages/aa/3f/85b56f7e2abcfa558c5fbf7b11eb02d78a4a63e6aeee2bbae3bb552abea5/torch-2.7.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:0a8d43caa342b9986101ec5feb5bbf1d86570b5caa01e9cb426378311258fdde", size = 68569377, upload-time = "2025-04-23T14:35:20.361Z" }, - { url = "https://files.pythonhosted.org/packages/aa/5e/ac759f4c0ab7c01feffa777bd68b43d2ac61560a9770eeac074b450f81d4/torch-2.7.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:36a6368c7ace41ad1c0f69f18056020b6a5ca47bedaca9a2f3b578f5a104c26c", size = 99013250, upload-time = "2025-04-23T14:35:15.589Z" }, - { url = "https://files.pythonhosted.org/packages/9c/58/2d245b6f1ef61cf11dfc4aceeaacbb40fea706ccebac3f863890c720ab73/torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:15aab3e31c16feb12ae0a88dba3434a458874636f360c567caa6a91f6bfba481", size = 865042157, upload-time = "2025-04-23T14:32:56.011Z" }, - { url = "https://files.pythonhosted.org/packages/44/80/b353c024e6b624cd9ce1d66dcb9d24e0294680f95b369f19280e241a0159/torch-2.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f56d4b2510934e072bab3ab8987e00e60e1262fb238176168f5e0c43a1320c6d", size = 212482262, upload-time = "2025-04-23T14:35:03.527Z" }, - { url = "https://files.pythonhosted.org/packages/ee/8d/b2939e5254be932db1a34b2bd099070c509e8887e0c5a90c498a917e4032/torch-2.7.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:30b7688a87239a7de83f269333651d8e582afffce6f591fff08c046f7787296e", size = 68574294, upload-time = "2025-04-23T14:34:47.098Z" }, - { url = "https://files.pythonhosted.org/packages/14/24/720ea9a66c29151b315ea6ba6f404650834af57a26b2a04af23ec246b2d5/torch-2.7.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:868ccdc11798535b5727509480cd1d86d74220cfdc42842c4617338c1109a205", size = 99015553, upload-time = "2025-04-23T14:34:41.075Z" }, - { url = "https://files.pythonhosted.org/packages/4b/27/285a8cf12bd7cd71f9f211a968516b07dcffed3ef0be585c6e823675ab91/torch-2.7.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b52347118116cf3dff2ab5a3c3dd97c719eb924ac658ca2a7335652076df708", size = 865046389, upload-time = "2025-04-23T14:32:01.16Z" }, - { url = "https://files.pythonhosted.org/packages/74/c8/2ab2b6eadc45554af8768ae99668c5a8a8552e2012c7238ded7e9e4395e1/torch-2.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:434cf3b378340efc87c758f250e884f34460624c0523fe5c9b518d205c91dd1b", size = 212490304, upload-time = "2025-04-23T14:33:57.108Z" }, - { url = "https://files.pythonhosted.org/packages/28/fd/74ba6fde80e2b9eef4237fe668ffae302c76f0e4221759949a632ca13afa/torch-2.7.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:edad98dddd82220465b106506bb91ee5ce32bd075cddbcf2b443dfaa2cbd83bf", size = 68856166, upload-time = "2025-04-23T14:34:04.012Z" }, - { url = "https://files.pythonhosted.org/packages/cb/b4/8df3f9fe6bdf59e56a0e538592c308d18638eb5f5dc4b08d02abb173c9f0/torch-2.7.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a885fc25afefb6e6eb18a7d1e8bfa01cc153e92271d980a49243b250d5ab6d9", size = 99091348, upload-time = "2025-04-23T14:33:48.975Z" }, - { url = "https://files.pythonhosted.org/packages/9d/f5/0bd30e9da04c3036614aa1b935a9f7e505a9e4f1f731b15e165faf8a4c74/torch-2.7.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:176300ff5bc11a5f5b0784e40bde9e10a35c4ae9609beed96b4aeb46a27f5fae", size = 865104023, upload-time = "2025-04-23T14:30:40.537Z" }, - { url = "https://files.pythonhosted.org/packages/d1/b7/2235d0c3012c596df1c8d39a3f4afc1ee1b6e318d469eda4c8bb68566448/torch-2.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d0ca446a93f474985d81dc866fcc8dccefb9460a29a456f79d99c29a78a66993", size = 212750916, upload-time = "2025-04-23T14:32:22.91Z" }, - { url = "https://files.pythonhosted.org/packages/90/48/7e6477cf40d48cc0a61fa0d41ee9582b9a316b12772fcac17bc1a40178e7/torch-2.7.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:27f5007bdf45f7bb7af7f11d1828d5c2487e030690afb3d89a651fd7036a390e", size = 68575074, upload-time = "2025-04-23T14:32:38.136Z" }, + { url = "https://files.pythonhosted.org/packages/6a/27/2e06cb52adf89fe6e020963529d17ed51532fc73c1e6d1b18420ef03338c/torch-2.7.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a103b5d782af5bd119b81dbcc7ffc6fa09904c423ff8db397a1e6ea8fd71508f", size = 99089441, upload-time = "2025-06-04T17:38:48.268Z" }, + { url = "https://files.pythonhosted.org/packages/0a/7c/0a5b3aee977596459ec45be2220370fde8e017f651fecc40522fd478cb1e/torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:fe955951bdf32d182ee8ead6c3186ad54781492bf03d547d31771a01b3d6fb7d", size = 821154516, upload-time = "2025-06-04T17:36:28.556Z" }, + { url = "https://files.pythonhosted.org/packages/f9/91/3d709cfc5e15995fb3fe7a6b564ce42280d3a55676dad672205e94f34ac9/torch-2.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:885453d6fba67d9991132143bf7fa06b79b24352f4506fd4d10b309f53454162", size = 216093147, upload-time = "2025-06-04T17:39:38.132Z" }, + { url = "https://files.pythonhosted.org/packages/92/f6/5da3918414e07da9866ecb9330fe6ffdebe15cb9a4c5ada7d4b6e0a6654d/torch-2.7.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:d72acfdb86cee2a32c0ce0101606f3758f0d8bb5f8f31e7920dc2809e963aa7c", size = 68630914, upload-time = "2025-06-04T17:39:31.162Z" }, + { url = "https://files.pythonhosted.org/packages/11/56/2eae3494e3d375533034a8e8cf0ba163363e996d85f0629441fa9d9843fe/torch-2.7.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:236f501f2e383f1cb861337bdf057712182f910f10aeaf509065d54d339e49b2", size = 99093039, upload-time = "2025-06-04T17:39:06.963Z" }, + { url = "https://files.pythonhosted.org/packages/e5/94/34b80bd172d0072c9979708ccd279c2da2f55c3ef318eceec276ab9544a4/torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:06eea61f859436622e78dd0cdd51dbc8f8c6d76917a9cf0555a333f9eac31ec1", size = 821174704, upload-time = "2025-06-04T17:37:03.799Z" }, + { url = "https://files.pythonhosted.org/packages/50/9e/acf04ff375b0b49a45511c55d188bcea5c942da2aaf293096676110086d1/torch-2.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:8273145a2e0a3c6f9fd2ac36762d6ee89c26d430e612b95a99885df083b04e52", size = 216095937, upload-time = "2025-06-04T17:39:24.83Z" }, + { url = "https://files.pythonhosted.org/packages/5b/2b/d36d57c66ff031f93b4fa432e86802f84991477e522adcdffd314454326b/torch-2.7.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aea4fc1bf433d12843eb2c6b2204861f43d8364597697074c8d38ae2507f8730", size = 68640034, upload-time = "2025-06-04T17:39:17.989Z" }, + { url = "https://files.pythonhosted.org/packages/87/93/fb505a5022a2e908d81fe9a5e0aa84c86c0d5f408173be71c6018836f34e/torch-2.7.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:27ea1e518df4c9de73af7e8a720770f3628e7f667280bce2be7a16292697e3fa", size = 98948276, upload-time = "2025-06-04T17:39:12.852Z" }, + { url = "https://files.pythonhosted.org/packages/56/7e/67c3fe2b8c33f40af06326a3d6ae7776b3e3a01daa8f71d125d78594d874/torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c33360cfc2edd976c2633b3b66c769bdcbbf0e0b6550606d188431c81e7dd1fc", size = 821025792, upload-time = "2025-06-04T17:34:58.747Z" }, + { url = "https://files.pythonhosted.org/packages/a1/37/a37495502bc7a23bf34f89584fa5a78e25bae7b8da513bc1b8f97afb7009/torch-2.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:d8bf6e1856ddd1807e79dc57e54d3335f2b62e6f316ed13ed3ecfe1fc1df3d8b", size = 216050349, upload-time = "2025-06-04T17:38:59.709Z" }, + { url = "https://files.pythonhosted.org/packages/3a/60/04b77281c730bb13460628e518c52721257814ac6c298acd25757f6a175c/torch-2.7.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:787687087412c4bd68d315e39bc1223f08aae1d16a9e9771d95eabbb04ae98fb", size = 68645146, upload-time = "2025-06-04T17:38:52.97Z" }, + { url = "https://files.pythonhosted.org/packages/66/81/e48c9edb655ee8eb8c2a6026abdb6f8d2146abd1f150979ede807bb75dcb/torch-2.7.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:03563603d931e70722dce0e11999d53aa80a375a3d78e6b39b9f6805ea0a8d28", size = 98946649, upload-time = "2025-06-04T17:38:43.031Z" }, + { url = "https://files.pythonhosted.org/packages/3a/24/efe2f520d75274fc06b695c616415a1e8a1021d87a13c68ff9dce733d088/torch-2.7.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:d632f5417b6980f61404a125b999ca6ebd0b8b4bbdbb5fbbba44374ab619a412", size = 821033192, upload-time = "2025-06-04T17:38:09.146Z" }, + { url = "https://files.pythonhosted.org/packages/dd/d9/9c24d230333ff4e9b6807274f6f8d52a864210b52ec794c5def7925f4495/torch-2.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:23660443e13995ee93e3d844786701ea4ca69f337027b05182f5ba053ce43b38", size = 216055668, upload-time = "2025-06-04T17:38:36.253Z" }, + { url = "https://files.pythonhosted.org/packages/95/bf/e086ee36ddcef9299f6e708d3b6c8487c1651787bb9ee2939eb2a7f74911/torch-2.7.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0da4f4dba9f65d0d203794e619fe7ca3247a55ffdcbd17ae8fb83c8b2dc9b585", size = 68925988, upload-time = "2025-06-04T17:38:29.273Z" }, + { url = "https://files.pythonhosted.org/packages/69/6a/67090dcfe1cf9048448b31555af6efb149f7afa0a310a366adbdada32105/torch-2.7.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:e08d7e6f21a617fe38eeb46dd2213ded43f27c072e9165dc27300c9ef9570934", size = 99028857, upload-time = "2025-06-04T17:37:50.956Z" }, + { url = "https://files.pythonhosted.org/packages/90/1c/48b988870823d1cc381f15ec4e70ed3d65e043f43f919329b0045ae83529/torch-2.7.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:30207f672328a42df4f2174b8f426f354b2baa0b7cca3a0adb3d6ab5daf00dc8", size = 821098066, upload-time = "2025-06-04T17:37:33.939Z" }, + { url = "https://files.pythonhosted.org/packages/7b/eb/10050d61c9d5140c5dc04a89ed3257ef1a6b93e49dd91b95363d757071e0/torch-2.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:79042feca1c634aaf6603fe6feea8c6b30dfa140a6bbc0b973e2260c7e79a22e", size = 216336310, upload-time = "2025-06-04T17:36:09.862Z" }, + { url = "https://files.pythonhosted.org/packages/b1/29/beb45cdf5c4fc3ebe282bf5eafc8dfd925ead7299b3c97491900fe5ed844/torch-2.7.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:988b0cbc4333618a1056d2ebad9eb10089637b659eb645434d0809d8d937b946", size = 68645708, upload-time = "2025-06-04T17:34:39.852Z" }, ] [[package]] @@ -6249,17 +6249,17 @@ torch = [ [[package]] name = "triton" -version = "3.3.0" +version = "3.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "setuptools", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/76/04/d54d3a6d077c646624dc9461b0059e23fd5d30e0dbe67471e3654aec81f9/triton-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fad99beafc860501d7fcc1fb7045d9496cbe2c882b1674640304949165a916e7", size = 156441993, upload-time = "2025-04-09T20:27:25.107Z" }, - { url = "https://files.pythonhosted.org/packages/3c/c5/4874a81131cc9e934d88377fbc9d24319ae1fb540f3333b4e9c696ebc607/triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3161a2bf073d6b22c4e2f33f951f3e5e3001462b2570e6df9cd57565bdec2984", size = 156528461, upload-time = "2025-04-09T20:27:32.599Z" }, - { url = "https://files.pythonhosted.org/packages/11/53/ce18470914ab6cfbec9384ee565d23c4d1c55f0548160b1c7b33000b11fd/triton-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b68c778f6c4218403a6bd01be7484f6dc9e20fe2083d22dd8aef33e3b87a10a3", size = 156504509, upload-time = "2025-04-09T20:27:40.413Z" }, - { url = "https://files.pythonhosted.org/packages/7d/74/4bf2702b65e93accaa20397b74da46fb7a0356452c1bb94dbabaf0582930/triton-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47bc87ad66fa4ef17968299acacecaab71ce40a238890acc6ad197c3abe2b8f1", size = 156516468, upload-time = "2025-04-09T20:27:48.196Z" }, - { url = "https://files.pythonhosted.org/packages/0a/93/f28a696fa750b9b608baa236f8225dd3290e5aff27433b06143adc025961/triton-3.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce4700fc14032af1e049005ae94ba908e71cd6c2df682239aed08e49bc71b742", size = 156580729, upload-time = "2025-04-09T20:27:55.424Z" }, + { url = "https://files.pythonhosted.org/packages/8d/a9/549e51e9b1b2c9b854fd761a1d23df0ba2fbc60bd0c13b489ffa518cfcb7/triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e", size = 155600257, upload-time = "2025-05-29T23:39:36.085Z" }, + { url = "https://files.pythonhosted.org/packages/21/2f/3e56ea7b58f80ff68899b1dbe810ff257c9d177d288c6b0f55bf2fe4eb50/triton-3.3.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b31e3aa26f8cb3cc5bf4e187bf737cbacf17311e1112b781d4a059353dfd731b", size = 155689937, upload-time = "2025-05-29T23:39:44.182Z" }, + { url = "https://files.pythonhosted.org/packages/24/5f/950fb373bf9c01ad4eb5a8cd5eaf32cdf9e238c02f9293557a2129b9c4ac/triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9999e83aba21e1a78c1f36f21bce621b77bcaa530277a50484a7cb4a822f6e43", size = 155669138, upload-time = "2025-05-29T23:39:51.771Z" }, + { url = "https://files.pythonhosted.org/packages/74/1f/dfb531f90a2d367d914adfee771babbd3f1a5b26c3f5fbc458dee21daa78/triton-3.3.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b89d846b5a4198317fec27a5d3a609ea96b6d557ff44b56c23176546023c4240", size = 155673035, upload-time = "2025-05-29T23:40:02.468Z" }, + { url = "https://files.pythonhosted.org/packages/28/71/bd20ffcb7a64c753dc2463489a61bf69d531f308e390ad06390268c4ea04/triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42", size = 155735832, upload-time = "2025-05-29T23:40:10.522Z" }, ] [[package]] pFad - Phonifier reborn

    Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

    Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


    Alternative Proxies:

    Alternative Proxy

    pFad Proxy

    pFad v3 Proxy

    pFad v4 Proxy