Skip to content

Commit

Permalink
Refactor: STT 도메인 연관관계 재구성 및 STT 페이지별 상태체크 기능 구현 (#38)
Browse files Browse the repository at this point in the history
* Refactor: STT 도메인 연관관계 재구성

- Stt와 Recording 간의 직접 연관관계를 제거
- SttTask를 중심으로 하는 새로운 구조로 변경
  - Recording <- SttTask -> Stt
- 관련 DB 스키마 변경
  - stt 테이블: recording_id 제거, stt_task_id 추가
  - stt_task 테이블: stt_id 제거, recording_id 추가

* Feat: Stt 페이지별 상태체크 기능 구현

* Test: STT 도메인 구조 변경에 따른 테스트 수정
  • Loading branch information
yugyeom-ghim authored Nov 14, 2024
1 parent 15577cf commit eac0b88
Show file tree
Hide file tree
Showing 19 changed files with 231 additions and 142 deletions.
15 changes: 8 additions & 7 deletions src/main/java/notai/stt/application/SttService.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,21 @@ public class SttService {
* AI 서버로부터 받은 STT 결과를 처리하여 페이지별 STT 데이터를 생성하고 저장합니다.
* 1. STT 테스크와 관련 엔티티들을 조회
* 2. 음성 인식된 단어들을 페이지와 매칭
* 3. 매칭 결과를 저장하고 테스크를 완료 처리
* 3. 테스크를 완료처리하고 매칭 결과 저장
*/
public void updateSttResult(UpdateSttResultCommand command) {
SttTask sttTask = sttTaskRepository.getById(command.taskId());
Stt stt = sttTask.getStt();
Recording recording = stt.getRecording();
Recording recording = sttTask.getRecording();

List<PageRecording> pageRecordings = pageRecordingRepository.findAllByRecordingIdOrderByStartTime(recording.getId());
List<PageRecording> pageRecordings =
pageRecordingRepository.findAllByRecordingIdOrderByStartTime(recording.getId());

SttPageMatchedDto matchedResult = stt.matchWordsWithPages(command.words(), pageRecordings);
List<Stt> pageMatchedSttResults = Stt.createFromMatchedResult(recording, matchedResult);
sttRepository.saveAll(pageMatchedSttResults);
SttPageMatchedDto matchedResult = Stt.matchWordsWithPages(command.words(), pageRecordings);
List<Stt> pageMatchedSttResults = Stt.createFromMatchedResult(sttTask, matchedResult);

sttTask.complete();
sttTaskRepository.save(sttTask);

sttRepository.saveAll(pageMatchedSttResults);
}
}
10 changes: 3 additions & 7 deletions src/main/java/notai/stt/application/SttTaskService.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import notai.recording.domain.Recording;
import notai.recording.domain.RecordingRepository;
import notai.stt.application.command.SttRequestCommand;
import notai.stt.domain.Stt;
import notai.stt.domain.SttRepository;
import notai.sttTask.domain.SttTask;
import notai.sttTask.domain.SttTaskRepository;
Expand All @@ -38,14 +37,14 @@ public void submitSttTask(SttRequestCommand command) {

try {
byte[] audioBytes = Files.readAllBytes(audioFile.toPath());

ByteArrayResource resource = new ByteArrayResource(audioBytes) {
@Override
public String getFilename() {
return audioFile.getName();
}
};

TaskResponse response = aiClient.submitSttTask(resource);
createAndSaveSttTask(recording, response);
} catch (IOException e) {
Expand All @@ -62,10 +61,7 @@ private File validateAudioFile(String audioFilePath) {
}

private void createAndSaveSttTask(Recording recording, TaskResponse response) {
Stt stt = new Stt(recording);
sttRepository.save(stt);

SttTask sttTask = new SttTask(response.taskId(), stt, TaskStatus.PENDING);
SttTask sttTask = new SttTask(response.taskId(), TaskStatus.PENDING, recording);
sttTaskRepository.save(sttTask);
}
}
60 changes: 34 additions & 26 deletions src/main/java/notai/stt/domain/Stt.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import lombok.NoArgsConstructor;
import notai.common.domain.RootEntity;
import notai.pageRecording.domain.PageRecording;
import notai.recording.domain.Recording;
import notai.stt.application.command.UpdateSttResultCommand;
import notai.stt.application.dto.SttPageMatchedDto;
import notai.sttTask.domain.SttTask;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -34,8 +34,8 @@ public class Stt extends RootEntity<Long> {

@NotNull
@ManyToOne(fetch = LAZY)
@JoinColumn(name = "recording_id")
private Recording recording;
@JoinColumn(name = "stt_task_id")
private SttTask sttTask;

private Integer pageNumber;

Expand All @@ -46,12 +46,12 @@ public class Stt extends RootEntity<Long> {

private Integer endTime;

public Stt(Recording recording) {
this.recording = recording;
public Stt(SttTask sttTask) {
this.sttTask = sttTask;
}

public Stt(Recording recording, Integer pageNumber, String content, Integer startTime, Integer endTime) {
this.recording = recording;
public Stt(SttTask sttTask, Integer pageNumber, String content, Integer startTime, Integer endTime) {
this.sttTask = sttTask;
this.pageNumber = pageNumber;
this.content = content;
this.startTime = startTime;
Expand All @@ -62,20 +62,20 @@ public Stt(Recording recording, Integer pageNumber, String content, Integer star
* 페이지별 STT 결과로부터 새로운 STT 엔티티를 생성합니다.
* 시작/종료 시간은 페이지 내 첫/마지막 단어의 시간으로 설정합니다.
*/
public static Stt createFromPageContent(Recording recording, SttPageMatchedDto.PageMatchedContent content) {
public static Stt createFromPageContent(SttTask sttTask, SttPageMatchedDto.PageMatchedContent content) {
return new Stt(
recording,
content.pageNumber(),
content.content(),
content.words().get(0).startTime(),
content.words().get(content.words().size() - 1).endTime()
sttTask,
content.pageNumber(),
content.content(),
content.words().get(0).startTime(),
content.words().get(content.words().size() - 1).endTime()
);
}

/**
* 음성 인식된 단어들을 페이지 기록과 매칭하여 페이지별 STT 결과를 생성합니다.
*/
public SttPageMatchedDto matchWordsWithPages(
public static SttPageMatchedDto matchWordsWithPages(
List<UpdateSttResultCommand.Word> words,
List<PageRecording> pageRecordings
) {
Expand All @@ -85,8 +85,8 @@ public SttPageMatchedDto matchWordsWithPages(

// 페이지 번호 순으로 자동 정렬됨
Map<Integer, List<SttPageMatchedDto.PageMatchedWord>> pageWordMap = new TreeMap<>();
int wordIndex = 0;
PageRecording lastPage = pageRecordings.get(pageRecordings.size() - 1);
int wordIndex = 0;
PageRecording lastPage = pageRecordings.get(pageRecordings.size() - 1);

// 각 페이지별로 매칭되는 단어들을 찾아 처리
for (PageRecording page : pageRecordings) {
Expand All @@ -97,7 +97,7 @@ public SttPageMatchedDto matchWordsWithPages(
// 현재 페이지의 시간 범위에 속하는 단어들을 찾아 매칭
while (wordIndex < words.size()) {
UpdateSttResultCommand.Word word = words.get(wordIndex);

// 페이지 시작 시간보다 이른 단어는 건너뛰기
if (word.start() + TIME_THRESHOLD < pageStart) {
wordIndex++;
Expand All @@ -124,14 +124,22 @@ public SttPageMatchedDto matchWordsWithPages(
}

// 페이지별로 단어들을 하나의 텍스트로 합치는 과정
List<SttPageMatchedDto.PageMatchedContent> pageContents = pageWordMap.entrySet().stream()
List<SttPageMatchedDto.PageMatchedContent> pageContents = pageWordMap
.entrySet().stream()
.map(entry -> {
Integer pageNumber = entry.getKey();
List<SttPageMatchedDto.PageMatchedWord> pageWords = entry.getValue();
String combinedContent = pageWords.stream()
.map(SttPageMatchedDto.PageMatchedWord::word)
.collect(Collectors.joining(" "));
return new SttPageMatchedDto.PageMatchedContent(pageNumber, combinedContent, pageWords);
List<SttPageMatchedDto.PageMatchedWord>
pageWords = entry.getValue();
String combinedContent =
pageWords.stream()
.map(SttPageMatchedDto.PageMatchedWord::word)
.collect(Collectors.joining(
" "));
return new SttPageMatchedDto.PageMatchedContent(
pageNumber,
combinedContent,
pageWords
);
})
.toList();

Expand All @@ -141,9 +149,9 @@ public SttPageMatchedDto matchWordsWithPages(
/**
* 페이지 매칭 결과로부터 STT 엔티티들을 생성하고 저장합니다.
*/
public static List<Stt> createFromMatchedResult(Recording recording, SttPageMatchedDto matchedResult) {
public static List<Stt> createFromMatchedResult(SttTask sttTask, SttPageMatchedDto matchedResult) {
return matchedResult.pageContents().stream()
.map(content -> createFromPageContent(recording, content))
.toList();
.map(content -> createFromPageContent(sttTask, content))
.toList();
}
}
2 changes: 1 addition & 1 deletion src/main/java/notai/stt/presentation/SttController.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@RequestMapping("/api/stt")
@RequiredArgsConstructor
public class SttController {
private final SttQueryService sttQueryService;
private final SttQueryService sttQueryService;

@GetMapping("/documents/{documentId}/pages/{pageNumber}")
public SttPageResponse getSttByPage(
Expand Down
31 changes: 16 additions & 15 deletions src/main/java/notai/stt/presentation/response/SttPageResponse.java
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
package notai.stt.presentation.response;

import notai.stt.domain.Stt;

import java.util.List;

public record SttPageResponse(
Integer pageNumber,
List<SttContent> contents
Integer pageNumber,
List<SttContent> contents
) {
public static SttPageResponse of(Integer pageNumber, List<Stt> sttList) {
List<SttContent> contents = sttList.stream()
.map(SttContent::from)
.toList();
return new SttPageResponse(pageNumber, contents);
}

public record SttContent(
String content,
Integer startTime,
Integer endTime
String content,
Integer startTime,
Integer endTime
) {
public static SttContent from(Stt stt) {
return new SttContent(
stt.getContent(),
stt.getStartTime(),
stt.getEndTime()
stt.getContent(),
stt.getStartTime(),
stt.getEndTime()
);
}
}

public static SttPageResponse of(Integer pageNumber, List<Stt> sttList) {
List<SttContent> contents = sttList.stream()
.map(SttContent::from)
.toList();
return new SttPageResponse(pageNumber, contents);
}
}
1 change: 1 addition & 0 deletions src/main/java/notai/stt/query/SttQueryRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@

public interface SttQueryRepository {
List<Stt> findAllByDocumentIdAndPageNumber(Long documentId, Integer pageNumber);

List<Stt> findAllByDocumentId(Long documentId);
}
13 changes: 7 additions & 6 deletions src/main/java/notai/stt/query/SttQueryRepositoryImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import com.querydsl.jpa.impl.JPAQueryFactory;
import lombok.RequiredArgsConstructor;
import notai.stt.domain.Stt;
import static notai.stt.domain.QStt.stt;
import notai.stt.domain.Stt;
import static notai.sttTask.domain.QSttTask.sttTask;

import java.util.List;

Expand All @@ -15,18 +16,18 @@ public class SttQueryRepositoryImpl implements SttQueryRepository {
public List<Stt> findAllByDocumentIdAndPageNumber(Long documentId, Integer pageNumber) {
return queryFactory
.selectFrom(stt)
.join(stt.recording).fetchJoin()
.where(stt.recording.document.id.eq(documentId)
.and(stt.pageNumber.eq(pageNumber)))
.join(stt.sttTask, sttTask).fetchJoin()
.where(stt.sttTask.recording.document.id.eq(documentId)
.and(stt.pageNumber.eq(pageNumber)))
.fetch();
}

@Override
public List<Stt> findAllByDocumentId(Long documentId) {
return queryFactory
.selectFrom(stt)
.join(stt.recording).fetchJoin()
.where(stt.recording.document.id.eq(documentId))
.join(stt.sttTask, sttTask).fetchJoin()
.where(stt.sttTask.recording.document.id.eq(documentId))
.fetch();
}
}
32 changes: 26 additions & 6 deletions src/main/java/notai/sttTask/application/SttTaskQueryService.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
import notai.member.domain.MemberRepository;
import notai.stt.domain.Stt;
import notai.stt.domain.SttRepository;
import notai.sttTask.application.command.SttTaskPageStatusCommand;
import notai.sttTask.application.result.SttTaskOverallStatusResult;
import notai.sttTask.application.result.SttTaskPageStatusResult;
import notai.sttTask.domain.SttTask;
import notai.sttTask.domain.SttTaskRepository;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.util.Collections;
import java.util.List;
import java.util.Objects;

@Transactional(readOnly = true)
@RequiredArgsConstructor
Expand All @@ -25,8 +28,8 @@ public class SttTaskQueryService {

private final DocumentRepository documentRepository;
private final MemberRepository memberRepository;
private final SttTaskRepository sttTaskRepository;
private final SttRepository sttRepository;
private final SttTaskRepository sttTaskRepository;

public SttTaskOverallStatusResult fetchOverallStatus(Long memberId, Long documentId) {
Document foundDocument = documentRepository.getById(documentId);
Expand All @@ -38,16 +41,33 @@ public SttTaskOverallStatusResult fetchOverallStatus(Long memberId, Long documen
if (sttResults.isEmpty()) {
return SttTaskOverallStatusResult.of(documentId, NOT_REQUESTED, 0, 0);
}
List<TaskStatus> taskStatuses =
sttTaskRepository.findAllBySttIn(sttResults).stream().map(SttTask::getStatus).toList();

List<TaskStatus> taskStatuses = sttResults.stream()
.map(stt -> stt.getSttTask().getStatus())
.distinct()
.toList();

int totalPages = taskStatuses.size();
int totalPages = foundDocument.getTotalPages();
int completedPages = Collections.frequency(taskStatuses, COMPLETED);

if (totalPages == completedPages) {
return SttTaskOverallStatusResult.of(documentId, COMPLETED, totalPages, completedPages);
if (taskStatuses.size() == 1 && taskStatuses.get(0) == COMPLETED) {
return SttTaskOverallStatusResult.of(documentId, COMPLETED, totalPages, totalPages);
}
return SttTaskOverallStatusResult.of(documentId, IN_PROGRESS, totalPages, completedPages);
}

public SttTaskPageStatusResult fetchPageStatus(Long memberId, SttTaskPageStatusCommand command) {
Document foundDocument = documentRepository.getById(command.documentId());
Member member = memberRepository.getById(memberId);
foundDocument.validateOwner(member);
foundDocument.validatePageNumber(command.pageNumber());

TaskStatus status = sttTaskRepository.getTaskStatusByDocumentIdAndPageNumber(
command.documentId(),
command.pageNumber()
);

// STT 페이지별 결과에 대한 상태는 존재의 유무로만 판단 가능하므로 없을경우 IN_PROGRESS 으로 통일
return SttTaskPageStatusResult.of(command.pageNumber(), Objects.requireNonNullElse(status, IN_PROGRESS));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package notai.sttTask.application.command;

public record SttTaskPageStatusCommand(
Long documentId,
Integer pageNumber
) {
public static SttTaskPageStatusCommand of(Long documentId, Integer pageNumber) {
return new SttTaskPageStatusCommand(documentId, pageNumber);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package notai.sttTask.application.result;

import notai.llm.domain.TaskStatus;

public record SttTaskPageStatusResult(
Integer pageNumber,
TaskStatus status
) {
public static SttTaskPageStatusResult of(Integer pageNumber, TaskStatus status) {
return new SttTaskPageStatusResult(pageNumber, status);
}
}
Loading

0 comments on commit eac0b88

Please sign in to comment.