Skip to content

Commit

Permalink
[ML] Create and inject APM Inference Metrics (elastic#111293)
Browse files Browse the repository at this point in the history
We are migrating from in-memory cumulative counter to an Time Series
Data Stream delta counter.  The goal is to avoid metrics suddenly
dropping to zero when a node restarts, hopefully increasing accuracy of
the metric.

Co-authored-by: Jonathan Buttner <[email protected]>
  • Loading branch information
prwhelan and jonathan-buttner authored Jul 29, 2024
1 parent 38f301a commit 1a939e9
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 307 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.xcontent.ToXContentObject;

Expand Down Expand Up @@ -48,5 +49,6 @@ default DenseVectorFieldMapper.ElementType elementType() {
* be chosen when initializing a deployment within their service. In this situation, return null.
* @return the model used to perform inference or null if the model is not defined
*/
@Nullable
String modelId();
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.node.PluginComponentBinding;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.ExtensiblePlugin;
import org.elasticsearch.plugins.MapperPlugin;
Expand Down Expand Up @@ -84,8 +85,8 @@
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.InferenceAPMStats;
import org.elasticsearch.xpack.inference.telemetry.StatsMap;
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -196,10 +197,10 @@ public Collection<?> createComponents(PluginServices services) {
var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry);
shardBulkInferenceActionFilter.set(actionFilter);

var statsFactory = new InferenceAPMStats.Factory(services.telemetryProvider().getMeterRegistry());
var statsMap = new StatsMap<>(InferenceAPMStats::key, statsFactory::newInferenceRequestAPMCounter);
var meterRegistry = services.telemetryProvider().getMeterRegistry();
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));

return List.of(modelRegistry, registry, httpClientManager, statsMap);
return List.of(modelRegistry, registry, httpClientManager, stats);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,26 @@
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {

private final ModelRegistry modelRegistry;
private final InferenceServiceRegistry serviceRegistry;
private final InferenceStats inferenceStats;

@Inject
public TransportInferenceAction(
TransportService transportService,
ActionFilters actionFilters,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats
) {
super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE);
this.modelRegistry = modelRegistry;
this.serviceRegistry = serviceRegistry;
this.inferenceStats = inferenceStats;
}

@Override
Expand Down Expand Up @@ -76,6 +80,7 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
unparsedModel.settings(),
unparsedModel.secrets()
);
inferenceStats.incrementRequestCount(model);
inferOnService(model, request, service.get(), delegate);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static CohereEmbeddingsModel of(CohereEmbeddingsModel model, Map<String,
}

public CohereEmbeddingsModel(
String modelId,
String inferenceId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Expand All @@ -37,7 +37,7 @@ public CohereEmbeddingsModel(
ConfigurationParseContext context
) {
this(
modelId,
inferenceId,
taskType,
service,
CohereEmbeddingsServiceSettings.fromMap(serviceSettings, context),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public OpenAiEmbeddingsServiceSettings(
@Nullable RateLimitSettings rateLimitSettings
) {
this.uri = uri;
this.modelId = modelId;
this.modelId = Objects.requireNonNull(modelId);
this.organizationId = organizationId;
this.similarity = similarity;
this.dimensions = dimensions;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.telemetry;

import org.elasticsearch.inference.Model;
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.MeterRegistry;

import java.util.HashMap;
import java.util.Objects;

public class ApmInferenceStats implements InferenceStats {
private final LongCounter inferenceAPMRequestCounter;

public ApmInferenceStats(LongCounter inferenceAPMRequestCounter) {
this.inferenceAPMRequestCounter = Objects.requireNonNull(inferenceAPMRequestCounter);
}

@Override
public void incrementRequestCount(Model model) {
var service = model.getConfigurations().getService();
var taskType = model.getTaskType();
var modelId = model.getServiceSettings().modelId();

var attributes = new HashMap<String, Object>(5);
attributes.put("service", service);
attributes.put("task_type", taskType.toString());
if (modelId != null) {
attributes.put("model_id", modelId);
}

inferenceAPMRequestCounter.incrementBy(1, attributes);
}

public static ApmInferenceStats create(MeterRegistry meterRegistry) {
return new ApmInferenceStats(
meterRegistry.registerLongCounter(
"es.inference.requests.count.total",
"Inference API request counts for a particular service, task type, model ID",
"operations"
)
);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,14 @@
package org.elasticsearch.xpack.inference.telemetry;

import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.InferenceRequestStats;

import java.util.Objects;
import java.util.concurrent.atomic.LongAdder;
public interface InferenceStats {

public class InferenceStats implements Stats {
protected final String service;
protected final TaskType taskType;
protected final String modelId;
protected final LongAdder counter = new LongAdder();
/**
* Increment the counter for a particular value in a thread safe manner.
* @param model the model to increment request count for
*/
void incrementRequestCount(Model model);

public static String key(Model model) {
StringBuilder builder = new StringBuilder();
builder.append(model.getConfigurations().getService());
builder.append(":");
builder.append(model.getTaskType());

if (model.getServiceSettings().modelId() != null) {
builder.append(":");
builder.append(model.getServiceSettings().modelId());
}

return builder.toString();
}

public InferenceStats(Model model) {
Objects.requireNonNull(model);

service = model.getConfigurations().getService();
taskType = model.getTaskType();
modelId = model.getServiceSettings().modelId();
}

@Override
public void increment() {
counter.increment();
}

@Override
public long getCount() {
return counter.sum();
}

@Override
public InferenceRequestStats toSerializableForm() {
return new InferenceRequestStats(service, taskType, modelId, getCount());
}
InferenceStats NOOP = model -> {};
}

This file was deleted.

This file was deleted.

Loading

0 comments on commit 1a939e9

Please sign in to comment.