Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate AI for Predictive Analysis #7857

Merged
merged 10 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@
<artifactId>jsoup</artifactId>
<version>1.18.1</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/java/greencity/config/SecurityConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Excepti
FRIENDS + "/user/{userId}",
"/habit/assign/confirm/{habitAssignId}",
"/database/backup",
"/database/backupFiles")
"/database/backupFiles",
"/ai/**")
.permitAll()
.requestMatchers(HttpMethod.POST,
SUBSCRIPTIONS,
Expand Down
46 changes: 46 additions & 0 deletions core/src/main/java/greencity/controller/AIController.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package greencity.controller;

import greencity.annotations.ApiLocale;
import greencity.annotations.CurrentUser;
import greencity.constant.HttpStatuses;
import greencity.dto.user.UserVO;
import greencity.service.AIService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.ExampleObject;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import lombok.AllArgsConstructor;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.Locale;

@RestController
@RequestMapping("/ai")
@AllArgsConstructor
public class AIController {
private final AIService aiService;

@Operation(summary = "Makes predictions about the environmental impact of the current user "
+ "based on the analysis of their habits and habit duration.")
@ApiResponses(value = {
@ApiResponse(responseCode = "200", description = HttpStatuses.OK),
@ApiResponse(responseCode = "400", description = HttpStatuses.BAD_REQUEST,
content = @Content(examples = @ExampleObject(HttpStatuses.BAD_REQUEST))),
@ApiResponse(responseCode = "401", description = HttpStatuses.UNAUTHORIZED,
content = @Content(examples = @ExampleObject(HttpStatuses.UNAUTHORIZED))),
@ApiResponse(responseCode = "404", description = HttpStatuses.NOT_FOUND,
content = @Content(examples = @ExampleObject(HttpStatuses.NOT_FOUND)))
})
@ApiLocale
@GetMapping("/forecast")
public ResponseEntity<String> forecast(@Parameter(hidden = true) @CurrentUser UserVO userVO,
@Parameter(hidden = true) Locale locale) {
return ResponseEntity.status(HttpStatus.OK)
.body(aiService.getForecast(userVO.getId(), locale.getDisplayLanguage()));
}
}
6 changes: 5 additions & 1 deletion core/src/main/resources/application-dev.properties
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,8 @@ pg.dump.path=pg_dump
google.maps.api.key=${GOOGLE_MAP_API_KEY:AIzaSyCU0ArzZlZ3n0pLq4o9MJy29LPT5DBMk4Y}

# Cron
cron.sendContentToSubscribers=${CRON_SEND_CONTENT_TO_SUBSCRIBERS:0 0 20 * * SAT}
cron.sendContentToSubscribers=${CRON_SEND_CONTENT_TO_SUBSCRIBERS:0 0 20 * * SAT}

# OpenAI
openai.api.key=${OPEN_AI_API_KEY:sk-proj-nCJktK9QTYEogyNa6mGghwftDa4E6kA}
openai.api.url=https://api.openai.com/v1/chat/completions
19 changes: 0 additions & 19 deletions core/src/test/java/greencity/IntegrationTestBase.java

This file was deleted.

55 changes: 55 additions & 0 deletions core/src/test/java/greencity/controller/AIControllerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package greencity.controller;

import greencity.service.AIService;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.springframework.context.annotation.Import;
import org.springframework.data.web.PageableHandlerMethodArgumentResolver;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import java.security.Principal;
import java.util.Locale;
import static greencity.ModelUtils.getPrincipal;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
@Import(AIController.class)
class AIControllerTest {
@Mock
private AIService aiService;
@InjectMocks
private AIController aiController;
private MockMvc mockMvc;
private Principal principal = getPrincipal();

@BeforeEach
void setup() {
this.mockMvc = MockMvcBuilders.standaloneSetup(aiController)
.setCustomArgumentResolvers(new PageableHandlerMethodArgumentResolver())
.build();
}

@Test
void forecast_ReturnsForecast_FromAIService() throws Exception {
Locale testLocale = Locale.forLanguageTag("en");

mockMvc.perform(get("/ai/forecast")
.principal(principal)
.locale(testLocale))
.andExpect(status().isOk());

verify(aiService, times(1)).getForecast(any(), eq(testLocale.getDisplayLanguage()));
}
}
35 changes: 0 additions & 35 deletions core/src/test/java/greencity/repository/LanguageRepoTest.java

This file was deleted.

21 changes: 0 additions & 21 deletions core/src/test/java/greencity/repository/ModelUtils.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,5 @@ public class ErrorMessage {
public static final String INVITATION_ALREADY_EXIST = "Invitation already exist";
public static final String INVALID_DURATION_BETWEEN_START_AND_FINISH = "Invalid duration between start and finish";
public static final String PAGE_NOT_FOUND_MESSAGE = "Requested page %d exceeds total pages %d.";
public static final String OPEN_AI_IS_NOT_RESPONDING = "Could not get a response from OpenAI.";
}
11 changes: 11 additions & 0 deletions service-api/src/main/java/greencity/constant/OpenAIRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package greencity.constant;

import lombok.experimental.UtilityClass;

@UtilityClass
public class OpenAIRequest {
public static final String FORECAST =
"respond in a personalized manner, concisely, and accurately. Use numbers to present approximate data. "
+ "Provide a forecast of this person's impact on the global environment, considering their habits "
+ "over the specified number of days:";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package greencity.dto.habit;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;

@Data
@Builder
@AllArgsConstructor
public class DurationHabitDto {
String description;
Long duration;
}
15 changes: 15 additions & 0 deletions service-api/src/main/java/greencity/service/AIService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package greencity.service;

/**
* Interface for interacting with an AI-based forecasting service.
*/
public interface AIService {
/**
* Retrieves a forecast for a user based on their ID and preferred language.
*
* @param userId The ID of the user for whom the forecast is being requested.
* @param language The preferred language for the forecast response.
* @return The forecast as a string in the specified language.
*/
String getForecast(Long userId, String language);
}
15 changes: 15 additions & 0 deletions service-api/src/main/java/greencity/service/OpenAIService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package greencity.service;

/**
* Interface for interacting with the OpenAI API. The purpose of this interface
* is to send requests to the OpenAI service and receive responses.
*/
public interface OpenAIService {
/**
* Sends a request to the OpenAI API and returns the response as a string.
*
* @param request The request as a string to be sent to the OpenAI API.
* @return The response from the service as a string.
*/
String makeRequest(String request);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package greencity.mapping;

import greencity.dto.habit.DurationHabitDto;
import greencity.entity.HabitAssign;
import greencity.enums.HabitAssignStatus;
import org.modelmapper.AbstractConverter;
import org.springframework.stereotype.Component;
import java.time.LocalDate;
import java.time.temporal.ChronoUnit;

@Component
public class DurationHabitDtoMapper extends AbstractConverter<HabitAssign, DurationHabitDto> {
@Override
protected DurationHabitDto convert(HabitAssign habitAssign) {
return DurationHabitDto.builder()
.duration(habitAssign.getStatus() == HabitAssignStatus.INPROGRESS ? habitAssign.getWorkingDays()
: ChronoUnit.DAYS.between(habitAssign.getCreateDate().toLocalDate(), LocalDate.now()))
.description(habitAssign.getHabit().getHabitTranslations().getFirst().getName())
.build();
}
}
28 changes: 28 additions & 0 deletions service/src/main/java/greencity/service/AIServiceImpl.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package greencity.service;

import greencity.constant.OpenAIRequest;
import greencity.dto.habit.DurationHabitDto;
import greencity.entity.HabitAssign;
import greencity.repository.HabitAssignRepo;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.modelmapper.ModelMapper;
import org.springframework.stereotype.Service;
import java.util.List;

@Slf4j
@Service
@AllArgsConstructor
public class AIServiceImpl implements AIService {
private final OpenAIService openAIService;
private final HabitAssignRepo habitAssignRepo;
private final ModelMapper modelMapper;

@Override
public String getForecast(Long userId, String language) {
List<HabitAssign> habitAssigns = habitAssignRepo.findAllByUserId(userId);
List<DurationHabitDto> durationHabitDtos = habitAssigns.stream()
.map(habitAssign -> modelMapper.map(habitAssign, DurationHabitDto.class)).toList();
return openAIService.makeRequest(language + OpenAIRequest.FORECAST + durationHabitDtos);
}
}
69 changes: 69 additions & 0 deletions service/src/main/java/greencity/service/OpenAIServiceImpl.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package greencity.service;

import greencity.constant.ErrorMessage;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

@Setter
@Slf4j
@Service
public class OpenAIServiceImpl implements OpenAIService {
@Value("${openai.api.key}")
private String apiKey;
@Value("${openai.api.url}")
private String apiUrl;
private final RestTemplate restTemplate;

public OpenAIServiceImpl(RestTemplate restTemplate) {
this.restTemplate = restTemplate;
}

@Override
public String makeRequest(String prompt) {
HttpHeaders headers = new HttpHeaders();
headers.add("Authorization", "Bearer " + apiKey);
headers.add("Content-Type", "application/json");

Map<String, Object> body = new HashMap<>();
body.put("model", "gpt-4o-mini");
List<Map<String, String>> messages = new ArrayList<>();
messages.add(Map.of("role", "user", "content", prompt));
body.put("messages", messages);
body.put("max_tokens", 450);
body.put("temperature", 0.8);
HttpEntity<Map<String, Object>> request = new HttpEntity<>(body, headers);
try {
ResponseEntity<Map<String, Object>> response = restTemplate.exchange(
apiUrl,
HttpMethod.POST,
request,
new ParameterizedTypeReference<>() {
});

return Optional.ofNullable(response)
.map(ResponseEntity::getBody)
.filter(responseBody -> responseBody.containsKey("choices"))
.map(responseBody -> (List<Map<String, Object>>) responseBody.get("choices"))
.filter(choices -> !choices.isEmpty())
.map(choices -> choices.get(0))
.map(choice -> (Map<String, Object>) choice.get("message"))
.map(message -> (String) message.get("content"))
.orElse(ErrorMessage.OPEN_AI_IS_NOT_RESPONDING);
} catch (Exception e) {
return ErrorMessage.OPEN_AI_IS_NOT_RESPONDING;
}
}
}
Loading
Loading