Skip to content

Commit

Permalink
Merge pull request #7857 from ita-social-projects/openAI
Browse files Browse the repository at this point in the history
Integrate AI for Predictive Analysis
  • Loading branch information
Maryna-511750 authored Dec 8, 2024
2 parents a81cc19 + 55132b3 commit e47ea89
Show file tree
Hide file tree
Showing 19 changed files with 474 additions and 77 deletions.
4 changes: 4 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@
<artifactId>jsoup</artifactId>
<version>1.18.1</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/java/greencity/config/SecurityConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Excepti
FRIENDS + "/user/{userId}",
"/habit/assign/confirm/{habitAssignId}",
"/database/backup",
"/database/backupFiles")
"/database/backupFiles",
"/ai/**")
.permitAll()
.requestMatchers(HttpMethod.POST,
SUBSCRIPTIONS,
Expand Down
46 changes: 46 additions & 0 deletions core/src/main/java/greencity/controller/AIController.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package greencity.controller;

import greencity.annotations.ApiLocale;
import greencity.annotations.CurrentUser;
import greencity.constant.HttpStatuses;
import greencity.dto.user.UserVO;
import greencity.service.AIService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.ExampleObject;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import lombok.AllArgsConstructor;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.Locale;

@RestController
@RequestMapping("/ai")
@AllArgsConstructor
public class AIController {
private final AIService aiService;

@Operation(summary = "Makes predictions about the environmental impact of the current user "
+ "based on the analysis of their habits and habit duration.")
@ApiResponses(value = {
@ApiResponse(responseCode = "200", description = HttpStatuses.OK),
@ApiResponse(responseCode = "400", description = HttpStatuses.BAD_REQUEST,
content = @Content(examples = @ExampleObject(HttpStatuses.BAD_REQUEST))),
@ApiResponse(responseCode = "401", description = HttpStatuses.UNAUTHORIZED,
content = @Content(examples = @ExampleObject(HttpStatuses.UNAUTHORIZED))),
@ApiResponse(responseCode = "404", description = HttpStatuses.NOT_FOUND,
content = @Content(examples = @ExampleObject(HttpStatuses.NOT_FOUND)))
})
@ApiLocale
@GetMapping("/forecast")
public ResponseEntity<String> forecast(@Parameter(hidden = true) @CurrentUser UserVO userVO,
@Parameter(hidden = true) Locale locale) {
return ResponseEntity.status(HttpStatus.OK)
.body(aiService.getForecast(userVO.getId(), locale.getDisplayLanguage()));
}
}
6 changes: 5 additions & 1 deletion core/src/main/resources/application-dev.properties
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,8 @@ pg.dump.path=pg_dump
google.maps.api.key=${GOOGLE_MAP_API_KEY:AIzaSyCU0ArzZlZ3n0pLq4o9MJy29LPT5DBMk4Y}

# Cron
cron.sendContentToSubscribers=${CRON_SEND_CONTENT_TO_SUBSCRIBERS:0 0 20 * * SAT}
cron.sendContentToSubscribers=${CRON_SEND_CONTENT_TO_SUBSCRIBERS:0 0 20 * * SAT}

# OpenAI
openai.api.key=${OPEN_AI_API_KEY:sk-proj-nCJktK9QTYEogyNa6mGghwftDa4E6kA}
openai.api.url=https://api.openai.com/v1/chat/completions
19 changes: 0 additions & 19 deletions core/src/test/java/greencity/IntegrationTestBase.java

This file was deleted.

55 changes: 55 additions & 0 deletions core/src/test/java/greencity/controller/AIControllerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package greencity.controller;

import greencity.service.AIService;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.springframework.context.annotation.Import;
import org.springframework.data.web.PageableHandlerMethodArgumentResolver;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import java.security.Principal;
import java.util.Locale;
import static greencity.ModelUtils.getPrincipal;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
@Import(AIController.class)
class AIControllerTest {
@Mock
private AIService aiService;
@InjectMocks
private AIController aiController;
private MockMvc mockMvc;
private Principal principal = getPrincipal();

@BeforeEach
void setup() {
this.mockMvc = MockMvcBuilders.standaloneSetup(aiController)
.setCustomArgumentResolvers(new PageableHandlerMethodArgumentResolver())
.build();
}

@Test
void forecast_ReturnsForecast_FromAIService() throws Exception {
Locale testLocale = Locale.forLanguageTag("en");

mockMvc.perform(get("/ai/forecast")
.principal(principal)
.locale(testLocale))
.andExpect(status().isOk());

verify(aiService, times(1)).getForecast(any(), eq(testLocale.getDisplayLanguage()));
}
}
35 changes: 0 additions & 35 deletions core/src/test/java/greencity/repository/LanguageRepoTest.java

This file was deleted.

21 changes: 0 additions & 21 deletions core/src/test/java/greencity/repository/ModelUtils.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,5 @@ public class ErrorMessage {
public static final String INVITATION_ALREADY_EXIST = "Invitation already exist";
public static final String INVALID_DURATION_BETWEEN_START_AND_FINISH = "Invalid duration between start and finish";
public static final String PAGE_NOT_FOUND_MESSAGE = "Requested page %d exceeds total pages %d.";
public static final String OPEN_AI_IS_NOT_RESPONDING = "Could not get a response from OpenAI.";
}
11 changes: 11 additions & 0 deletions service-api/src/main/java/greencity/constant/OpenAIRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package greencity.constant;

import lombok.experimental.UtilityClass;

@UtilityClass
public class OpenAIRequest {
public static final String FORECAST =
"respond in a personalized manner, concisely, and accurately. Use numbers to present approximate data. "
+ "Provide a forecast of this person's impact on the global environment, considering their habits "
+ "over the specified number of days:";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package greencity.dto.habit;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;

@Data
@Builder
@AllArgsConstructor
public class DurationHabitDto {
String description;
Long duration;
}
15 changes: 15 additions & 0 deletions service-api/src/main/java/greencity/service/AIService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package greencity.service;

/**
* Interface for interacting with an AI-based forecasting service.
*/
public interface AIService {
/**
* Retrieves a forecast for a user based on their ID and preferred language.
*
* @param userId The ID of the user for whom the forecast is being requested.
* @param language The preferred language for the forecast response.
* @return The forecast as a string in the specified language.
*/
String getForecast(Long userId, String language);
}
15 changes: 15 additions & 0 deletions service-api/src/main/java/greencity/service/OpenAIService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package greencity.service;

/**
* Interface for interacting with the OpenAI API. The purpose of this interface
* is to send requests to the OpenAI service and receive responses.
*/
public interface OpenAIService {
/**
* Sends a request to the OpenAI API and returns the response as a string.
*
* @param request The request as a string to be sent to the OpenAI API.
* @return The response from the service as a string.
*/
String makeRequest(String request);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package greencity.mapping;

import greencity.dto.habit.DurationHabitDto;
import greencity.entity.HabitAssign;
import greencity.enums.HabitAssignStatus;
import org.modelmapper.AbstractConverter;
import org.springframework.stereotype.Component;
import java.time.LocalDate;
import java.time.temporal.ChronoUnit;

@Component
public class DurationHabitDtoMapper extends AbstractConverter<HabitAssign, DurationHabitDto> {
@Override
protected DurationHabitDto convert(HabitAssign habitAssign) {
return DurationHabitDto.builder()
.duration(habitAssign.getStatus() == HabitAssignStatus.INPROGRESS ? habitAssign.getWorkingDays()
: ChronoUnit.DAYS.between(habitAssign.getCreateDate().toLocalDate(), LocalDate.now()))
.description(habitAssign.getHabit().getHabitTranslations().getFirst().getName())
.build();
}
}
28 changes: 28 additions & 0 deletions service/src/main/java/greencity/service/AIServiceImpl.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package greencity.service;

import greencity.constant.OpenAIRequest;
import greencity.dto.habit.DurationHabitDto;
import greencity.entity.HabitAssign;
import greencity.repository.HabitAssignRepo;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.modelmapper.ModelMapper;
import org.springframework.stereotype.Service;
import java.util.List;

@Slf4j
@Service
@AllArgsConstructor
public class AIServiceImpl implements AIService {
private final OpenAIService openAIService;
private final HabitAssignRepo habitAssignRepo;
private final ModelMapper modelMapper;

@Override
public String getForecast(Long userId, String language) {
List<HabitAssign> habitAssigns = habitAssignRepo.findAllByUserId(userId);
List<DurationHabitDto> durationHabitDtos = habitAssigns.stream()
.map(habitAssign -> modelMapper.map(habitAssign, DurationHabitDto.class)).toList();
return openAIService.makeRequest(language + OpenAIRequest.FORECAST + durationHabitDtos);
}
}
69 changes: 69 additions & 0 deletions service/src/main/java/greencity/service/OpenAIServiceImpl.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package greencity.service;

import greencity.constant.ErrorMessage;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

@Setter
@Slf4j
@Service
public class OpenAIServiceImpl implements OpenAIService {
@Value("${openai.api.key}")
private String apiKey;
@Value("${openai.api.url}")
private String apiUrl;
private final RestTemplate restTemplate;

public OpenAIServiceImpl(RestTemplate restTemplate) {
this.restTemplate = restTemplate;
}

@Override
public String makeRequest(String prompt) {
HttpHeaders headers = new HttpHeaders();
headers.add("Authorization", "Bearer " + apiKey);
headers.add("Content-Type", "application/json");

Map<String, Object> body = new HashMap<>();
body.put("model", "gpt-4o-mini");
List<Map<String, String>> messages = new ArrayList<>();
messages.add(Map.of("role", "user", "content", prompt));
body.put("messages", messages);
body.put("max_tokens", 450);
body.put("temperature", 0.8);
HttpEntity<Map<String, Object>> request = new HttpEntity<>(body, headers);
try {
ResponseEntity<Map<String, Object>> response = restTemplate.exchange(
apiUrl,
HttpMethod.POST,
request,
new ParameterizedTypeReference<>() {
});

return Optional.ofNullable(response)
.map(ResponseEntity::getBody)
.filter(responseBody -> responseBody.containsKey("choices"))
.map(responseBody -> (List<Map<String, Object>>) responseBody.get("choices"))
.filter(choices -> !choices.isEmpty())
.map(choices -> choices.get(0))
.map(choice -> (Map<String, Object>) choice.get("message"))
.map(message -> (String) message.get("content"))
.orElse(ErrorMessage.OPEN_AI_IS_NOT_RESPONDING);
} catch (Exception e) {
return ErrorMessage.OPEN_AI_IS_NOT_RESPONDING;
}
}
}
Loading

0 comments on commit e47ea89

Please sign in to comment.