Skip to content

Commit

Permalink
Feature/31 add vector search (#32)
Browse files Browse the repository at this point in the history
* add vector search and tests
* add api request limit per user
* add tests for limiter
* add validation and tests for title and description max length
  • Loading branch information
jonashonecker authored Jul 4, 2024
1 parent 3cd1320 commit 45a7e53
Show file tree
Hide file tree
Showing 36 changed files with 969 additions and 161 deletions.
26 changes: 25 additions & 1 deletion backend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.3.0</version>
<version>3.2.7</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.github.jonashonecker</groupId>
Expand All @@ -17,6 +17,7 @@
<sonar.organization>jonashonecker</sonar.organization>
<sonar.host.url>https://sonarcloud.io</sonar.host.url>
<java.version>22</java.version>
<mongodb.version>4.11.0</mongodb.version>
</properties>
<dependencies>
<dependency>
Expand Down Expand Up @@ -60,6 +61,29 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-validation</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
<version>4.12.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.giffing.bucket4j.spring.boot.starter</groupId>
<artifactId>bucket4j-spring-boot-starter</artifactId>
<version>0.12.7</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-cache</artifactId>
</dependency>
<dependency>
<groupId>javax.cache</groupId>
<artifactId>cache-api</artifactId>
</dependency>
<dependency>
<groupId>com.hazelcast</groupId>
<artifactId>hazelcast</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.cache.annotation.EnableCaching;

@SpringBootApplication
@EnableCaching
public class BackendApplication {

public static void main(String[] args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.github.jonashonecker.backend.error.domain.ApiErrorResponse;
import com.github.jonashonecker.backend.ticket.exception.NoSuchTicketException;
import com.github.jonashonecker.backend.user.exception.UserAuthenticationException;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.bind.annotation.ExceptionHandler;
Expand All @@ -26,15 +25,6 @@ public ApiErrorResponse handleValidationExceptions(MethodArgumentNotValidExcepti
);
}

@ExceptionHandler(UserAuthenticationException.class)
@ResponseStatus(HttpStatus.UNAUTHORIZED)
public ApiErrorResponse handleUserAuthenticationException(UserAuthenticationException ex) {

return new ApiErrorResponse(
"There is an issue with your user login. Please contact support."
);
}

@ExceptionHandler(NoSuchTicketException.class)
@ResponseStatus(HttpStatus.BAD_REQUEST)
public ApiErrorResponse handleNoSuchTicketException(NoSuchTicketException ex) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.github.jonashonecker.backend.ticket;

import com.github.jonashonecker.backend.ticket.domain.embedding.EmbeddingRequestDTO;
import com.github.jonashonecker.backend.ticket.domain.embedding.EmbeddingResponseDTO;
import com.github.jonashonecker.backend.ticket.domain.ticket.TicketWithTitleAndDescription;
import com.github.jonashonecker.backend.ticket.exception.EmbeddingFailedException;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestClient;

import java.util.List;

@Service
public class EmbeddingService {

private final RestClient restClient;

public EmbeddingService(
@Value("${app.openai-api-key}") String apiKey,
@Value("${app.openai-embedding-baseUrl}") String baseUrl
) {
this.restClient = RestClient.builder()
.baseUrl(baseUrl)
.defaultHeader("Authorization", "Bearer " + apiKey)
.build();
}

public List<Double> getEmbeddingVectorForTicket(TicketWithTitleAndDescription ticket) {
return getEmbeddingVector("<h1>" + ticket.title() + "</h1>" + ticket.description());
}

public List<Double> getEmbeddingVectorForSearchText(String searchText) {
return getEmbeddingVector(searchText);
}

private List<Double> getEmbeddingVector(String text) {
EmbeddingResponseDTO response = this.restClient.post()
.body(new EmbeddingRequestDTO(text))
.retrieve()
.body(EmbeddingResponseDTO.class);

if (response == null) {
throw new EmbeddingFailedException("Failed to retrieve embedding vector for ticket");
}

return response.data()
.stream()
.filter(o -> "embedding".equals(o.object()))
.flatMap(o -> o.embedding().stream())
.toList();
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package com.github.jonashonecker.backend.ticket;

import com.github.jonashonecker.backend.ticket.domain.NewTicketDTO;
import com.github.jonashonecker.backend.ticket.domain.Ticket;
import com.github.jonashonecker.backend.ticket.domain.UpdateTicketDTO;
import com.github.jonashonecker.backend.ticket.domain.ticket.*;
import jakarta.validation.Valid;
import org.springframework.web.bind.annotation.*;

import java.util.List;

import static com.github.jonashonecker.backend.ticket.utils.Utils.ticketToDtoMapper;

@RestController
@RequestMapping("/api/ticket")
public class TicketController {
Expand All @@ -18,18 +18,23 @@ public TicketController(TicketService ticketService) {
}

@GetMapping
public List<Ticket> getAllTickets() {
return ticketService.getAllTickets();
public List<TicketResponseDTO> getAllTickets(@RequestParam(required = false) String searchText) {
List<Ticket> tickets = searchText == null ? ticketService.getAllTickets() : ticketService.getTicketsByVectorSearch(searchText);
return tickets.stream().map(ticketToDtoMapper).toList();
}

@PostMapping
public Ticket createTicket(@Valid @RequestBody NewTicketDTO newTicketDTO) {
return ticketService.createTicket(newTicketDTO);
public TicketResponseDTO createTicket(@Valid @RequestBody TicketRequestDTO ticketRequestDTO) {
Ticket ticket = ticketService.createTicket(ticketRequestDTO);
return ticketToDtoMapper.apply(ticket);
}

@PutMapping
public Ticket updateTicket(@Valid @RequestBody UpdateTicketDTO updateTicketDTO) {
return ticketService.updateTicket(updateTicketDTO);
@PutMapping("/{id}")
public TicketResponseDTO updateTicket(
@Valid @RequestBody TicketRequestDTO ticketRequestDTO,
@PathVariable String id) {
Ticket ticket = ticketService.updateTicket(ticketRequestDTO, id);
return ticketToDtoMapper.apply(ticket);
}

@DeleteMapping("/{id}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.github.jonashonecker.backend.ticket;

import com.github.jonashonecker.backend.ticket.domain.Ticket;
import com.github.jonashonecker.backend.ticket.domain.ticket.Ticket;
import org.springframework.data.mongodb.repository.MongoRepository;
import org.springframework.stereotype.Repository;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.github.jonashonecker.backend.ticket;

import com.github.jonashonecker.backend.ticket.domain.ticket.Ticket;
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import org.bson.conversions.Bson;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Repository;

import java.util.ArrayList;
import java.util.List;

import static com.mongodb.client.model.Aggregates.vectorSearch;
import static com.mongodb.client.model.search.SearchPath.fieldPath;

@Repository
public class TicketRepositoryVectorSearch {

private final MongoClient mongoClient;

@Value("${spring.data.mongodb.database}")
private String databaseName;

@Value("${vectorSearch.index.name}")
private String indexName;

@Value("${vectorSearch.index.fieldPath}")
private String fieldPath;

public TicketRepositoryVectorSearch(MongoClient mongoClient) {
this.mongoClient = mongoClient;
}

private MongoCollection<Ticket> getTicketCollection() {
return mongoClient.getDatabase(databaseName).getCollection("tickets", Ticket.class);
}

public List<Ticket> findTicketsByVector(List<Double> embedding) {
int numCandidates = 100;
int limit = 5;
List<Bson> pipeline = List.of(
vectorSearch(
fieldPath(fieldPath),
embedding,
indexName,
numCandidates,
limit));
AggregateIterable<Ticket> iterable = getTicketCollection().aggregate(pipeline, Ticket.class);
return iterable.into(new ArrayList<>());
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package com.github.jonashonecker.backend.ticket;

import com.github.jonashonecker.backend.ticket.domain.NewTicketDTO;
import com.github.jonashonecker.backend.ticket.domain.Status;
import com.github.jonashonecker.backend.ticket.domain.Ticket;
import com.github.jonashonecker.backend.ticket.domain.UpdateTicketDTO;
import com.github.jonashonecker.backend.ticket.domain.ticket.*;
import com.github.jonashonecker.backend.ticket.exception.NoSuchTicketException;
import com.github.jonashonecker.backend.user.UserService;
import org.springframework.stereotype.Service;
Expand All @@ -13,13 +10,22 @@
@Service
public class TicketService {
private final TicketRepository ticketRepository;
private final TicketRepositoryVectorSearch ticketRepositoryVectorSearch;
private final IdService idService;
private final UserService userService;
private final EmbeddingService embeddingService;

public TicketService(TicketRepository ticketRepository, IdService idService, UserService userService) {
public TicketService(TicketRepository ticketRepository, IdService idService, UserService userService, EmbeddingService embeddingService, TicketRepositoryVectorSearch ticketRepositoryVectorSearch) {
this.ticketRepository = ticketRepository;
this.idService = idService;
this.userService = userService;
this.embeddingService = embeddingService;
this.ticketRepositoryVectorSearch = ticketRepositoryVectorSearch;
}

public List<Ticket> getTicketsByVectorSearch(String searchText) {
List<Double> embedding = embeddingService.getEmbeddingVectorForSearchText(searchText);
return ticketRepositoryVectorSearch.findTicketsByVector(embedding);
}

public List<Ticket> getAllTickets() {
Expand All @@ -30,29 +36,31 @@ public Ticket getTicketById(String id) {
return ticketRepository.findById(id).orElseThrow(() -> new NoSuchTicketException("Could not find ticket with id: " + id));
}

public Ticket createTicket(NewTicketDTO newTicketDTO) {
public Ticket createTicket(TicketRequestDTO ticketRequestDTO) {
String defaultProjectName = "Default Project";
Status defaultStatus = Status.OPEN;
return ticketRepository.insert(new Ticket(
idService.getUUID(),
defaultProjectName,
newTicketDTO.title(),
newTicketDTO.description(),
ticketRequestDTO.title(),
ticketRequestDTO.description(),
defaultStatus,
userService.getCurrentUser()
));
userService.getCurrentUser(),
embeddingService.getEmbeddingVectorForTicket(ticketRequestDTO))
);
}

public Ticket updateTicket(UpdateTicketDTO updateTicketDTO) {
Ticket existingTicket = getTicketById(updateTicketDTO.id());
public Ticket updateTicket(TicketRequestDTO ticketRequestDTO, String id) {
Ticket existingTicket = getTicketById(id);
return ticketRepository.save(new Ticket(
existingTicket.id(),
existingTicket.projectName(),
updateTicketDTO.title(),
updateTicketDTO.description(),
ticketRequestDTO.title(),
ticketRequestDTO.description(),
existingTicket.status(),
existingTicket.author()
));
existingTicket.author(),
embeddingService.getEmbeddingVectorForTicket(ticketRequestDTO))
);
}

public void deleteTicket(String id) {
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.github.jonashonecker.backend.ticket.domain.embedding;

public record EmbeddingRequestDTO(
String input,
String model
) {
public EmbeddingRequestDTO(String input) {
this(input, "text-embedding-3-large");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.github.jonashonecker.backend.ticket.domain.embedding;

import java.util.List;

public record EmbeddingResponseDTO(
List<EmbeddingResponseDataDTO> data
) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.github.jonashonecker.backend.ticket.domain.embedding;

import java.util.List;

public record EmbeddingResponseDataDTO(
String object,
List<Double> embedding
) {
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.github.jonashonecker.backend.ticket.domain;
package com.github.jonashonecker.backend.ticket.domain.ticket;

public enum Status {
OPEN,
Expand Down
Loading

0 comments on commit 45a7e53

Please sign in to comment.