From 751e5cd111a5065eb9f1ef186cea7dc2f58814c1 Mon Sep 17 00:00:00 2001 From: yugyeom <48901587+rladbrua0207@users.noreply.github.com> Date: Fri, 11 Oct 2024 23:04:31 +0900 Subject: [PATCH] =?UTF-8?q?Feat:=20AI=20Client=EA=B8=B0=EB=8A=A5=20?= =?UTF-8?q?=EA=B5=AC=ED=98=84=20(#18)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Feat: AI Client기능 구현 * Test: AI Client 통합테스트 작성 * build: 빌드시 테스트에서 제외해야 할 테스트 태그 추가 * Test: AI Client 단위테스트 작성 * Feat: LLMService에 AI 서버 요청 기능 추가 * Test: LLMService AI_기능_요청 메서드에 클라이언트 단위테스트 코드 추가 * Chore: 불필요한 주석 제거 * Chore: 더미 AI서버 주소 추가 * Chore: 불필요한 impl클래스 제거 및 테스트코드 수정 --- build.gradle | 6 ++ .../domain/AnnotationRepository.java | 4 ++ src/main/java/notai/auth/Auth.java | 5 +- src/main/java/notai/client/ai/AiClient.java | 18 ++++++ .../java/notai/client/ai/AiClientConfig.java | 32 +++++++++++ .../client/ai/request/LlmTaskRequest.java | 11 ++++ .../client/ai/request/SttTaskRequest.java | 8 +++ .../client/ai/response/TaskResponse.java | 9 +++ .../java/notai/common/config/AuthConfig.java | 6 +- .../notai/common/config/AuthInterceptor.java | 3 +- .../java/notai/common/domain/RootEntity.java | 3 +- .../notai/llm/application/LLMService.java | 29 ++++++++-- .../application/result/MemberFindResult.java | 3 +- .../response/MemberFindResponse.java | 3 +- src/main/resources/application-local.yml | 1 + .../client/ai/AiClientIntegrationTest.java | 49 ++++++++++++++++ .../java/notai/client/ai/AiClientTest.java | 56 ++++++++++++++++++ .../notai/llm/application/LLMServiceTest.java | 57 +++++++++++++++---- 18 files changed, 277 insertions(+), 26 deletions(-) create mode 100644 src/main/java/notai/client/ai/AiClient.java create mode 100644 src/main/java/notai/client/ai/AiClientConfig.java create mode 100644 src/main/java/notai/client/ai/request/LlmTaskRequest.java create mode 100644 src/main/java/notai/client/ai/request/SttTaskRequest.java create mode 100644 src/main/java/notai/client/ai/response/TaskResponse.java create mode 100644 src/test/java/notai/client/ai/AiClientIntegrationTest.java create mode 100644 src/test/java/notai/client/ai/AiClientTest.java diff --git a/build.gradle b/build.gradle index 344ed4f..79dab35 100644 --- a/build.gradle +++ b/build.gradle @@ -64,3 +64,9 @@ dependencies { tasks.named('test') { useJUnitPlatform() } + +test { + useJUnitPlatform { + excludeTags 'exclude-test' + } +} diff --git a/src/main/java/notai/annotation/domain/AnnotationRepository.java b/src/main/java/notai/annotation/domain/AnnotationRepository.java index c05ab3c..5ff2bfa 100644 --- a/src/main/java/notai/annotation/domain/AnnotationRepository.java +++ b/src/main/java/notai/annotation/domain/AnnotationRepository.java @@ -17,4 +17,8 @@ default Annotation getById(Long annotationId) { return findById(annotationId) .orElseThrow(() -> new NotFoundException("주석을 찾을 수 없습니다. ID: " + annotationId)); } + + List findByDocumentIdAndPageNumber(Long documentId, Integer pageNumber); + + List findByDocumentId(Long documentId); } diff --git a/src/main/java/notai/auth/Auth.java b/src/main/java/notai/auth/Auth.java index 62c5c00..d3eedd0 100644 --- a/src/main/java/notai/auth/Auth.java +++ b/src/main/java/notai/auth/Auth.java @@ -1,12 +1,13 @@ package notai.auth; import io.swagger.v3.oas.annotations.Hidden; -import static java.lang.annotation.ElementType.PARAMETER; -import static java.lang.annotation.RetentionPolicy.RUNTIME; import java.lang.annotation.Retention; import java.lang.annotation.Target; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + @Hidden @Target(PARAMETER) @Retention(RUNTIME) diff --git a/src/main/java/notai/client/ai/AiClient.java b/src/main/java/notai/client/ai/AiClient.java new file mode 100644 index 0000000..296787a --- /dev/null +++ b/src/main/java/notai/client/ai/AiClient.java @@ -0,0 +1,18 @@ +package notai.client.ai; + +import notai.client.ai.request.LlmTaskRequest; +import notai.client.ai.response.TaskResponse; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.service.annotation.PostExchange; + +public interface AiClient { + + @PostExchange(url = "/api/ai/llm") + TaskResponse submitLlmTask(@RequestBody LlmTaskRequest request); + + @PostExchange(url = "/api/ai/stt") + TaskResponse submitSttTask(@RequestPart("audio") MultipartFile audioFile); +} + diff --git a/src/main/java/notai/client/ai/AiClientConfig.java b/src/main/java/notai/client/ai/AiClientConfig.java new file mode 100644 index 0000000..88b17f9 --- /dev/null +++ b/src/main/java/notai/client/ai/AiClientConfig.java @@ -0,0 +1,32 @@ +package notai.client.ai; + +import lombok.extern.slf4j.Slf4j; +import static notai.client.HttpInterfaceUtil.createHttpInterface; +import notai.common.exception.type.ExternalApiException; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpStatusCode; +import org.springframework.web.client.RestClient; + +@Slf4j +@Configuration +public class AiClientConfig { + + @Value("${ai-server-url}") + private String aiServerUrl; + + @Bean + public AiClient aiClient() { + RestClient restClient = + RestClient.builder().baseUrl(aiServerUrl).requestInterceptor((request, body, execution) -> { + request.getHeaders().setContentLength(body.length); // Content-Length 설정 안하면 411 에러 발생 + return execution.execute(request, body); + }).defaultStatusHandler(HttpStatusCode::isError, (request, response) -> { + String responseBody = new String(response.getBody().readAllBytes()); + log.error("Response Status: {}", response.getStatusCode()); + throw new ExternalApiException(responseBody, response.getStatusCode().value()); + }).build(); + return createHttpInterface(restClient, AiClient.class); + } +} diff --git a/src/main/java/notai/client/ai/request/LlmTaskRequest.java b/src/main/java/notai/client/ai/request/LlmTaskRequest.java new file mode 100644 index 0000000..c44ae23 --- /dev/null +++ b/src/main/java/notai/client/ai/request/LlmTaskRequest.java @@ -0,0 +1,11 @@ +package notai.client.ai.request; + +public record LlmTaskRequest( + String ocrText, + String stt, + String keyboardNote +) { + public static LlmTaskRequest of(String ocrText, String stt, String keyboardNote) { + return new LlmTaskRequest(ocrText, stt, keyboardNote); + } +} diff --git a/src/main/java/notai/client/ai/request/SttTaskRequest.java b/src/main/java/notai/client/ai/request/SttTaskRequest.java new file mode 100644 index 0000000..81e8bfa --- /dev/null +++ b/src/main/java/notai/client/ai/request/SttTaskRequest.java @@ -0,0 +1,8 @@ +package notai.client.ai.request; + +import org.springframework.web.multipart.MultipartFile; + +public record SttTaskRequest( + MultipartFile audioFile +) { +} diff --git a/src/main/java/notai/client/ai/response/TaskResponse.java b/src/main/java/notai/client/ai/response/TaskResponse.java new file mode 100644 index 0000000..3105145 --- /dev/null +++ b/src/main/java/notai/client/ai/response/TaskResponse.java @@ -0,0 +1,9 @@ +package notai.client.ai.response; + +import java.util.UUID; + +public record TaskResponse( + UUID taskId, + String taskType +) { +} diff --git a/src/main/java/notai/common/config/AuthConfig.java b/src/main/java/notai/common/config/AuthConfig.java index c0e8bc9..6e4417f 100644 --- a/src/main/java/notai/common/config/AuthConfig.java +++ b/src/main/java/notai/common/config/AuthConfig.java @@ -17,8 +17,10 @@ public class AuthConfig implements WebMvcConfigurer { @Override public void addInterceptors(InterceptorRegistry registry) { - registry.addInterceptor(authInterceptor).addPathPatterns("/api/**").excludePathPatterns( - "/api/members/oauth/login/**").excludePathPatterns("/api/members/token/refresh"); + registry.addInterceptor(authInterceptor) + .addPathPatterns("/api/**") + .excludePathPatterns("/api/members/oauth/login/**") + .excludePathPatterns("/api/members/token/refresh"); } @Override diff --git a/src/main/java/notai/common/config/AuthInterceptor.java b/src/main/java/notai/common/config/AuthInterceptor.java index 327a0f2..87c168c 100644 --- a/src/main/java/notai/common/config/AuthInterceptor.java +++ b/src/main/java/notai/common/config/AuthInterceptor.java @@ -3,10 +3,11 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import notai.auth.TokenService; -import static org.springframework.http.HttpHeaders.AUTHORIZATION; import org.springframework.stereotype.Component; import org.springframework.web.servlet.HandlerInterceptor; +import static org.springframework.http.HttpHeaders.AUTHORIZATION; + @Component public class AuthInterceptor implements HandlerInterceptor { private final TokenService tokenService; diff --git a/src/main/java/notai/common/domain/RootEntity.java b/src/main/java/notai/common/domain/RootEntity.java index 7fcd71b..c8220c2 100644 --- a/src/main/java/notai/common/domain/RootEntity.java +++ b/src/main/java/notai/common/domain/RootEntity.java @@ -2,7 +2,6 @@ import jakarta.persistence.EntityListeners; import jakarta.persistence.MappedSuperclass; -import static lombok.AccessLevel.PROTECTED; import lombok.Getter; import lombok.NoArgsConstructor; import org.springframework.data.annotation.CreatedDate; @@ -13,6 +12,8 @@ import java.time.LocalDateTime; import java.util.Objects; +import static lombok.AccessLevel.PROTECTED; + @Getter @NoArgsConstructor(access = PROTECTED) @EntityListeners(AuditingEntityListener.class) diff --git a/src/main/java/notai/llm/application/LLMService.java b/src/main/java/notai/llm/application/LLMService.java index d8a395f..6986a3b 100644 --- a/src/main/java/notai/llm/application/LLMService.java +++ b/src/main/java/notai/llm/application/LLMService.java @@ -1,6 +1,11 @@ package notai.llm.application; +import static java.util.stream.Collectors.groupingBy; import lombok.RequiredArgsConstructor; +import notai.annotation.domain.Annotation; +import notai.annotation.domain.AnnotationRepository; +import notai.client.ai.AiClient; +import notai.client.ai.request.LlmTaskRequest; import notai.document.domain.Document; import notai.document.domain.DocumentRepository; import notai.llm.application.command.LLMSubmitCommand; @@ -16,7 +21,10 @@ import org.springframework.transaction.annotation.Transactional; import java.time.LocalDateTime; +import java.util.List; +import java.util.Map; import java.util.UUID; +import java.util.stream.Collectors; /** * SummaryService 와 ExamService 는 엔티티와 관련된 로직만 처리하고 @@ -32,12 +40,24 @@ public class LLMService { private final DocumentRepository documentRepository; private final SummaryRepository summaryRepository; private final ProblemRepository problemRepository; + private final AnnotationRepository annotationRepository; + private final AiClient aiClient; public LLMSubmitResult submitTask(LLMSubmitCommand command) { Document foundDocument = documentRepository.getById(command.documentId()); + List annotations = annotationRepository.findByDocumentId(command.documentId()); + + Map> annotationsByPage = + annotations.stream().collect(groupingBy(Annotation::getPageNumber)); command.pages().forEach(pageNumber -> { - UUID taskId = sendRequestToAIServer(); + String annotationContents = annotationsByPage.getOrDefault( + pageNumber, + List.of() + ).stream().map(Annotation::getContent).collect(Collectors.joining(", ")); + + // Todo OCR, STT 결과 전달 + UUID taskId = sendRequestToAIServer("ocrText", "stt", annotationContents); Summary summary = new Summary(foundDocument, pageNumber); Problem problem = new Problem(foundDocument, pageNumber); @@ -64,10 +84,7 @@ public Integer updateSummaryAndProblem(SummaryAndProblemUpdateCommand command) { return command.pageNumber(); } - /** - * 임시 값 반환, 추후 AI 서버에서 작업 단위 UUID 가 반환됨. - */ - private UUID sendRequestToAIServer() { - return UUID.randomUUID(); + private UUID sendRequestToAIServer(String ocrText, String stt, String keyboardNote) { + return aiClient.submitLlmTask(LlmTaskRequest.of(ocrText, stt, keyboardNote)).taskId(); } } diff --git a/src/main/java/notai/member/application/result/MemberFindResult.java b/src/main/java/notai/member/application/result/MemberFindResult.java index 261e83e..13b2169 100644 --- a/src/main/java/notai/member/application/result/MemberFindResult.java +++ b/src/main/java/notai/member/application/result/MemberFindResult.java @@ -3,7 +3,8 @@ import notai.member.domain.Member; public record MemberFindResult( - Long id, String nickname + Long id, + String nickname ) { public static MemberFindResult from(Member member) { return new MemberFindResult(member.getId(), member.getNickname()); diff --git a/src/main/java/notai/member/presentation/response/MemberFindResponse.java b/src/main/java/notai/member/presentation/response/MemberFindResponse.java index a1b8115..ebd525b 100644 --- a/src/main/java/notai/member/presentation/response/MemberFindResponse.java +++ b/src/main/java/notai/member/presentation/response/MemberFindResponse.java @@ -3,7 +3,8 @@ import notai.member.application.result.MemberFindResult; public record MemberFindResponse( - Long id, String nickname + Long id, + String nickname ) { public static MemberFindResponse from(MemberFindResult result) { return new MemberFindResponse(result.id(), result.nickname()); diff --git a/src/main/resources/application-local.yml b/src/main/resources/application-local.yml index ba990f8..da70f6b 100644 --- a/src/main/resources/application-local.yml +++ b/src/main/resources/application-local.yml @@ -37,6 +37,7 @@ server: force: true server-url: http://localhost:8080 +ai-server-url: http://localhost:5000 # 실제 AI 서버주소는 prod에서만 사용 token: # todo production에서 secretKey 변경 secretKey: "ZGQrT0tuZHZkRWRxeXJCamRYMDFKMnBaR2w5WXlyQm9HU2RqZHNha1gycFlkMWpLc0dObw==" diff --git a/src/test/java/notai/client/ai/AiClientIntegrationTest.java b/src/test/java/notai/client/ai/AiClientIntegrationTest.java new file mode 100644 index 0000000..ebd27d3 --- /dev/null +++ b/src/test/java/notai/client/ai/AiClientIntegrationTest.java @@ -0,0 +1,49 @@ +package notai.client.ai; + +import notai.client.ai.request.LlmTaskRequest; +import notai.client.ai.response.TaskResponse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.mock.web.MockMultipartFile; + +@SpringBootTest +@Tag("exclude-test") // 테스트 필요할때 주석 +class AiClientIntegrationTest { + + @Autowired + private AiClient aiClient; + + @Test + void LLM_태스크_제출_통합_테스트() { + // Given + LlmTaskRequest request = LlmTaskRequest.of("OCR 텍스트", "STT 텍스트", "키보드 노트"); + + // When + TaskResponse response = aiClient.submitLlmTask(request); + + // Then + assertNotNull(response); + assertNotNull(response.taskId()); + assertEquals("llm", response.taskType()); + } + + @Test + void STT_태스크_제출_통합_테스트() { + // Given + MockMultipartFile audioFile = new MockMultipartFile( + "audio", "test.mp3", "audio/mpeg", "test audio content".getBytes() + ); + + // When + TaskResponse response = aiClient.submitSttTask(audioFile); + + // Then + assertNotNull(response); + assertNotNull(response.taskId()); + assertEquals("llm", response.taskType()); + } +} diff --git a/src/test/java/notai/client/ai/AiClientTest.java b/src/test/java/notai/client/ai/AiClientTest.java new file mode 100644 index 0000000..2fab457 --- /dev/null +++ b/src/test/java/notai/client/ai/AiClientTest.java @@ -0,0 +1,56 @@ +package notai.client.ai; + +import notai.client.ai.request.LlmTaskRequest; +import notai.client.ai.response.TaskResponse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import static org.mockito.Mockito.*; +import org.mockito.MockitoAnnotations; +import org.springframework.web.multipart.MultipartFile; + +import java.util.UUID; + +class AiClientTest { + + @Mock + private AiClient aiClient; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + void LLM_테스크_전달_테스트() { + // Given + LlmTaskRequest request = LlmTaskRequest.of("OCR 텍스트", "STT 텍스트", "키보드 노트"); + UUID expectedTaskId = UUID.randomUUID(); + TaskResponse expectedResponse = new TaskResponse(expectedTaskId, "llm"); + when(aiClient.submitLlmTask(request)).thenReturn(expectedResponse); + + // When + TaskResponse response = aiClient.submitLlmTask(request); + + // Then + assertEquals(expectedResponse, response); + verify(aiClient, times(1)).submitLlmTask(request); + } + + @Test + void STT_테스크_전달_테스트() { + // Given + MultipartFile mockAudioFile = mock(MultipartFile.class); + UUID expectedTaskId = UUID.randomUUID(); + TaskResponse expectedResponse = new TaskResponse(expectedTaskId, "stt"); + when(aiClient.submitSttTask(mockAudioFile)).thenReturn(expectedResponse); + + // When + TaskResponse response = aiClient.submitSttTask(mockAudioFile); + + // Then + assertEquals(expectedResponse, response); + verify(aiClient, times(1)).submitSttTask(mockAudioFile); + } +} diff --git a/src/test/java/notai/llm/application/LLMServiceTest.java b/src/test/java/notai/llm/application/LLMServiceTest.java index 4919144..2523c47 100644 --- a/src/test/java/notai/llm/application/LLMServiceTest.java +++ b/src/test/java/notai/llm/application/LLMServiceTest.java @@ -1,32 +1,40 @@ package notai.llm.application; +import notai.annotation.domain.Annotation; +import notai.annotation.domain.AnnotationRepository; +import notai.client.ai.AiClient; +import notai.client.ai.request.LlmTaskRequest; +import notai.client.ai.response.TaskResponse; import notai.common.exception.type.NotFoundException; import notai.document.domain.Document; import notai.document.domain.DocumentRepository; +import notai.folder.domain.Folder; import notai.llm.application.command.LLMSubmitCommand; import notai.llm.application.command.SummaryAndProblemUpdateCommand; import notai.llm.application.result.LLMSubmitResult; import notai.llm.domain.LLM; import notai.llm.domain.LLMRepository; +import notai.member.domain.Member; +import notai.member.domain.OauthId; +import notai.member.domain.OauthProvider; import notai.problem.domain.Problem; import notai.problem.domain.ProblemRepository; import notai.summary.domain.Summary; import notai.summary.domain.SummaryRepository; +import static org.junit.jupiter.api.Assertions.*; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.BDDMockito.given; import org.mockito.InjectMocks; import org.mockito.Mock; +import static org.mockito.Mockito.*; import org.mockito.junit.jupiter.MockitoExtension; import java.util.List; import java.util.UUID; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.*; - @ExtendWith(MockitoExtension.class) class LLMServiceTest { @@ -45,6 +53,12 @@ class LLMServiceTest { @Mock private ProblemRepository problemRepository; + @Mock + private AnnotationRepository annotationRepository; + + @Mock + private AiClient aiClient; + @Test void AI_기능_요청시_존재하지_않는_문서ID로_요청한_경우_예외_발생() { // given @@ -62,22 +76,41 @@ class LLMServiceTest { } @Test - void AI_기능_요청() { + void AI_기능_요청_및_AI_클라이언트_테스트() { // given Long documentId = 1L; - List pages = List.of(1, 2, 3); + List pages = List.of(1, 2); LLMSubmitCommand command = new LLMSubmitCommand(documentId, pages); - Document document = mock(Document.class); + + Member member = new Member(new OauthId("12345", OauthProvider.KAKAO), "test@example.com", "TestUser"); + Folder folder = new Folder(member, "TestFolder"); + Document document = new Document(folder, "TestDocument", "http://example.com/test.pdf"); + + List annotations = List.of(new Annotation(document, 1, 10, 20, 100, 50, "Annotation 1"), + new Annotation(document, 1, 30, 40, 80, 60, "Annotation 2"), + new Annotation(document, 2, 50, 60, 120, 70, "Annotation 3") + ); + + UUID taskId = UUID.randomUUID(); + TaskResponse taskResponse = new TaskResponse(taskId, "llm"); given(documentRepository.getById(anyLong())).willReturn(document); + given(annotationRepository.findByDocumentId(anyLong())).willReturn(annotations); + given(aiClient.submitLlmTask(any(LlmTaskRequest.class))).willReturn(taskResponse); given(llmRepository.save(any(LLM.class))).willAnswer(invocation -> invocation.getArgument(0)); + // when LLMSubmitResult result = llmService.submitTask(command); // then - assertAll(() -> verify(documentRepository, times(1)).getById(anyLong()), - () -> verify(llmRepository, times(3)).save(any(LLM.class)) + assertAll(() -> verify(documentRepository, times(1)).getById(documentId), + () -> verify(annotationRepository, times(1)).findByDocumentId(documentId), + () -> verify(aiClient, times(2)).submitLlmTask(any(LlmTaskRequest.class)), + () -> verify(llmRepository, times(2)).save(any(LLM.class)) ); + + verify(aiClient).submitLlmTask(argThat(request -> request.keyboardNote().equals("Annotation 1, Annotation 2"))); + verify(aiClient).submitLlmTask(argThat(request -> request.keyboardNote().equals("Annotation 3"))); } @Test @@ -122,4 +155,4 @@ class LLMServiceTest { () -> assertEquals(pageNumber, resultPageNumber) ); } -} \ No newline at end of file +}