Skip to content

Commit

Permalink
feat: 실시간 스트리밍 API를 활용하여 STT 구현
Browse files Browse the repository at this point in the history
기존 CLOVA Speech Recognition 에서 CLOVA Speech 실시간 스트리밍 API를 활용하여 실시간 처리를 가능하도록 수정한다.
  • Loading branch information
kanguk01 committed Oct 24, 2024
1 parent 9f16b66 commit dae3889
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 149 deletions.
39 changes: 39 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ plugins {
id 'java'
id 'org.springframework.boot' version '3.3.3'
id 'io.spring.dependency-management' version '1.1.6'
id 'com.google.protobuf' version '0.9.4'

}

group = 'com.splanet'
Expand Down Expand Up @@ -40,6 +42,17 @@ dependencies {
implementation 'jakarta.validation:jakarta.validation-api:3.0.2'
implementation 'org.apache.httpcomponents.client5:httpclient5:5.2.1'
implementation 'org.springframework.boot:spring-boot-starter-websocket'
// gRPC 및 Protocol Buffers 의존성
implementation 'io.grpc:grpc-netty-shaded:1.56.1'
implementation 'io.grpc:grpc-protobuf:1.56.1'
implementation 'io.grpc:grpc-stub:1.56.1'
implementation 'com.google.protobuf:protobuf-java:3.23.4'

// gRPC 관련 필요한 의존성
implementation 'javax.annotation:javax.annotation-api:1.3.2'
implementation 'com.google.code.gson:gson:2.8.9'


compileOnly 'org.projectlombok:lombok'
runtimeOnly 'com.mysql:mysql-connector-j'
annotationProcessor 'org.springframework.boot:spring-boot-configuration-processor'
Expand All @@ -53,3 +66,29 @@ dependencies {
tasks.named('test') {
useJUnitPlatform()
}

protobuf {
protoc {
artifact = 'com.google.protobuf:protoc:3.23.4'
}
plugins {
grpc {
artifact = 'io.grpc:protoc-gen-grpc-java:1.66.0'
}
}
generateProtoTasks {
all().forEach { task ->
task.plugins {
grpc {}
}
}
}
}

sourceSets {
main {
java {
srcDirs 'build/generated/source/proto/main/java', 'build/generated/source/proto/main/grpc'
}
}
}
10 changes: 0 additions & 10 deletions src/main/java/com/splanet/splanet/config/WebSocketConfig.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package com.splanet.splanet.config;

import com.splanet.splanet.core.handler.SpeechWebSocketHandler;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean;

@Configuration
@EnableWebSocket
Expand All @@ -23,12 +21,4 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(speechWebSocketHandler, "/ws/stt")
.setAllowedOrigins("*");
}

@Bean
public ServletServerContainerFactoryBean configureWebSocketContainer() {
ServletServerContainerFactoryBean factory = new ServletServerContainerFactoryBean();
factory.setMaxBinaryMessageBufferSize(256 * 1024); //바이너리 버퍼 크기 지정 16KB
factory.setMaxTextMessageBufferSize(256 * 1024); //텍스트 버퍼 크기 지정 16KB
return factory;
}
}
Original file line number Diff line number Diff line change
@@ -1,78 +1,100 @@
package com.splanet.splanet.core.handler;

import com.splanet.splanet.stt.service.ClovaSpeechService;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.protobuf.ByteString;
import com.nbp.cdncp.nest.grpc.proto.v1.NestResponse;
import com.splanet.splanet.stt.service.ClovaSpeechGrpcService;
import io.grpc.stub.StreamObserver;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.BinaryWebSocketHandler;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Map;

@Component
public class SpeechWebSocketHandler extends BinaryWebSocketHandler {

private final ClovaSpeechService clovaSpeechService;
private List<byte[]> audioDataBuffer = new ArrayList<>();
private static final int MINIMUM_AUDIO_SIZE = 64000; // 최소 데이터 크기를 96KB로 설정 (약 3초 분량)
private final ClovaSpeechGrpcService clovaSpeechGrpcService;
private final Map<String, StreamObserver<ByteString>> clientObservers = new ConcurrentHashMap<>();

public SpeechWebSocketHandler(ClovaSpeechService clovaSpeechService) {
this.clovaSpeechService = clovaSpeechService;
public SpeechWebSocketHandler(ClovaSpeechGrpcService clovaSpeechGrpcService) {
this.clovaSpeechGrpcService = clovaSpeechGrpcService;
}

@Override
protected synchronized void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
session.setBinaryMessageSizeLimit(256 * 1024); // 메시지 크기 제한을 256KB로 설정
byte[] audioData = message.getPayload().array();
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// 세션이 열릴 때마다 새로운 gRPC 스트림을 생성
StreamObserver<NestResponse> responseObserver = new StreamObserver<NestResponse>() {
@Override
public void onNext(NestResponse value) {
// 서버로부터 받은 응답 처리
try {
String contents = value.getContents(); // JSON 문자열

// 오디오 데이터를 버퍼에 추가
audioDataBuffer.add(audioData);
// JSON 파싱
JsonParser parser = new JsonParser();
JsonObject jsonObject = parser.parse(contents).getAsJsonObject();

// 누적된 오디오 데이터 크기 계산
int totalSize = audioDataBuffer.stream().mapToInt(arr -> arr.length).sum();
if (jsonObject.has("transcription")) {
JsonObject transcription = jsonObject.getAsJsonObject("transcription");
String text = transcription.get("text").getAsString();
// 클라이언트로 text 필드만 전송
session.sendMessage(new TextMessage(text));
}
} catch (Exception e) {
e.printStackTrace();
}
}

// 현재 누적된 데이터 크기를 로그로 출력
System.out.println("현재 누적된 데이터 크기: " + totalSize + " bytes");
@Override
public void onError(Throwable t) {
t.printStackTrace();
try {
session.sendMessage(new TextMessage("오류 발생: " + t.getMessage()));
} catch (IOException e) {
e.printStackTrace();
}
}

// 오디오 데이터가 충분히 쌓였을 때만 CLOVA API로 전송
if (totalSize >= MINIMUM_AUDIO_SIZE) {
byte[] fullAudioData = mergeAudioData();
try {
// CLOVA API로 전송
String transcript = clovaSpeechService.recognize(fullAudioData);
session.sendMessage(new TextMessage(transcript));
// 인식에 성공했으므로 버퍼를 초기화
audioDataBuffer.clear();
System.out.println("인식 성공: 버퍼를 초기화합니다.");
} catch (Exception e) {
e.printStackTrace();
// STT007 오류 발생 시 버퍼를 유지하고 데이터 수집 계속
if (e.getMessage().contains("STT007")) {
System.err.println("오류 발생: STT007 - 데이터가 너무 작습니다. 더 많은 데이터를 수집 중...");
// 버퍼를 유지하여 다음 데이터를 기다립니다.
} else {
// 다른 오류 발생 시 버퍼를 초기화하고 오류 메시지 전송
audioDataBuffer.clear();
session.sendMessage(new TextMessage("오류 발생: " + e.getMessage()));
System.err.println("오류 발생: " + e.getMessage() + " - 버퍼를 초기화합니다.");
@Override
public void onCompleted() {
// 스트림 완료 처리
try {
session.close();
} catch (IOException e) {
e.printStackTrace();
}
}
} else {
// 아직 데이터가 충분하지 않으면 아무 작업도 하지 않음
System.out.println("데이터가 아직 충분하지 않음");
};

// 오디오 데이터를 전송할 StreamObserver 생성
StreamObserver<ByteString> requestObserver = clovaSpeechGrpcService.recognize(responseObserver);
clientObservers.put(session.getId(), requestObserver);
}

@Override
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
// 클라이언트로부터 받은 오디오 데이터를 gRPC 서비스로 전달
StreamObserver<ByteString> requestObserver = clientObservers.get(session.getId());
if (requestObserver != null) {
byte[] audioData = message.getPayload().array();
ByteString audioChunk = ByteString.copyFrom(audioData);
requestObserver.onNext(audioChunk);
}
}

// 누적된 오디오 데이터를 병합하는 메서드
private byte[] mergeAudioData() {
int totalLength = audioDataBuffer.stream().mapToInt(arr -> arr.length).sum();
byte[] mergedData = new byte[totalLength];
int currentIndex = 0;
for (byte[] data : audioDataBuffer) {
System.arraycopy(data, 0, mergedData, currentIndex, data.length);
currentIndex += data.length;
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
// 세션이 종료되면 gRPC 스트림도 종료
StreamObserver<ByteString> requestObserver = clientObservers.remove(session.getId());
if (requestObserver != null) {
requestObserver.onCompleted();
}
return mergedData;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
@Getter
@Setter
@Configuration
@ConfigurationProperties(prefix = "clova")
@ConfigurationProperties(prefix = "clova.speech")
public class ClovaProperties {
private String clientId;
private String clientSecret;
private String url;
private String language;
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package com.splanet.splanet.stt.service;

import com.google.protobuf.ByteString;
import com.nbp.cdncp.nest.grpc.proto.v1.*;
import com.splanet.splanet.core.properties.ClovaProperties;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import org.springframework.stereotype.Service;

@Service
public class ClovaSpeechGrpcService implements ClovaSpeechService {

private final NestServiceGrpc.NestServiceStub nestServiceStub;
private final ClovaProperties clovaProperties;

public ClovaSpeechGrpcService(ClovaProperties clovaProperties) {
this.clovaProperties = clovaProperties;

// gRPC 채널 생성
ManagedChannel channel = NettyChannelBuilder
.forAddress("clovaspeech-gw.ncloud.com", 50051)
.useTransportSecurity()
.build();

// Stub 생성 및 인증 정보 설정
NestServiceGrpc.NestServiceStub stub = NestServiceGrpc.newStub(channel);
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER), "Bearer " + clovaProperties.getClientSecret());
this.nestServiceStub = MetadataUtils.attachHeaders(stub, metadata);
}

@Override
public StreamObserver<ByteString> recognize(StreamObserver<NestResponse> responseObserver) {
StreamObserver<NestRequest> requestObserver = nestServiceStub.recognize(responseObserver);

// Config 메시지 전송
requestObserver.onNext(createConfigRequest(clovaProperties.getLanguage()));

return new StreamObserver<ByteString>() {
private int sequenceId = 0;

@Override
public void onNext(ByteString audioChunk) {
NestRequest dataRequest = createDataRequest(audioChunk, sequenceId, false);
requestObserver.onNext(dataRequest);
sequenceId++;
}

@Override
public void onError(Throwable t) {
t.printStackTrace();
requestObserver.onError(t);
}

@Override
public void onCompleted() {
requestObserver.onCompleted();
}
};
}

// Config 설정
private NestRequest createConfigRequest(String language) {
NestConfig config = NestConfig.newBuilder()
.setConfig("{\"transcription\":{\"language\":\"" + language + "\"}}")
.build();

return NestRequest.newBuilder()
.setType(RequestType.CONFIG)
.setConfig(config)
.build();
}

// 데이터 구성
private NestRequest createDataRequest(ByteString audioChunk, int sequenceId, boolean epFlag) {
NestData data = NestData.newBuilder()
.setChunk(audioChunk)
.setExtraContents("{\"seqId\":" + sequenceId + ",\"epFlag\":" + epFlag + "}")
.build();

return NestRequest.newBuilder()
.setType(RequestType.DATA)
.setData(data)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.splanet.splanet.stt.service;

import org.springframework.web.multipart.MultipartFile;
import com.google.protobuf.ByteString;
import com.nbp.cdncp.nest.grpc.proto.v1.NestResponse;
import io.grpc.stub.StreamObserver;

public interface ClovaSpeechService {
String recognize(byte[] audioBytes);
StreamObserver<ByteString> recognize(StreamObserver<NestResponse> responseObserver);
}
Loading

0 comments on commit dae3889

Please sign in to comment.