Skip to content

Commit

Permalink
Update Elasticsearch spring implementation for 0.34 (#37)
Browse files Browse the repository at this point in the history
This new implementation expects to have an Elasticsearch Rest Client
bean (recommended) or it will create it otherwise.

This updates brings:

* The support for the new way to implement the Elasticsearch embbeding
store as we now pass a rest client instead of "just" properties
(deprecated in langchain4j/langchain4j#712)
* The removal of previously needed number of dimensions (deprecated in
langchain4j/langchain4j#712)
* We add a `checkSslCertificates` property which should be only used in
tests.
* We add a `caCertificateAsBase64String` property which can be used in
tests when using a self-signed certificate.

About tests:

* We don't disable security anymore as it's not a good practice. We
should educate users instead.
* We don't wait anymore for 1s for documents to be available, but
instead we call the refresh API to make the documents immediately
visible. So we change the `awaitUntilPersisted()` method to
`awaitUntilPersisted(ApplicationContext)` so we can fetch the
Elasticsearch Rest Client bean.
* We allow testing against a running cluster (local/cloud) and set from
the CLI the URL (`ELASTICSEARCH_URL`), api key
(`ELASTICSEARCH_API_KEY`), username (`ELASTICSEARCH_USERNAME`) and
password (`ELASTICSEARCH_PASSWORD`). Note that when running with a
self-signed certificate, the BASE64 version of the CA should be provided
(`ELASTICSEARCH_CA_CERTIFICATE`).
* We fetch the version of the cluster used to run the tests from the
`pom.xml` file using `elastic.version` property.
* We upgrade testcontainers to 1.20.1
  • Loading branch information
dadoonet authored Sep 2, 2024
1 parent 818b4a5 commit e052f56
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.ApplicationContext;

import java.util.List;

Expand Down Expand Up @@ -53,7 +54,7 @@ void should_provide_embedding_store_without_embedding_model() {
String id = embeddingStore.add(embedding, segment);
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(context);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand Down Expand Up @@ -82,7 +83,7 @@ void should_provide_embedding_store_with_embedding_model() {
String id = embeddingStore.add(embedding, segment);
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(context);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand All @@ -95,7 +96,7 @@ void should_provide_embedding_store_with_embedding_model() {
});
}

protected void awaitUntilPersisted() {
protected void awaitUntilPersisted(ApplicationContext context) {

}
}
44 changes: 44 additions & 0 deletions langchian4j-elasticsearch-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@
<name>LangChain4j Spring Boot starter for Elasticsearch</name>
<packaging>jar</packaging>

<properties>
<!-- For tests only -->
<elastic.version>8.14.3</elastic.version>
<!-- You can run the tests using
# Elasticsearch Cloud
mvn verify -DELASTICSEARCH_URL=https://URL.elastic-cloud.com -DELASTICSEARCH_API_KEY=THE_KEY
# Elasticsearch running already Locally with a BASE64 encoded SSL CA Certificate
mvn verify -DELASTICSEARCH_URL=https://localhost:9200 -DELASTICSEARCH_PASSWORD=changeme -DELASTICSEARCH_CA_CERTIFICATE=BASE64-CONTENT
# Elasticsearch started by Testcontainers
mvn verify
# Elasticsearch started by Testcontainers with a specific password
mvn verify -DELASTICSEARCH_PASSWORD=changeme
-->
<ELASTICSEARCH_URL />
<ELASTICSEARCH_API_KEY />
<ELASTICSEARCH_USERNAME />
<ELASTICSEARCH_PASSWORD>changeme</ELASTICSEARCH_PASSWORD>
<ELASTICSEARCH_CA_CERTIFICATE />
</properties>

<dependencies>

<dependency>
Expand Down Expand Up @@ -86,4 +106,28 @@
</dependency>
</dependencies>

<build>
<testResources>
<testResource>
<directory>src/test/resources</directory>
<filtering>true</filtering>
</testResource>
</testResources>

<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-failsafe-plugin</artifactId>
<configuration>
<environmentVariables>
<ELASTICSEARCH_URL>${ELASTICSEARCH_URL}</ELASTICSEARCH_URL>
<ELASTICSEARCH_API_KEY>${ELASTICSEARCH_API_KEY}</ELASTICSEARCH_API_KEY>
<ELASTICSEARCH_USERNAME>${ELASTICSEARCH_USERNAME}</ELASTICSEARCH_USERNAME>
<ELASTICSEARCH_PASSWORD>${ELASTICSEARCH_PASSWORD}</ELASTICSEARCH_PASSWORD>
<ELASTICSEARCH_CA_CERTIFICATE>${ELASTICSEARCH_CA_CERTIFICATE}</ELASTICSEARCH_CA_CERTIFICATE>
</environmentVariables>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package dev.langchain4j.store.embedding.elasticsearch.spring;

import org.apache.http.Header;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.message.BasicHeader;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.SSLContexts;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.ByteArrayInputStream;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;

import static dev.langchain4j.internal.Utils.isNullOrBlank;

/**
* This class simply helps create a Rest Client instance
*/
class ElasticsearchClientHelper {

private static final Logger log = LoggerFactory.getLogger(ElasticsearchClientHelper.class);

/**
* Create an Elasticsearch Rest Client and test that it's running.
*
* @param address the server url, like <a href="https://localhost:9200">https://localhost:9200</a>
* @param apiKey the API key if any. If null, we will be using login/password
* @param username the username to use if apiKey is not set.
* @param password the password to use if apiKey is not set.
* @param checkCertificate true if we want to check the certificate. If false, we won't check the certificate (tests only)
* @param caCertificate the SSL CA certificate to use if provided, otherwise we will use the system ones
* @return the client instance
*/
static RestClient getClient(String address, String apiKey, String username, String password,
boolean checkCertificate, byte[] caCertificate) {
log.debug("Trying to connect to {}.", address);

// Create the low-level client
RestClientBuilder restClientBuilder = RestClient.builder(HttpHost.create(address));

if (!isNullOrBlank(apiKey)) {
restClientBuilder.setDefaultHeaders(new Header[]{
new BasicHeader("Authorization", "Apikey " + apiKey)
});
} else {
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY,
new UsernamePasswordCredentials(username, password));
restClientBuilder.setHttpClientConfigCallback(hcb -> hcb
.setDefaultCredentialsProvider(credentialsProvider)
.setSSLContext(getSSLContext(checkCertificate, caCertificate)));
}

return restClientBuilder.build();
}

/**
* @param checkCertificate true if we want to check the certificate. If false, we won't check the certificate (tests only)
* @param caCertificate the SSL CA certificate to use if provided, otherwise we will use the system ones
* @return the SSL Context
*/
private static SSLContext getSSLContext(boolean checkCertificate, byte[] caCertificate) {
// If we don't want to check anything (not recommended for production)
if (!checkCertificate) {
return createTrustAllCertsContext();
}

// If we don't have a self-signed certificate
if (caCertificate == null) {
return null;
}

// If we have a self-signed certificate we need to check it against the fake Certificate Authority
return createContextFromCaCert(caCertificate);
}

/**
* Create an SSL Context from a Certificate
* @param certificate Certificate provided as a byte array
* @return the SSL Context
*/
private static SSLContext createContextFromCaCert(byte[] certificate) {
try {
CertificateFactory factory = CertificateFactory.getInstance("X.509");
Certificate trustedCa = factory.generateCertificate(
new ByteArrayInputStream(certificate)
);
KeyStore trustStore = KeyStore.getInstance("pkcs12");
trustStore.load(null, null);
trustStore.setCertificateEntry("ca", trustedCa);
SSLContextBuilder sslContextBuilder = SSLContexts.custom().loadTrustMaterial(trustStore, null);
return sslContextBuilder.build();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private static final TrustManager[] trustAllCerts = new TrustManager[]{new X509TrustManager() {
@Override public void checkClientTrusted(X509Certificate[] chain, String authType) {}
@Override public void checkServerTrusted(X509Certificate[] chain, String authType) {}
@Override public X509Certificate[] getAcceptedIssuers() { return null; }
}};

private static SSLContext createTrustAllCertsContext() {
try {
SSLContext sslContext = SSLContext.getInstance("SSL");
sslContext.init(null, trustAllCerts, new SecureRandom());
return sslContext;
} catch (NoSuchAlgorithmException | KeyManagementException e) {
throw new RuntimeException("Can not create the SSLContext", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package dev.langchain4j.store.embedding.elasticsearch.spring;

import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.elasticsearch.ElasticsearchEmbeddingStore;
import org.elasticsearch.client.RestClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.lang.Nullable;

import java.util.Base64;
import java.util.Optional;

import static dev.langchain4j.store.embedding.elasticsearch.spring.ElasticsearchEmbeddingStoreProperties.*;
Expand All @@ -18,21 +20,54 @@
@ConditionalOnProperty(prefix = PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true)
public class ElasticsearchEmbeddingStoreAutoConfiguration {

private static final Logger log = LoggerFactory.getLogger(ElasticsearchEmbeddingStoreAutoConfiguration.class);

/**
* Create a bean for the Elasticsearch Rest Client if it does not exist yet
* but ideally this should be created before
* @param properties the properties
* @return a RestClient instance
*/
@Bean
@ConditionalOnMissingBean
public ElasticsearchEmbeddingStore elasticsearchEmbeddingStore(ElasticsearchEmbeddingStoreProperties properties,
@Nullable EmbeddingModel embeddingModel) {
public RestClient elasticsearchRestClient(ElasticsearchEmbeddingStoreProperties properties) {
String serverUrl = Optional.ofNullable(properties.getServerUrl()).orElse(DEFAULT_SERVER_URL);
String username = Optional.ofNullable(properties.getUsername()).orElse(DEFAULT_USERNAME);
boolean checkSslCertificates = Optional.ofNullable(properties.getCheckSslCertificates()).orElse(true);

if (!checkSslCertificates) {
log.warn("disabling ssl checks is a bad practice in general and should be done ONLY in the context of tests.");
}

// If we have a self-signed certificate, we can provide it
byte[] caCertificate = null;
String caCertificateAsBase64String = properties.getCaCertificateAsBase64String();
if (caCertificateAsBase64String != null) {
caCertificate = Base64.getDecoder().decode(caCertificateAsBase64String);
}

log.debug("create RestClient running at [{}] with api key [{}], username [{}].",
serverUrl, properties.getApiKey(), username);

return ElasticsearchClientHelper.getClient(
serverUrl,
properties.getApiKey(),
username,
properties.getPassword(),
checkSslCertificates,
caCertificate);
}

@Bean
@ConditionalOnMissingBean
public ElasticsearchEmbeddingStore elasticsearchEmbeddingStore(ElasticsearchEmbeddingStoreProperties properties,
RestClient elasticsearchRestClient) {
String indexName = Optional.ofNullable(properties.getIndexName()).orElse(DEFAULT_INDEX_NAME);
Integer dimension = Optional.ofNullable(properties.getDimension()).orElseGet(() -> embeddingModel == null ? null : embeddingModel.dimension());

log.debug("create ElasticsearchEmbeddingStore on index [{}].", indexName);
return ElasticsearchEmbeddingStore.builder()
.serverUrl(serverUrl)
.apiKey(properties.getApiKey())
.userName(properties.getUserName())
.password(properties.getPassword())
.restClient(elasticsearchRestClient)
.indexName(indexName)
.dimension(dimension)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
public class ElasticsearchEmbeddingStoreProperties {

static final String PREFIX = "langchain4j.elasticsearch";
static final String DEFAULT_SERVER_URL = "http://localhost:9200";
static final String DEFAULT_SERVER_URL = "https://localhost:9200";
static final String DEFAULT_INDEX_NAME = "langchain4j-index";
static final String DEFAULT_USERNAME = "elastic";

private String serverUrl;
private String apiKey;
private String userName;
private String username;
private String password;
private String indexName;
private Integer dimension;
private Boolean checkSslCertificates;
private String caCertificateAsBase64String;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,38 @@
import dev.langchain4j.store.embedding.elasticsearch.ElasticsearchEmbeddingStore;
import dev.langchain4j.store.embedding.spring.EmbeddingStoreAutoConfigurationIT;
import lombok.SneakyThrows;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RestClient;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.elasticsearch.ElasticsearchContainer;
import org.junit.jupiter.api.BeforeEach;
import org.springframework.context.ApplicationContext;

import java.io.IOException;

import static dev.langchain4j.internal.Utils.randomUUID;
import static org.assertj.core.api.Assertions.assertThat;

class ElasticsearchEmbeddingStoreAutoConfigurationIT extends EmbeddingStoreAutoConfigurationIT {

static ElasticsearchContainer elasticsearch = new ElasticsearchContainer("elasticsearch:8.9.0")
.withEnv("xpack.security.enabled", "false")
.waitingFor(Wait.defaultWaitStrategy());
static ElasticsearchTestContainerHelper elasticsearchTestContainerHelper = new ElasticsearchTestContainerHelper();

String indexName;

@BeforeAll
static void beforeAll() {
elasticsearch.start();
static void startServices() throws IOException {
elasticsearchTestContainerHelper.startServices();
assertThat(elasticsearchTestContainerHelper.restClient).isNotNull();
}

@AfterAll
static void afterAll() {
elasticsearch.stop();
static void stopServices() throws IOException {
elasticsearchTestContainerHelper.stopServices();
}

@BeforeEach
void setIndexName() {
indexName = randomUUID();
}

@Override
Expand All @@ -41,8 +52,10 @@ protected Class<? extends EmbeddingStore<TextSegment>> embeddingStoreClass() {
@Override
protected String[] properties() {
return new String[]{
"langchain4j.elasticsearch.serverUrl=" + elasticsearch.getHttpHostAddress(),
"langchain4j.elasticsearch.indexName=" + randomUUID()
"langchain4j.elasticsearch.serverUrl=https://" + elasticsearchTestContainerHelper.elasticsearch.getHttpHostAddress(),
"langchain4j.elasticsearch.password=changeme",
"langchain4j.elasticsearch.indexName=" + indexName,
"langchain4j.elasticsearch.caCertificateAsBase64String=" + elasticsearchTestContainerHelper.certAsBase64
};
}

Expand All @@ -53,7 +66,8 @@ protected String dimensionPropertyKey() {

@Override
@SneakyThrows
protected void awaitUntilPersisted() {
Thread.sleep(1000);
protected void awaitUntilPersisted(ApplicationContext context) {
RestClient restClient = context.getBean(RestClient.class);
restClient.performRequest(new Request("POST", "/" + indexName + "/_refresh"));
}
}
Loading

0 comments on commit e052f56

Please sign in to comment.