Skip to content

Commit

Permalink
Conversation builder consistency changes (#1423)
Browse files Browse the repository at this point in the history
* Corrected several unit tests

Signed-off-by: Whit Waldo <[email protected]>

* Updated extension name for consistency

Signed-off-by: Whit Waldo <[email protected]>

* Updated registration name for consistency

Signed-off-by: Whit Waldo <[email protected]>

---------

Signed-off-by: Whit Waldo <[email protected]>
  • Loading branch information
WhitWaldo authored Dec 11, 2024
1 parent 8bc0318 commit ccf2bfd
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/AI/ConversationalAI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

var builder = WebApplication.CreateBuilder(args);

builder.Services.AddDaprAiConversation();
builder.Services.AddDaprConversationClient();

var app = builder.Build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static class DaprAiConversationBuilderExtensions
/// Registers the necessary functionality for the Dapr AI conversation functionality.
/// </summary>
/// <returns></returns>
public static IDaprAiConversationBuilder AddDaprAiConversation(this IServiceCollection services, Action<IServiceProvider, DaprConversationClientBuilder>? configure = null, ServiceLifetime lifetime = ServiceLifetime.Singleton)
public static IDaprAiConversationBuilder AddDaprConversationClient(this IServiceCollection services, Action<IServiceProvider, DaprConversationClientBuilder>? configure = null, ServiceLifetime lifetime = ServiceLifetime.Singleton)
{
ArgumentNullException.ThrowIfNull(services, nameof(services));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Dapr.AI.Conversation;
using Dapr.AI.Conversation.Extensions;
using Microsoft.Extensions.Configuration;
Expand All @@ -34,7 +36,7 @@ public void AddDaprConversationClient_FromIConfiguration()
var services = new ServiceCollection();
services.AddSingleton<IConfiguration>(configuration);

services.AddDaprAiConversation();
services.AddDaprConversationClient();

var app = services.BuildServiceProvider();

Expand All @@ -45,18 +47,66 @@ public void AddDaprConversationClient_FromIConfiguration()
}

[Fact]
public void AddDaprAiConversation_WithoutConfigure_ShouldAddServices()
public void AddDaprConversationClient_RegistersDaprClientOnlyOnce()
{
var services = new ServiceCollection();
var builder = services.AddDaprAiConversation();

var clientBuilder = new Action<IServiceProvider, DaprConversationClientBuilder>((sp, builder) =>
{
builder.UseDaprApiToken("abc");
});

services.AddDaprConversationClient(); //Sets a default API token value of an empty string
services.AddDaprConversationClient(clientBuilder); //Sets the API token value

var serviceProvider = services.BuildServiceProvider();
var daprConversationClient = serviceProvider.GetService<DaprConversationClient>();

Assert.NotNull(daprConversationClient!.HttpClient);
Assert.False(daprConversationClient.HttpClient.DefaultRequestHeaders.TryGetValues("dapr-api-token", out var _));
}

[Fact]
public void AddDaprConversationClient_RegistersUsingDependencyFromIServiceProvider()
{
var services = new ServiceCollection();
services.AddSingleton<TestSecretRetriever>();
services.AddDaprConversationClient((provider, builder) =>
{
var configProvider = provider.GetRequiredService<TestSecretRetriever>();
var apiToken = configProvider.GetApiTokenValue();
builder.UseDaprApiToken(apiToken);
});

var serviceProvider = services.BuildServiceProvider();
var client = serviceProvider.GetRequiredService<DaprConversationClient>();

//Validate it's set on the GrpcClient - note that it doesn't get set on the HttpClient
Assert.NotNull(client);
Assert.NotNull(client.DaprApiToken);
Assert.Equal("abcdef", client.DaprApiToken);
Assert.NotNull(client.HttpClient);

if (!client.HttpClient.DefaultRequestHeaders.TryGetValues("dapr-api-token", out var daprApiToken))
{
Assert.Fail();
}
Assert.Equal("abcdef", daprApiToken.FirstOrDefault());
}

[Fact]
public void AddDaprConversationClient_WithoutConfigure_ShouldAddServices()
{
var services = new ServiceCollection();
var builder = services.AddDaprConversationClient();
Assert.NotNull(builder);
}

[Fact]
public void AddDaprAiConversation_RegistersIHttpClientFactory()
public void AddDaprConversationClient_RegistersIHttpClientFactory()
{
var services = new ServiceCollection();
services.AddDaprAiConversation();
services.AddDaprConversationClient();
var serviceProvider = services.BuildServiceProvider();

var httpClientFactory = serviceProvider.GetService<IHttpClientFactory>();
Expand All @@ -67,9 +117,66 @@ public void AddDaprAiConversation_RegistersIHttpClientFactory()
}

[Fact]
public void AddDaprAiConversation_NullServices_ShouldThrowException()
public void AddDaprConversationClient_NullServices_ShouldThrowException()
{
IServiceCollection services = null;
Assert.Throws<ArgumentNullException>(() => services.AddDaprAiConversation());
Assert.Throws<ArgumentNullException>(() => services.AddDaprConversationClient());
}

[Fact]
public void AddDaprConversationClient_ShouldRegisterSingleton_WhenLifetimeIsSingleton()
{
var services = new ServiceCollection();

services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Singleton);
var serviceProvider = services.BuildServiceProvider();

var daprConversationClient1 = serviceProvider.GetService<DaprConversationClient>();
var daprConversationClient2 = serviceProvider.GetService<DaprConversationClient>();

Assert.NotNull(daprConversationClient1);
Assert.NotNull(daprConversationClient2);

Assert.Same(daprConversationClient1, daprConversationClient2);
}

[Fact]
public async Task AddDaprConversationClient_ShouldRegisterScoped_WhenLifetimeIsScoped()
{
var services = new ServiceCollection();

services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Scoped);
var serviceProvider = services.BuildServiceProvider();

await using var scope1 = serviceProvider.CreateAsyncScope();
var daprConversationClient1 = scope1.ServiceProvider.GetService<DaprConversationClient>();

await using var scope2 = serviceProvider.CreateAsyncScope();
var daprConversationClient2 = scope2.ServiceProvider.GetService<DaprConversationClient>();

Assert.NotNull(daprConversationClient1);
Assert.NotNull(daprConversationClient2);
Assert.NotSame(daprConversationClient1, daprConversationClient2);
}

[Fact]
public void AddDaprConversationClient_ShouldRegisterTransient_WhenLifetimeIsTransient()
{
var services = new ServiceCollection();

services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Transient);
var serviceProvider = services.BuildServiceProvider();

var daprConversationClient1 = serviceProvider.GetService<DaprConversationClient>();
var daprConversationClient2 = serviceProvider.GetService<DaprConversationClient>();

Assert.NotNull(daprConversationClient1);
Assert.NotNull(daprConversationClient2);
Assert.NotSame(daprConversationClient1, daprConversationClient2);
}

private class TestSecretRetriever
{
public string GetApiTokenValue() => "abcdef";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void AddDaprJobsClient_RegistersUsingDependencyFromIServiceProvider()
services.AddDaprJobsClient((provider, builder) =>
{
var configProvider = provider.GetRequiredService<TestSecretRetriever>();
var apiToken = TestSecretRetriever.GetApiTokenValue();
var apiToken = configProvider.GetApiTokenValue();
builder.UseDaprApiToken(apiToken);
});

Expand All @@ -114,7 +114,7 @@ public void RegisterJobsClient_ShouldRegisterSingleton_WhenLifetimeIsSingleton()
{
var services = new ServiceCollection();

services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Singleton);
services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Singleton);
var serviceProvider = services.BuildServiceProvider();

var daprJobsClient1 = serviceProvider.GetService<DaprJobsClient>();
Expand All @@ -131,7 +131,7 @@ public async Task RegisterJobsClient_ShouldRegisterScoped_WhenLifetimeIsScoped()
{
var services = new ServiceCollection();

services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Scoped);
services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Scoped);
var serviceProvider = services.BuildServiceProvider();

await using var scope1 = serviceProvider.CreateAsyncScope();
Expand All @@ -150,7 +150,7 @@ public void RegisterJobsClient_ShouldRegisterTransient_WhenLifetimeIsTransient()
{
var services = new ServiceCollection();

services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Transient);
services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Transient);
var serviceProvider = services.BuildServiceProvider();

var daprJobsClient1 = serviceProvider.GetService<DaprJobsClient>();
Expand All @@ -163,6 +163,6 @@ public void RegisterJobsClient_ShouldRegisterTransient_WhenLifetimeIsTransient()

private class TestSecretRetriever
{
public static string GetApiTokenValue() => "abcdef";
public string GetApiTokenValue() => "abcdef";
}
}

0 comments on commit ccf2bfd

Please sign in to comment.