Skip to content

Commit

Permalink
Feat: AI Client기능 구현 (#18)
Browse files Browse the repository at this point in the history
* Feat: AI Client기능 구현

* Test: AI Client 통합테스트 작성

* build: 빌드시 테스트에서 제외해야 할 테스트 태그 추가

* Test: AI Client 단위테스트 작성

* Feat: LLMService에 AI 서버 요청 기능 추가

* Test: LLMService AI_기능_요청 메서드에 클라이언트 단위테스트 코드 추가

* Chore: 불필요한 주석 제거

* Chore: 더미 AI서버 주소 추가

* Chore: 불필요한 impl클래스 제거 및 테스트코드 수정
  • Loading branch information
yugyeom-ghim authored and hynseoj committed Oct 11, 2024
1 parent 780b049 commit 751e5cd
Show file tree
Hide file tree
Showing 18 changed files with 277 additions and 26 deletions.
6 changes: 6 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,9 @@ dependencies {
tasks.named('test') {
useJUnitPlatform()
}

test {
useJUnitPlatform {
excludeTags 'exclude-test'
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,8 @@ default Annotation getById(Long annotationId) {
return findById(annotationId)
.orElseThrow(() -> new NotFoundException("주석을 찾을 수 없습니다. ID: " + annotationId));
}

List<Annotation> findByDocumentIdAndPageNumber(Long documentId, Integer pageNumber);

List<Annotation> findByDocumentId(Long documentId);
}
5 changes: 3 additions & 2 deletions src/main/java/notai/auth/Auth.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package notai.auth;

import io.swagger.v3.oas.annotations.Hidden;
import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

import java.lang.annotation.Retention;
import java.lang.annotation.Target;

import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

@Hidden
@Target(PARAMETER)
@Retention(RUNTIME)
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/notai/client/ai/AiClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package notai.client.ai;

import notai.client.ai.request.LlmTaskRequest;
import notai.client.ai.response.TaskResponse;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.service.annotation.PostExchange;

public interface AiClient {

@PostExchange(url = "/api/ai/llm")
TaskResponse submitLlmTask(@RequestBody LlmTaskRequest request);

@PostExchange(url = "/api/ai/stt")
TaskResponse submitSttTask(@RequestPart("audio") MultipartFile audioFile);
}

32 changes: 32 additions & 0 deletions src/main/java/notai/client/ai/AiClientConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package notai.client.ai;

import lombok.extern.slf4j.Slf4j;
import static notai.client.HttpInterfaceUtil.createHttpInterface;
import notai.common.exception.type.ExternalApiException;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpStatusCode;
import org.springframework.web.client.RestClient;

@Slf4j
@Configuration
public class AiClientConfig {

@Value("${ai-server-url}")
private String aiServerUrl;

@Bean
public AiClient aiClient() {
RestClient restClient =
RestClient.builder().baseUrl(aiServerUrl).requestInterceptor((request, body, execution) -> {
request.getHeaders().setContentLength(body.length); // Content-Length 설정 안하면 411 에러 발생
return execution.execute(request, body);
}).defaultStatusHandler(HttpStatusCode::isError, (request, response) -> {
String responseBody = new String(response.getBody().readAllBytes());
log.error("Response Status: {}", response.getStatusCode());
throw new ExternalApiException(responseBody, response.getStatusCode().value());
}).build();
return createHttpInterface(restClient, AiClient.class);
}
}
11 changes: 11 additions & 0 deletions src/main/java/notai/client/ai/request/LlmTaskRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package notai.client.ai.request;

public record LlmTaskRequest(
String ocrText,
String stt,
String keyboardNote
) {
public static LlmTaskRequest of(String ocrText, String stt, String keyboardNote) {
return new LlmTaskRequest(ocrText, stt, keyboardNote);
}
}
8 changes: 8 additions & 0 deletions src/main/java/notai/client/ai/request/SttTaskRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package notai.client.ai.request;

import org.springframework.web.multipart.MultipartFile;

public record SttTaskRequest(
MultipartFile audioFile
) {
}
9 changes: 9 additions & 0 deletions src/main/java/notai/client/ai/response/TaskResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package notai.client.ai.response;

import java.util.UUID;

public record TaskResponse(
UUID taskId,
String taskType
) {
}
6 changes: 4 additions & 2 deletions src/main/java/notai/common/config/AuthConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ public class AuthConfig implements WebMvcConfigurer {

@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(authInterceptor).addPathPatterns("/api/**").excludePathPatterns(
"/api/members/oauth/login/**").excludePathPatterns("/api/members/token/refresh");
registry.addInterceptor(authInterceptor)
.addPathPatterns("/api/**")
.excludePathPatterns("/api/members/oauth/login/**")
.excludePathPatterns("/api/members/token/refresh");
}

@Override
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/notai/common/config/AuthInterceptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import notai.auth.TokenService;
import static org.springframework.http.HttpHeaders.AUTHORIZATION;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;

import static org.springframework.http.HttpHeaders.AUTHORIZATION;

@Component
public class AuthInterceptor implements HandlerInterceptor {
private final TokenService tokenService;
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/notai/common/domain/RootEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import jakarta.persistence.EntityListeners;
import jakarta.persistence.MappedSuperclass;
import static lombok.AccessLevel.PROTECTED;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.springframework.data.annotation.CreatedDate;
Expand All @@ -13,6 +12,8 @@
import java.time.LocalDateTime;
import java.util.Objects;

import static lombok.AccessLevel.PROTECTED;

@Getter
@NoArgsConstructor(access = PROTECTED)
@EntityListeners(AuditingEntityListener.class)
Expand Down
29 changes: 23 additions & 6 deletions src/main/java/notai/llm/application/LLMService.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package notai.llm.application;

import static java.util.stream.Collectors.groupingBy;
import lombok.RequiredArgsConstructor;
import notai.annotation.domain.Annotation;
import notai.annotation.domain.AnnotationRepository;
import notai.client.ai.AiClient;
import notai.client.ai.request.LlmTaskRequest;
import notai.document.domain.Document;
import notai.document.domain.DocumentRepository;
import notai.llm.application.command.LLMSubmitCommand;
Expand All @@ -16,7 +21,10 @@
import org.springframework.transaction.annotation.Transactional;

import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;

/**
* SummaryService 와 ExamService 는 엔티티와 관련된 로직만 처리하고
Expand All @@ -32,12 +40,24 @@ public class LLMService {
private final DocumentRepository documentRepository;
private final SummaryRepository summaryRepository;
private final ProblemRepository problemRepository;
private final AnnotationRepository annotationRepository;
private final AiClient aiClient;

public LLMSubmitResult submitTask(LLMSubmitCommand command) {
Document foundDocument = documentRepository.getById(command.documentId());
List<Annotation> annotations = annotationRepository.findByDocumentId(command.documentId());

Map<Integer, List<Annotation>> annotationsByPage =
annotations.stream().collect(groupingBy(Annotation::getPageNumber));

command.pages().forEach(pageNumber -> {
UUID taskId = sendRequestToAIServer();
String annotationContents = annotationsByPage.getOrDefault(
pageNumber,
List.of()
).stream().map(Annotation::getContent).collect(Collectors.joining(", "));

// Todo OCR, STT 결과 전달
UUID taskId = sendRequestToAIServer("ocrText", "stt", annotationContents);
Summary summary = new Summary(foundDocument, pageNumber);
Problem problem = new Problem(foundDocument, pageNumber);

Expand All @@ -64,10 +84,7 @@ public Integer updateSummaryAndProblem(SummaryAndProblemUpdateCommand command) {
return command.pageNumber();
}

/**
* 임시 값 반환, 추후 AI 서버에서 작업 단위 UUID 가 반환됨.
*/
private UUID sendRequestToAIServer() {
return UUID.randomUUID();
private UUID sendRequestToAIServer(String ocrText, String stt, String keyboardNote) {
return aiClient.submitLlmTask(LlmTaskRequest.of(ocrText, stt, keyboardNote)).taskId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import notai.member.domain.Member;

public record MemberFindResult(
Long id, String nickname
Long id,
String nickname
) {
public static MemberFindResult from(Member member) {
return new MemberFindResult(member.getId(), member.getNickname());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import notai.member.application.result.MemberFindResult;

public record MemberFindResponse(
Long id, String nickname
Long id,
String nickname
) {
public static MemberFindResponse from(MemberFindResult result) {
return new MemberFindResponse(result.id(), result.nickname());
Expand Down
1 change: 1 addition & 0 deletions src/main/resources/application-local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ server:
force: true

server-url: http://localhost:8080
ai-server-url: http://localhost:5000 # 실제 AI 서버주소는 prod에서만 사용

token: # todo production에서 secretKey 변경
secretKey: "ZGQrT0tuZHZkRWRxeXJCamRYMDFKMnBaR2w5WXlyQm9HU2RqZHNha1gycFlkMWpLc0dObw=="
Expand Down
49 changes: 49 additions & 0 deletions src/test/java/notai/client/ai/AiClientIntegrationTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package notai.client.ai;

import notai.client.ai.request.LlmTaskRequest;
import notai.client.ai.response.TaskResponse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.mock.web.MockMultipartFile;

@SpringBootTest
@Tag("exclude-test") // 테스트 필요할때 주석
class AiClientIntegrationTest {

@Autowired
private AiClient aiClient;

@Test
void LLM_태스크_제출_통합_테스트() {
// Given
LlmTaskRequest request = LlmTaskRequest.of("OCR 텍스트", "STT 텍스트", "키보드 노트");

// When
TaskResponse response = aiClient.submitLlmTask(request);

// Then
assertNotNull(response);
assertNotNull(response.taskId());
assertEquals("llm", response.taskType());
}

@Test
void STT_태스크_제출_통합_테스트() {
// Given
MockMultipartFile audioFile = new MockMultipartFile(
"audio", "test.mp3", "audio/mpeg", "test audio content".getBytes()
);

// When
TaskResponse response = aiClient.submitSttTask(audioFile);

// Then
assertNotNull(response);
assertNotNull(response.taskId());
assertEquals("llm", response.taskType());
}
}
56 changes: 56 additions & 0 deletions src/test/java/notai/client/ai/AiClientTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package notai.client.ai;

import notai.client.ai.request.LlmTaskRequest;
import notai.client.ai.response.TaskResponse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import static org.mockito.Mockito.*;
import org.mockito.MockitoAnnotations;
import org.springframework.web.multipart.MultipartFile;

import java.util.UUID;

class AiClientTest {

@Mock
private AiClient aiClient;

@BeforeEach
void setUp() {
MockitoAnnotations.openMocks(this);
}

@Test
void LLM_테스크_전달_테스트() {
// Given
LlmTaskRequest request = LlmTaskRequest.of("OCR 텍스트", "STT 텍스트", "키보드 노트");
UUID expectedTaskId = UUID.randomUUID();
TaskResponse expectedResponse = new TaskResponse(expectedTaskId, "llm");
when(aiClient.submitLlmTask(request)).thenReturn(expectedResponse);

// When
TaskResponse response = aiClient.submitLlmTask(request);

// Then
assertEquals(expectedResponse, response);
verify(aiClient, times(1)).submitLlmTask(request);
}

@Test
void STT_테스크_전달_테스트() {
// Given
MultipartFile mockAudioFile = mock(MultipartFile.class);
UUID expectedTaskId = UUID.randomUUID();
TaskResponse expectedResponse = new TaskResponse(expectedTaskId, "stt");
when(aiClient.submitSttTask(mockAudioFile)).thenReturn(expectedResponse);

// When
TaskResponse response = aiClient.submitSttTask(mockAudioFile);

// Then
assertEquals(expectedResponse, response);
verify(aiClient, times(1)).submitSttTask(mockAudioFile);
}
}
Loading

0 comments on commit 751e5cd

Please sign in to comment.