diff --git a/cli/src/main/java/com/box/l10n/mojito/cli/command/AIRepositoryLocaleOverrideCommand.java b/cli/src/main/java/com/box/l10n/mojito/cli/command/AIRepositoryLocaleOverrideCommand.java new file mode 100644 index 0000000000..448b21b2a8 --- /dev/null +++ b/cli/src/main/java/com/box/l10n/mojito/cli/command/AIRepositoryLocaleOverrideCommand.java @@ -0,0 +1,71 @@ +package com.box.l10n.mojito.cli.command; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; +import com.box.l10n.mojito.rest.client.AIServiceClient; +import com.box.l10n.mojito.rest.entity.AITranslationLocalePromptOverridesRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Scope; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; + +@Component +@Scope("prototype") +@Parameters( + commandNames = {"ai-repository-locale-prompt-override"}, + commandDescription = + "Create/Update/Delete locale translation AI prompt overrides for a given repository") +public class AIRepositoryLocaleOverrideCommand extends Command { + + static Logger logger = LoggerFactory.getLogger(CreateAIPromptCommand.class); + + @Autowired AIServiceClient aiServiceClient; + + @Parameter( + names = {"--repository-name", "-r"}, + required = true, + description = "Repository name") + String repository; + + @Parameter( + names = {"--ai-prompt-id", "-aid"}, + required = true, + description = "AI prompt id") + Long aiPromptId; + + @Parameter( + names = {"--locales", "-l"}, + required = true, + description = "Locale BCP-47 tags provided in a comma separated list") + String locales; + + @Parameter( + names = {"--disabled"}, + required = false, + description = + "Indicates if the locales are disabled for AI translation. Setting to false means AI translation will be skipped for the relevant locales. Default is false") + boolean disabled = false; + + @Parameter( + names = {"--delete"}, + required = false, + description = "Delete the AI prompt overrides for the given locales") + boolean isDelete = false; + + @Override + protected void execute() throws CommandException { + AITranslationLocalePromptOverridesRequest aiTranslationLocalePromptOverridesRequest = + new AITranslationLocalePromptOverridesRequest( + repository, StringUtils.commaDelimitedListToSet(locales), aiPromptId, disabled); + + if (isDelete) { + aiServiceClient.deleteRepositoryLocalePromptOverrides( + aiTranslationLocalePromptOverridesRequest); + } else { + aiServiceClient.createOrUpdateRepositoryLocalePromptOverrides( + aiTranslationLocalePromptOverridesRequest); + } + } +} diff --git a/restclient/src/main/java/com/box/l10n/mojito/rest/client/AIServiceClient.java b/restclient/src/main/java/com/box/l10n/mojito/rest/client/AIServiceClient.java index f71b47fa1c..7f0ea00ac6 100644 --- a/restclient/src/main/java/com/box/l10n/mojito/rest/client/AIServiceClient.java +++ b/restclient/src/main/java/com/box/l10n/mojito/rest/client/AIServiceClient.java @@ -4,9 +4,13 @@ import com.box.l10n.mojito.rest.entity.AICheckResponse; import com.box.l10n.mojito.rest.entity.AIPromptContextMessageCreateRequest; import com.box.l10n.mojito.rest.entity.AIPromptCreateRequest; +import com.box.l10n.mojito.rest.entity.AITranslationLocalePromptOverridesRequest; import com.box.l10n.mojito.rest.entity.OpenAIPrompt; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.stereotype.Component; @Component @@ -64,4 +68,25 @@ public void addPromptToRepository(Long promptId, String repositoryName, String p null, Void.class); } + + public void createOrUpdateRepositoryLocalePromptOverrides( + AITranslationLocalePromptOverridesRequest aiTranslationLocalePromptOverridesRequest) { + logger.debug("Received request to create or update repository locale prompt overrides"); + authenticatedRestTemplate.postForObject( + getBasePathForEntity() + "/prompts/translation/locale/overrides", + aiTranslationLocalePromptOverridesRequest, + Void.class); + } + + public void deleteRepositoryLocalePromptOverrides( + AITranslationLocalePromptOverridesRequest aiTranslationLocalePromptOverridesRequest) { + logger.debug("Received request to delete repository locale prompt overrides"); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + + HttpEntity entity = + new HttpEntity<>(aiTranslationLocalePromptOverridesRequest, headers); + authenticatedRestTemplate.deleteForObject( + getBasePathForEntity() + "/prompts/translation/locale/overrides", entity, Void.class); + } } diff --git a/restclient/src/main/java/com/box/l10n/mojito/rest/entity/AITranslationLocalePromptOverridesRequest.java b/restclient/src/main/java/com/box/l10n/mojito/rest/entity/AITranslationLocalePromptOverridesRequest.java new file mode 100644 index 0000000000..9b868edbd9 --- /dev/null +++ b/restclient/src/main/java/com/box/l10n/mojito/rest/entity/AITranslationLocalePromptOverridesRequest.java @@ -0,0 +1,51 @@ +package com.box.l10n.mojito.rest.entity; + +import java.util.Set; + +public class AITranslationLocalePromptOverridesRequest { + + private String repositoryName; + private Set locales; + private Long aiPromptId; + private boolean disabled; + + public AITranslationLocalePromptOverridesRequest( + String repositoryName, Set locales, Long aiPromptId, boolean disabled) { + this.repositoryName = repositoryName; + this.locales = locales; + this.aiPromptId = aiPromptId; + this.disabled = disabled; + } + + public String getRepositoryName() { + return repositoryName; + } + + public void setRepositoryName(String repositoryName) { + this.repositoryName = repositoryName; + } + + public Set getLocales() { + return locales; + } + + public void setLocales(Set locales) { + this.locales = locales; + } + + public Long getAiPromptId() { + return aiPromptId; + } + + public void setAiPromptId(Long aiPromptId) { + this.aiPromptId = aiPromptId; + } + + public boolean isDisabled() { + return disabled; + } + + public void setDisabled(boolean disabled) { + this.disabled = disabled; + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIPromptWS.java b/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIPromptWS.java index ef0f238a0f..389f41a767 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIPromptWS.java +++ b/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIPromptWS.java @@ -8,6 +8,8 @@ import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; @@ -91,6 +93,33 @@ public void deletePromptMessage(@PathVariable("context_message_id") Long context promptService.deletePromptContextMessage(contextMessageId); } + @RequestMapping( + value = "/api/ai/prompts/translation/locale/overrides", + method = RequestMethod.POST) + @Timed("AIWS.createOrUpdateRepositoryLocalePromptOverrides") + public ResponseEntity createOrUpdateRepositoryLocalePromptOverrides( + @RequestBody AITranslationLocalePromptOverridesRequest request) { + logger.debug("Received request to create or update repository locale prompt overrides"); + promptService.createOrUpdateRepositoryLocaleTranslationPromptOverrides( + request.getRepositoryName(), + request.getLocales(), + request.getAiPromptId(), + request.isDisabled()); + return new ResponseEntity<>(HttpStatus.CREATED); + } + + @RequestMapping( + value = "/api/ai/prompts/translation/locale/overrides", + method = RequestMethod.DELETE) + @Timed("AIWS.deleteRepositoryLocalePromptOverrides") + public ResponseEntity deleteRepositoryLocalePromptOverrides( + @RequestBody AITranslationLocalePromptOverridesRequest request) { + logger.debug("Received request to delete repository locale prompt overrides"); + promptService.deleteRepositoryLocaleTranslationPromptOverride( + request.getRepositoryName(), request.getLocales()); + return new ResponseEntity<>(HttpStatus.NO_CONTENT); + } + private static AIPrompt buildOpenAIPromptDTO(com.box.l10n.mojito.entity.AIPrompt prompt) { AIPrompt AIPrompt = new AIPrompt(); AIPrompt.setSystemPrompt(prompt.getSystemPrompt()); diff --git a/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AITranslationLocalePromptOverridesRequest.java b/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AITranslationLocalePromptOverridesRequest.java new file mode 100644 index 0000000000..45ec8759ca --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AITranslationLocalePromptOverridesRequest.java @@ -0,0 +1,45 @@ +package com.box.l10n.mojito.rest.ai; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import java.util.Set; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class AITranslationLocalePromptOverridesRequest { + + private String repositoryName; + private Set locales; + private Long aiPromptId; + private boolean disabled; + + public String getRepositoryName() { + return repositoryName; + } + + public void setRepositoryName(String repositoryName) { + this.repositoryName = repositoryName; + } + + public Set getLocales() { + return locales; + } + + public void setLocales(Set locales) { + this.locales = locales; + } + + public Long getAiPromptId() { + return aiPromptId; + } + + public void setAiPromptId(Long aiPromptId) { + this.aiPromptId = aiPromptId; + } + + public boolean isDisabled() { + return disabled; + } + + public void setDisabled(boolean disabled) { + this.disabled = disabled; + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMPromptService.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMPromptService.java index 8c072cf0b9..3f8d1ca702 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMPromptService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMPromptService.java @@ -4,18 +4,25 @@ import com.box.l10n.mojito.entity.AIPrompt; import com.box.l10n.mojito.entity.AIPromptContextMessage; import com.box.l10n.mojito.entity.AIPromptType; +import com.box.l10n.mojito.entity.Locale; import com.box.l10n.mojito.entity.PromptType; import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.RepositoryLocale; import com.box.l10n.mojito.entity.RepositoryLocaleAIPrompt; import com.box.l10n.mojito.rest.ai.AIException; import com.box.l10n.mojito.rest.ai.AIPromptContextMessageCreateRequest; import com.box.l10n.mojito.rest.ai.AIPromptCreateRequest; import com.box.l10n.mojito.service.ai.openai.OpenAIPromptContextMessageType; +import com.box.l10n.mojito.service.locale.LocaleRepository; import com.box.l10n.mojito.service.repository.RepositoryRepository; import io.micrometer.core.annotation.Timed; import jakarta.transaction.Transactional; import java.time.ZonedDateTime; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -38,17 +45,13 @@ public class LLMPromptService implements PromptService { @Autowired AIPromptContextMessageRepository aiPromptContextMessageRepository; + @Autowired LocaleRepository localeRepository; + @Timed("LLMPromptService.createPrompt") @Transactional public Long createPrompt(AIPromptCreateRequest AIPromptCreateRequest) { - Repository repository = - repositoryRepository.findByName(AIPromptCreateRequest.getRepositoryName()); - - if (repository == null) { - logger.error("Repository not found: {}", AIPromptCreateRequest.getRepositoryName()); - throw new AIException("Repository not found: " + AIPromptCreateRequest.getRepositoryName()); - } + Repository repository = getRepository(AIPromptCreateRequest.getRepositoryName()); AIPromptType aiPromptType = aiPromptTypeRepository.findByName(AIPromptCreateRequest.getPromptType()); @@ -82,12 +85,7 @@ public Long createPrompt(AIPromptCreateRequest AIPromptCreateRequest) { @Timed("LLMPromptService.addPromptToRepository") public void addPromptToRepository(Long promptId, String repositoryName, String promptType) { - Repository repository = repositoryRepository.findByName(repositoryName); - - if (repository == null) { - logger.error("Repository not found: {}", repositoryName); - throw new AIException("Repository not found: " + repositoryName); - } + Repository repository = getRepository(repositoryName); AIPromptType aiPromptType = aiPromptTypeRepository.findByName(promptType); if (aiPromptType == null) { @@ -188,11 +186,103 @@ public void deletePromptContextMessage(Long promptMessageId) { @Timed("LLMPromptService.getAllActivePromptsForRepository") public List getAllActivePromptsForRepository(String repositoryName) { + Repository repository = getRepository(repositoryName); + return aiPromptRepository.findByRepositoryIdAndDeletedFalse(repository.getId()); + } + + @Override + @Timed("LLMPromptService.handleRepositoryLocalePromptOverrides") + @Transactional + public void createOrUpdateRepositoryLocaleTranslationPromptOverrides( + String repositoryName, Set bcp47Tags, Long aiPromptId, boolean disabled) { + Repository repository = getRepository(repositoryName); + AIPrompt aiPrompt = + aiPromptRepository + .findById(aiPromptId) + .orElseThrow(() -> new AIException("Prompt not found: " + aiPromptId)); + + validateProvidedLocaleTags(repositoryName, bcp47Tags, getLocalesForRepository(repository)); + bcp47Tags.forEach( + bcp47Tag -> + createOrUpdateRepositoryLocaleAiPrompt( + bcp47Tag, + getBcp47TagRepositoryLocaleAIPromptOverridesMap(bcp47Tags, repository), + repository, + aiPrompt, + disabled)); + } + + @Override + @Transactional + @Timed("LLMPromptService.deleteRepositoryLocalePromptOverride") + public void deleteRepositoryLocaleTranslationPromptOverride( + String repositoryName, Set bcp47Tags) { + Repository repository = getRepository(repositoryName); + validateProvidedLocaleTags(repositoryName, bcp47Tags, getLocalesForRepository(repository)); + repositoryLocaleAIPromptRepository.deleteAll( + getBcp47TagRepositoryLocaleAIPromptOverridesMap(bcp47Tags, repository).values()); + } + + private Repository getRepository(String repositoryName) { Repository repository = repositoryRepository.findByName(repositoryName); if (repository == null) { logger.error("Repository not found: {}", repositoryName); throw new AIException("Repository not found: " + repositoryName); } - return aiPromptRepository.findByRepositoryIdAndDeletedFalse(repository.getId()); + return repository; + } + + private Set getLocalesForRepository(Repository repository) { + return repository.getRepositoryLocales().stream() + .map(RepositoryLocale::getLocale) + .collect(Collectors.toSet()); + } + + private void validateProvidedLocaleTags( + String repositoryName, Set bcp47Tags, Set locales) { + List invalidTags = + bcp47Tags.stream() + .filter( + bcp47Tag -> + locales.stream().noneMatch(locale -> locale.getBcp47Tag().equals(bcp47Tag))) + .toList(); + if (!invalidTags.isEmpty()) { + throw new AIException( + String.format( + "Repository %s is not configured for the following BCP-47 tags: %s", + repositoryName, invalidTags)); + } + } + + private Map getBcp47TagRepositoryLocaleAIPromptOverridesMap( + Set bcp47Tags, Repository repository) { + return repositoryLocaleAIPromptRepository + .getRepositoryLocaleTranslationPromptOverrides(repository) + .orElseGet(List::of) + .stream() + .filter( + repositoryLocaleAIPrompt -> + bcp47Tags.contains(repositoryLocaleAIPrompt.getLocale().getBcp47Tag())) + .collect( + Collectors.toMap( + repoLocaleAiPrompt -> repoLocaleAiPrompt.getLocale().getBcp47Tag(), + Function.identity())); + } + + private void createOrUpdateRepositoryLocaleAiPrompt( + String bcp47Tag, + Map existingLocaleOverridesMap, + Repository repository, + AIPrompt aiPrompt, + boolean disabled) { + RepositoryLocaleAIPrompt repositoryLocaleAIPrompt = existingLocaleOverridesMap.get(bcp47Tag); + if (repositoryLocaleAIPrompt == null) { + repositoryLocaleAIPrompt = new RepositoryLocaleAIPrompt(); + repositoryLocaleAIPrompt.setRepository(repository); + repositoryLocaleAIPrompt.setLocale(localeRepository.findByBcp47Tag(bcp47Tag)); + } + repositoryLocaleAIPrompt.setAiPrompt(aiPrompt); + repositoryLocaleAIPrompt.setDisabled(disabled); + repositoryLocaleAIPromptRepository.save(repositoryLocaleAIPrompt); } } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/PromptService.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/PromptService.java index 38498f055d..5f6232ee9f 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/PromptService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/PromptService.java @@ -4,6 +4,7 @@ import com.box.l10n.mojito.rest.ai.AIPromptContextMessageCreateRequest; import com.box.l10n.mojito.rest.ai.AIPromptCreateRequest; import java.util.List; +import java.util.Set; public interface PromptService { @@ -23,4 +24,10 @@ Long createPromptContextMessage( void deletePromptContextMessage(Long promptContextMessageId); void addPromptToRepository(Long promptId, String repositoryName, String promptType); + + void createOrUpdateRepositoryLocaleTranslationPromptOverrides( + String repositoryName, Set bcp47Tags, Long aiPromptId, boolean disabled); + + void deleteRepositoryLocaleTranslationPromptOverride( + String repositoryName, Set bcp47Tags); } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryLocaleAIPromptRepository.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryLocaleAIPromptRepository.java index c09643ebd8..6bd88052f6 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryLocaleAIPromptRepository.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryLocaleAIPromptRepository.java @@ -1,7 +1,9 @@ package com.box.l10n.mojito.service.ai; +import com.box.l10n.mojito.entity.Repository; import com.box.l10n.mojito.entity.RepositoryLocaleAIPrompt; import java.util.List; +import java.util.Optional; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.Query; import org.springframework.data.repository.query.Param; @@ -24,4 +26,12 @@ Long findCountOfActiveRepositoryPromptsByType( + "WHERE rlap.repository.id = :repositoryId AND aip.deleted = false AND aipt.name = :promptType") List getActivePromptsByRepositoryAndPromptType( @Param("repositoryId") Long repositoryId, @Param("promptType") String promptType); + + @Query( + "SELECT rlap from RepositoryLocaleAIPrompt rlap " + + "JOIN rlap.aiPrompt aip " + + "JOIN aip.promptType aipt " + + "WHERE rlap.repository = :repository AND rlap.locale IS NOT NULL AND aipt.name = 'TRANSLATION'") + Optional> getRepositoryLocaleTranslationPromptOverrides( + @Param("repository") Repository repository); } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJob.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJob.java index d7abbb9d45..f96a7e8a05 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJob.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJob.java @@ -178,13 +178,22 @@ private void translateLocales( executeTranslationPrompt( tmTextUnit, repository, targetLocale, repositoryLocaleAIPrompt)); } else { - logger.debug( - "No active translation prompt found for locale: {}, skipping AI translation.", - targetLocale.getBcp47Tag()); - meterRegistry.counter( - "AITranslateCronJob.translate.noActivePrompt", - Tags.of( - "repository", repository.getName(), "locale", targetLocale.getBcp47Tag())); + if (repositoryLocaleAIPrompt != null && repositoryLocaleAIPrompt.isDisabled()) { + logger.debug( + "AI translation is disabled for locale " + + repositoryLocaleAIPrompt.getLocale().getBcp47Tag() + + " in repository " + + repository.getName() + + ", skipping AI translation."); + } else { + logger.debug( + "No active translation prompt found for locale: {}, skipping AI translation.", + targetLocale.getBcp47Tag()); + meterRegistry.counter( + "AITranslateCronJob.translate.noActivePrompt", + Tags.of( + "repository", repository.getName(), "locale", targetLocale.getBcp47Tag())); + } } } catch (Exception e) { logger.error( diff --git a/webapp/src/test/java/com/box/l10n/mojito/service/ai/LLMPromptServiceTest.java b/webapp/src/test/java/com/box/l10n/mojito/service/ai/LLMPromptServiceTest.java index 717c058e4d..b121f62595 100644 --- a/webapp/src/test/java/com/box/l10n/mojito/service/ai/LLMPromptServiceTest.java +++ b/webapp/src/test/java/com/box/l10n/mojito/service/ai/LLMPromptServiceTest.java @@ -1,5 +1,6 @@ package com.box.l10n.mojito.service.ai; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -11,12 +12,18 @@ import com.box.l10n.mojito.entity.AIPrompt; import com.box.l10n.mojito.entity.AIPromptType; +import com.box.l10n.mojito.entity.Locale; import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.RepositoryLocale; import com.box.l10n.mojito.entity.RepositoryLocaleAIPrompt; import com.box.l10n.mojito.rest.ai.AIException; import com.box.l10n.mojito.rest.ai.AIPromptCreateRequest; +import com.box.l10n.mojito.service.locale.LocaleRepository; import com.box.l10n.mojito.service.repository.RepositoryRepository; +import java.util.Collection; +import java.util.List; import java.util.Optional; +import java.util.Set; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; @@ -35,8 +42,12 @@ public class LLMPromptServiceTest { @Mock RepositoryRepository repositoryRepository; + @Mock LocaleRepository localeRepository; + @Captor ArgumentCaptor repositoryAIPromptCaptor; + @Captor ArgumentCaptor> repositoryAIPromptCollectionCaptor; + @InjectMocks LLMPromptService LLMPromptService; @BeforeEach @@ -154,4 +165,113 @@ void testAddPromptToRepository() { assertEquals(1L, repositoryAIPromptCaptor.getValue().getAiPrompt().getId()); assertEquals(2L, repositoryAIPromptCaptor.getValue().getRepository().getId()); } + + @Test + void testCreateAiPromptOverride() { + Repository repository = new Repository(); + repository.setId(1L); + AIPrompt aiPrompt = new AIPrompt(); + aiPrompt.setId(1L); + Locale frLocale = new Locale(); + frLocale.setId(1L); + frLocale.setBcp47Tag("fr-FR"); + Locale deLocale = new Locale(); + deLocale.setId(2L); + deLocale.setBcp47Tag("de-DE"); + RepositoryLocale repositoryLocaleFr = new RepositoryLocale(); + repositoryLocaleFr.setLocale(frLocale); + RepositoryLocale repositoryLocaleDe = new RepositoryLocale(); + repositoryLocaleDe.setLocale(deLocale); + repository.setRepositoryLocales(Set.of(repositoryLocaleFr, repositoryLocaleDe)); + when(repositoryRepository.findByName("testRepo")).thenReturn(repository); + when(aiPromptRepository.findById(1L)).thenReturn(Optional.of(aiPrompt)); + when(localeRepository.findByBcp47Tag("fr-FR")).thenReturn(frLocale); + when(localeRepository.findByBcp47Tag("de-DE")).thenReturn(deLocale); + + LLMPromptService.createOrUpdateRepositoryLocaleTranslationPromptOverrides( + "testRepo", Set.of("fr-FR", "de-DE"), 1L, true); + + verify(repositoryLocaleAIPromptRepository, times(2)).save(repositoryAIPromptCaptor.capture()); + assertThat(repositoryAIPromptCaptor.getAllValues()) + .extracting(RepositoryLocaleAIPrompt::getAiPrompt) + .extracting(AIPrompt::getId) + .containsExactlyInAnyOrder(1L, 1L); + assertThat(repositoryAIPromptCaptor.getAllValues()) + .extracting(RepositoryLocaleAIPrompt::getLocale) + .extracting(Locale::getId) + .containsExactlyInAnyOrder(1L, 2L); + assertThat(repositoryAIPromptCaptor.getAllValues()) + .extracting(RepositoryLocaleAIPrompt::isDisabled) + .containsExactlyInAnyOrder(true, true); + } + + @Test + void testDeleteLocaleOverrides() { + Repository repository = new Repository(); + repository.setId(1L); + Locale frLocale = new Locale(); + frLocale.setId(1L); + frLocale.setBcp47Tag("fr-FR"); + Locale deLocale = new Locale(); + deLocale.setId(2L); + deLocale.setBcp47Tag("de-DE"); + Locale itLocale = new Locale(); + itLocale.setId(3L); + itLocale.setBcp47Tag("it-IT"); + RepositoryLocale repositoryLocaleFr = new RepositoryLocale(); + repositoryLocaleFr.setLocale(frLocale); + RepositoryLocale repositoryLocaleDe = new RepositoryLocale(); + repositoryLocaleDe.setLocale(deLocale); + RepositoryLocale repositoryLocaleIt = new RepositoryLocale(); + repositoryLocaleIt.setLocale(itLocale); + repository.setRepositoryLocales( + Set.of(repositoryLocaleFr, repositoryLocaleDe, repositoryLocaleIt)); + RepositoryLocaleAIPrompt repositoryLocaleAIPromptFr = new RepositoryLocaleAIPrompt(); + repositoryLocaleAIPromptFr.setId(1L); + repositoryLocaleAIPromptFr.setRepository(repository); + repositoryLocaleAIPromptFr.setLocale(frLocale); + RepositoryLocaleAIPrompt repositoryLocaleAIPromptDe = new RepositoryLocaleAIPrompt(); + repositoryLocaleAIPromptDe.setId(2L); + repositoryLocaleAIPromptDe.setRepository(repository); + repositoryLocaleAIPromptDe.setLocale(deLocale); + RepositoryLocaleAIPrompt repositoryLocaleAIPromptIt = new RepositoryLocaleAIPrompt(); + repositoryLocaleAIPromptIt.setId(3L); + repositoryLocaleAIPromptIt.setRepository(repository); + repositoryLocaleAIPromptIt.setLocale(itLocale); + + when(repositoryRepository.findByName("testRepo")).thenReturn(repository); + when(repositoryLocaleAIPromptRepository.getRepositoryLocaleTranslationPromptOverrides( + repository)) + .thenReturn( + Optional.of( + List.of( + repositoryLocaleAIPromptDe, + repositoryLocaleAIPromptFr, + repositoryLocaleAIPromptIt))); + + LLMPromptService.deleteRepositoryLocaleTranslationPromptOverride( + "testRepo", Set.of("fr-FR", "de-DE")); + + verify(repositoryLocaleAIPromptRepository, times(1)) + .deleteAll(repositoryAIPromptCollectionCaptor.capture()); + assertThat(repositoryAIPromptCollectionCaptor.getValue()) + .extracting(RepositoryLocaleAIPrompt::getId) + .containsExactlyInAnyOrder(1L, 2L); + } + + @Test + void testAiExceptionThrownIfLocalesNotConfiguredForRepositoryLocalePromptOverride() { + AIPrompt aiPrompt = new AIPrompt(); + aiPrompt.setId(1L); + Repository repository = new Repository(); + repository.setId(2L); + AIPromptType promptType = new AIPromptType(); + promptType.setId(3L); + when(repositoryRepository.findByName("testRepo")).thenReturn(repository); + assertThrows( + AIException.class, + () -> + LLMPromptService.deleteRepositoryLocaleTranslationPromptOverride( + "testRepo", Set.of("fr-FR", "de-DE"))); + } }