From 1a939e922e08cd02c5c76030082ea5f5f4f68073 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 29 Jul 2024 09:46:02 -0400 Subject: [PATCH] [ML] Create and inject APM Inference Metrics (#111293) We are migrating from in-memory cumulative counter to an Time Series Data Stream delta counter. The goal is to avoid metrics suddenly dropping to zero when a node restarts, hopefully increasing accuracy of the metric. Co-authored-by: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> --- .../inference/ServiceSettings.java | 2 + .../xpack/inference/InferencePlugin.java | 11 +- .../action/TransportInferenceAction.java | 7 +- .../embeddings/CohereEmbeddingsModel.java | 4 +- .../OpenAiEmbeddingsServiceSettings.java | 2 +- .../telemetry/ApmInferenceStats.java | 49 ++++++++ .../telemetry/InferenceAPMStats.java | 47 ------- .../inference/telemetry/InferenceStats.java | 52 ++------ .../xpack/inference/telemetry/Stats.java | 30 ----- .../xpack/inference/telemetry/StatsMap.java | 57 --------- .../telemetry/ApmInferenceStatsTests.java | 69 ++++++++++ .../inference/telemetry/StatsMapTests.java | 119 ------------------ 12 files changed, 142 insertions(+), 307 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStats.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/Stats.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/StatsMap.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStatsTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/StatsMapTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java index 34a58f83963ce..58e87105f70a3 100644 --- a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java @@ -9,6 +9,7 @@ package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.VersionedNamedWriteable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.xcontent.ToXContentObject; @@ -48,5 +49,6 @@ default DenseVectorFieldMapper.ElementType elementType() { * be chosen when initializing a deployment within their service. In this situation, return null. * @return the model used to perform inference or null if the model is not defined */ + @Nullable String modelId(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index fce2c54c535c9..ec9398358d180 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -26,6 +26,7 @@ import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.node.PluginComponentBinding; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.MapperPlugin; @@ -84,8 +85,8 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; -import org.elasticsearch.xpack.inference.telemetry.InferenceAPMStats; -import org.elasticsearch.xpack.inference.telemetry.StatsMap; +import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.ArrayList; import java.util.Collection; @@ -196,10 +197,10 @@ public Collection createComponents(PluginServices services) { var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); shardBulkInferenceActionFilter.set(actionFilter); - var statsFactory = new InferenceAPMStats.Factory(services.telemetryProvider().getMeterRegistry()); - var statsMap = new StatsMap<>(InferenceAPMStats::key, statsFactory::newInferenceRequestAPMCounter); + var meterRegistry = services.telemetryProvider().getMeterRegistry(); + var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry)); - return List.of(modelRegistry, registry, httpClientManager, statsMap); + return List.of(modelRegistry, registry, httpClientManager, stats); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 575697b5d0d39..b7fff3b704695 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -21,22 +21,26 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; public class TransportInferenceAction extends HandledTransportAction { private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; + private final InferenceStats inferenceStats; @Inject public TransportInferenceAction( TransportService transportService, ActionFilters actionFilters, ModelRegistry modelRegistry, - InferenceServiceRegistry serviceRegistry + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats ) { super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.modelRegistry = modelRegistry; this.serviceRegistry = serviceRegistry; + this.inferenceStats = inferenceStats; } @Override @@ -76,6 +80,7 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe unparsedModel.settings(), unparsedModel.secrets() ); + inferenceStats.incrementRequestCount(model); inferOnService(model, request, service.get(), delegate); }); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index 538d88a59ca76..fea5226bf9c6f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -28,7 +28,7 @@ public static CohereEmbeddingsModel of(CohereEmbeddingsModel model, Map serviceSettings, @@ -37,7 +37,7 @@ public CohereEmbeddingsModel( ConfigurationParseContext context ) { this( - modelId, + inferenceId, taskType, service, CohereEmbeddingsServiceSettings.fromMap(serviceSettings, context), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index d474e935fbda7..6ef1f6f0feefe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -150,7 +150,7 @@ public OpenAiEmbeddingsServiceSettings( @Nullable RateLimitSettings rateLimitSettings ) { this.uri = uri; - this.modelId = modelId; + this.modelId = Objects.requireNonNull(modelId); this.organizationId = organizationId; this.similarity = similarity; this.dimensions = dimensions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java new file mode 100644 index 0000000000000..ae14a0792dead --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.telemetry; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.telemetry.metric.LongCounter; +import org.elasticsearch.telemetry.metric.MeterRegistry; + +import java.util.HashMap; +import java.util.Objects; + +public class ApmInferenceStats implements InferenceStats { + private final LongCounter inferenceAPMRequestCounter; + + public ApmInferenceStats(LongCounter inferenceAPMRequestCounter) { + this.inferenceAPMRequestCounter = Objects.requireNonNull(inferenceAPMRequestCounter); + } + + @Override + public void incrementRequestCount(Model model) { + var service = model.getConfigurations().getService(); + var taskType = model.getTaskType(); + var modelId = model.getServiceSettings().modelId(); + + var attributes = new HashMap(5); + attributes.put("service", service); + attributes.put("task_type", taskType.toString()); + if (modelId != null) { + attributes.put("model_id", modelId); + } + + inferenceAPMRequestCounter.incrementBy(1, attributes); + } + + public static ApmInferenceStats create(MeterRegistry meterRegistry) { + return new ApmInferenceStats( + meterRegistry.registerLongCounter( + "es.inference.requests.count.total", + "Inference API request counts for a particular service, task type, model ID", + "operations" + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStats.java deleted file mode 100644 index 76977fef76045..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceAPMStats.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.telemetry; - -import org.elasticsearch.inference.Model; -import org.elasticsearch.telemetry.metric.LongCounter; -import org.elasticsearch.telemetry.metric.MeterRegistry; - -import java.util.Map; -import java.util.Objects; - -public class InferenceAPMStats extends InferenceStats { - - private final LongCounter inferenceAPMRequestCounter; - - public InferenceAPMStats(Model model, MeterRegistry meterRegistry) { - super(model); - this.inferenceAPMRequestCounter = meterRegistry.registerLongCounter( - "es.inference.requests.count", - "Inference API request counts for a particular service, task type, model ID", - "operations" - ); - } - - @Override - public void increment() { - super.increment(); - inferenceAPMRequestCounter.incrementBy(1, Map.of("service", service, "task_type", taskType.toString(), "model_id", modelId)); - } - - public static final class Factory { - private final MeterRegistry meterRegistry; - - public Factory(MeterRegistry meterRegistry) { - this.meterRegistry = Objects.requireNonNull(meterRegistry); - } - - public InferenceAPMStats newInferenceRequestAPMCounter(Model model) { - return new InferenceAPMStats(model, meterRegistry); - } - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java index d639f9da71f56..d080e818e45fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java @@ -8,52 +8,14 @@ package org.elasticsearch.xpack.inference.telemetry; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.InferenceRequestStats; -import java.util.Objects; -import java.util.concurrent.atomic.LongAdder; +public interface InferenceStats { -public class InferenceStats implements Stats { - protected final String service; - protected final TaskType taskType; - protected final String modelId; - protected final LongAdder counter = new LongAdder(); + /** + * Increment the counter for a particular value in a thread safe manner. + * @param model the model to increment request count for + */ + void incrementRequestCount(Model model); - public static String key(Model model) { - StringBuilder builder = new StringBuilder(); - builder.append(model.getConfigurations().getService()); - builder.append(":"); - builder.append(model.getTaskType()); - - if (model.getServiceSettings().modelId() != null) { - builder.append(":"); - builder.append(model.getServiceSettings().modelId()); - } - - return builder.toString(); - } - - public InferenceStats(Model model) { - Objects.requireNonNull(model); - - service = model.getConfigurations().getService(); - taskType = model.getTaskType(); - modelId = model.getServiceSettings().modelId(); - } - - @Override - public void increment() { - counter.increment(); - } - - @Override - public long getCount() { - return counter.sum(); - } - - @Override - public InferenceRequestStats toSerializableForm() { - return new InferenceRequestStats(service, taskType, modelId, getCount()); - } + InferenceStats NOOP = model -> {}; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/Stats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/Stats.java deleted file mode 100644 index bb1e9c98fc2cb..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/Stats.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.telemetry; - -import org.elasticsearch.xpack.core.inference.SerializableStats; - -public interface Stats { - - /** - * Increase the counter by one. - */ - void increment(); - - /** - * Return the current value of the counter. - * @return the current value of the counter - */ - long getCount(); - - /** - * Convert the object into a serializable form that can be written across nodes and returned in xcontent format. - * @return the serializable format of the object - */ - SerializableStats toSerializableForm(); -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/StatsMap.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/StatsMap.java deleted file mode 100644 index 1cfecfb4507d6..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/StatsMap.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.telemetry; - -import org.elasticsearch.xpack.core.inference.SerializableStats; - -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.function.Function; -import java.util.stream.Collectors; - -/** - * A map to provide tracking incrementing statistics. - * - * @param The input to derive the keys and values for the map - * @param The type of the values stored in the map - */ -public class StatsMap { - - private final ConcurrentMap stats = new ConcurrentHashMap<>(); - private final Function keyCreator; - private final Function valueCreator; - - /** - * @param keyCreator a function for creating a key in the map based on the input provided - * @param valueCreator a function for creating a value in the map based on the input provided - */ - public StatsMap(Function keyCreator, Function valueCreator) { - this.keyCreator = Objects.requireNonNull(keyCreator); - this.valueCreator = Objects.requireNonNull(valueCreator); - } - - /** - * Increment the counter for a particular value in a thread safe manner. - * @param input the input to derive the appropriate key in the map - */ - public void increment(Input input) { - var value = stats.computeIfAbsent(keyCreator.apply(input), key -> valueCreator.apply(input)); - value.increment(); - } - - /** - * Build a map that can be serialized. This takes a snapshot of the current state. Any concurrent calls to increment may or may not - * be represented in the resulting serializable map. - * @return a map that is more easily serializable - */ - public Map toSerializableMap() { - return stats.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSerializableForm())); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStatsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStatsTests.java new file mode 100644 index 0000000000000..1a5aba5f89ad2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStatsTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.telemetry; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.telemetry.metric.LongCounter; +import org.elasticsearch.telemetry.metric.MeterRegistry; +import org.elasticsearch.test.ESTestCase; + +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ApmInferenceStatsTests extends ESTestCase { + + public void testRecordWithModel() { + var longCounter = mock(LongCounter.class); + + var stats = new ApmInferenceStats(longCounter); + + stats.incrementRequestCount(model("service", TaskType.ANY, "modelId")); + + verify(longCounter).incrementBy( + eq(1L), + eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId")) + ); + } + + public void testRecordWithoutModel() { + var longCounter = mock(LongCounter.class); + + var stats = new ApmInferenceStats(longCounter); + + stats.incrementRequestCount(model("service", TaskType.ANY, null)); + + verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString()))); + } + + public void testCreation() { + assertNotNull(ApmInferenceStats.create(MeterRegistry.NOOP)); + } + + private Model model(String service, TaskType taskType, String modelId) { + var configuration = mock(ModelConfigurations.class); + when(configuration.getService()).thenReturn(service); + var settings = mock(ServiceSettings.class); + if (modelId != null) { + when(settings.modelId()).thenReturn(modelId); + } + + var model = mock(Model.class); + when(model.getTaskType()).thenReturn(taskType); + when(model.getConfigurations()).thenReturn(configuration); + when(model.getServiceSettings()).thenReturn(settings); + + return model; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/StatsMapTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/StatsMapTests.java deleted file mode 100644 index fcd8d3d7cefbc..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/StatsMapTests.java +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.telemetry; - -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests; -import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettingsTests; - -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class StatsMapTests extends ESTestCase { - public void testAddingEntry_InitializesTheCountToOne() { - var stats = new StatsMap<>(InferenceStats::key, InferenceStats::new); - - stats.increment( - new OpenAiEmbeddingsModel( - "inference_id", - TaskType.TEXT_EMBEDDING, - "openai", - OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap("modelId", null, null), - OpenAiEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), - null, - ConfigurationParseContext.REQUEST - ) - ); - - var converted = stats.toSerializableMap(); - - assertThat( - converted, - is( - Map.of( - "openai:text_embedding:modelId", - new org.elasticsearch.xpack.core.inference.InferenceRequestStats("openai", TaskType.TEXT_EMBEDDING, "modelId", 1) - ) - ) - ); - } - - public void testIncrementingWithSeparateModels_IncrementsTheCounterToTwo() { - var stats = new StatsMap<>(InferenceStats::key, InferenceStats::new); - - var model1 = new OpenAiEmbeddingsModel( - "inference_id", - TaskType.TEXT_EMBEDDING, - "openai", - OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap("modelId", null, null), - OpenAiEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), - null, - ConfigurationParseContext.REQUEST - ); - - var model2 = new OpenAiEmbeddingsModel( - "inference_id", - TaskType.TEXT_EMBEDDING, - "openai", - OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap("modelId", null, null), - OpenAiEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), - null, - ConfigurationParseContext.REQUEST - ); - - stats.increment(model1); - stats.increment(model2); - - var converted = stats.toSerializableMap(); - - assertThat( - converted, - is( - Map.of( - "openai:text_embedding:modelId", - new org.elasticsearch.xpack.core.inference.InferenceRequestStats("openai", TaskType.TEXT_EMBEDDING, "modelId", 2) - ) - ) - ); - } - - public void testNullModelId_ResultsInKeyWithout() { - var stats = new StatsMap<>(InferenceStats::key, InferenceStats::new); - - stats.increment( - new CohereEmbeddingsModel( - "inference_id", - TaskType.TEXT_EMBEDDING, - "cohere", - CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null), - CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(null, null), - null, - ConfigurationParseContext.REQUEST - ) - ); - - var converted = stats.toSerializableMap(); - - assertThat( - converted, - is( - Map.of( - "cohere:text_embedding", - new org.elasticsearch.xpack.core.inference.InferenceRequestStats("cohere", TaskType.TEXT_EMBEDDING, null, 1) - ) - ) - ); - } -}