diff --git a/src/main/java/nextstep/subway/applicaion/PathService.java b/src/main/java/nextstep/subway/applicaion/PathService.java new file mode 100644 index 0000000000..589bfb022b --- /dev/null +++ b/src/main/java/nextstep/subway/applicaion/PathService.java @@ -0,0 +1,35 @@ +package nextstep.subway.applicaion; + +import java.util.List; +import nextstep.subway.applicaion.dto.PathResponse; +import nextstep.subway.domain.Line; +import nextstep.subway.domain.LineRepository; +import nextstep.subway.domain.Path; +import nextstep.subway.domain.PathFinder; +import nextstep.subway.domain.Station; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@Transactional +public class PathService { + private final LineRepository lineRepository; + private final StationService stationService; + + public PathService(LineRepository lineRepository, + StationService stationService) { + this.lineRepository = lineRepository; + this.stationService = stationService; + } + + public PathResponse searchPath(Long sourceId, Long targetId) { + Station source = stationService.findById(sourceId); + Station target = stationService.findById(targetId); + List lines = lineRepository.findAll(); + + PathFinder pathFinder = PathFinder.from(lines); + Path shortestPath = pathFinder.searchPath(source, target); + return PathResponse.from(shortestPath); + } + +} diff --git a/src/main/java/nextstep/subway/applicaion/dto/PathResponse.java b/src/main/java/nextstep/subway/applicaion/dto/PathResponse.java new file mode 100644 index 0000000000..5cb477e74a --- /dev/null +++ b/src/main/java/nextstep/subway/applicaion/dto/PathResponse.java @@ -0,0 +1,31 @@ +package nextstep.subway.applicaion.dto; + +import java.util.List; +import java.util.stream.Collectors; +import nextstep.subway.domain.Path; + +public class PathResponse { + private List stations; + private int distance; + + private PathResponse(List stations, int distance) { + this.stations = stations; + this.distance = distance; + } + + public static PathResponse from(Path path) { + List stationResponses = path.getStations().stream() + .map(StationResponse::from) + .collect(Collectors.toList()); + + return new PathResponse(stationResponses, path.getDistance()); + } + + public List getStations() { + return stations; + } + + public int getDistance() { + return distance; + } +} diff --git a/src/main/java/nextstep/subway/domain/Path.java b/src/main/java/nextstep/subway/domain/Path.java new file mode 100644 index 0000000000..b0730d06a1 --- /dev/null +++ b/src/main/java/nextstep/subway/domain/Path.java @@ -0,0 +1,22 @@ +package nextstep.subway.domain; + +import java.util.Collections; +import java.util.List; + +public class Path { + private List stations; + private int distance; + + public Path(List stations, int distance) { + this.stations = stations; + this.distance = distance; + } + + public List getStations() { + return Collections.unmodifiableList(stations); + } + + public int getDistance() { + return distance; + } +} diff --git a/src/main/java/nextstep/subway/domain/PathFinder.java b/src/main/java/nextstep/subway/domain/PathFinder.java new file mode 100644 index 0000000000..ed158bbb17 --- /dev/null +++ b/src/main/java/nextstep/subway/domain/PathFinder.java @@ -0,0 +1,67 @@ +package nextstep.subway.domain; + +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import org.jgrapht.GraphPath; +import org.jgrapht.alg.shortestpath.DijkstraShortestPath; +import org.jgrapht.graph.DefaultWeightedEdge; +import org.jgrapht.graph.WeightedMultigraph; + +public class PathFinder { + private WeightedMultigraph graph = new WeightedMultigraph(DefaultWeightedEdge.class); + + private PathFinder() { + } + + public static PathFinder from(List lines) { + PathFinder pathFinder = new PathFinder(); + pathFinder.init(lines); + return pathFinder; + } + + public void init(List lines) { + lines.stream() + .map(Line::getSections) + .map(Sections::getSectionList) + .flatMap(Collection::stream) + .forEach( + section -> { + Station upStation = section.getUpStation(); + Station downStation = section.getDownStation(); + graph.addVertex(upStation); + graph.addVertex(downStation); + graph.setEdgeWeight(graph.addEdge(upStation, downStation), section.getDistance()); + } + ); + } + + public Path searchPath(Station source, Station target) { + validationSearchPathParams(source, target); + GraphPath shortestPath = searchShortestPath(source, target); + + List stations = shortestPath.getVertexList(); + int distance = (int) shortestPath.getWeight(); + return new Path(stations, distance); + } + + public GraphPath searchShortestPath(Station source, Station target) { + DijkstraShortestPath dijkstraShortestPath = new DijkstraShortestPath<>(graph); + GraphPath graphPath = dijkstraShortestPath.getPath(source, target); + + return Optional.ofNullable(graphPath) + .orElseThrow(() -> { + throw new IllegalArgumentException("출발역과 도착역이 연결되어 있지 않습니다."); + }); + } + + private void validationSearchPathParams(Station source, Station target) { + if (source.equals(target)) { + throw new IllegalArgumentException("출발역과 도착역이 동일합니다."); + } + + if (!graph.containsVertex(source) || !graph.containsVertex(target)) { + throw new IllegalArgumentException("노선에 포함된 역의 경로만 조회 가능합니다."); + } + } +} diff --git a/src/main/java/nextstep/subway/domain/Section.java b/src/main/java/nextstep/subway/domain/Section.java index 8b274354d8..5db1e01401 100644 --- a/src/main/java/nextstep/subway/domain/Section.java +++ b/src/main/java/nextstep/subway/domain/Section.java @@ -63,7 +63,7 @@ public void update(Station newUpStation, int minusDistance) { } public Section merge(Section section) { - if (!isDownStation(section.upStation)) { + if (!hasDownStationAs(section.upStation)) { throw new IllegalArgumentException("합치려는 구간의 상행역이 하행역과 같아야 합니다."); } @@ -71,11 +71,11 @@ public Section merge(Section section) { } - public boolean isUpStation(Station station) { + public boolean hasUpStationAs(Station station) { return upStation.equals(station); } - public boolean isDownStation(Station station) { + public boolean hasDownStationAs(Station station) { return downStation.equals(station); } diff --git a/src/main/java/nextstep/subway/domain/Sections.java b/src/main/java/nextstep/subway/domain/Sections.java index 6e069542ec..2578c31c5b 100644 --- a/src/main/java/nextstep/subway/domain/Sections.java +++ b/src/main/java/nextstep/subway/domain/Sections.java @@ -1,6 +1,7 @@ package nextstep.subway.domain; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -34,23 +35,30 @@ public void add(Section section) { public void remove(Station station) { validationRemoveStation(station); - Optional
firstSection = sections.stream() - .filter(s -> s.isDownStation(station)) - .findAny(); - - Optional
secondSection = sections.stream() - .filter(s -> s.isUpStation(station)) - .findAny(); + Optional
sectionHasUpStation = findSectionHasUpStationAs(station); + Optional
sectionHasDownStation = findSectionHasDownStationAs(station); - firstSection.ifPresent(section -> sections.remove(section)); - secondSection.ifPresent(section -> sections.remove(section)); + sectionHasUpStation.ifPresent(section -> sections.remove(section)); + sectionHasDownStation.ifPresent(section -> sections.remove(section)); - if (firstSection.isPresent() && secondSection.isPresent()) { - mergeExistingSections(firstSection.get(), secondSection.get()); + if (sectionHasUpStation.isPresent() && sectionHasDownStation.isPresent()) { + mergeExistingSections(sectionHasUpStation.get(), sectionHasDownStation.get()); } } + private Optional
findSectionHasDownStationAs(Station station) { + return sections.stream() + .filter(s -> s.hasUpStationAs(station)) + .findAny(); + } + + private Optional
findSectionHasUpStationAs(Station station) { + return sections.stream() + .filter(s -> s.hasDownStationAs(station)) + .findAny(); + } + private List getStations() { List stations = new ArrayList<>(); stations.add(sections.get(0).getUpStation()); @@ -152,4 +160,7 @@ private void mergeExistingSections(Section firstSection, Section secondSection) sections.add(firstSection.merge(secondSection)); } + public List
getSectionList() { + return Collections.unmodifiableList(sections); + } } diff --git a/src/main/java/nextstep/subway/ui/PathController.java b/src/main/java/nextstep/subway/ui/PathController.java new file mode 100644 index 0000000000..f2d8433565 --- /dev/null +++ b/src/main/java/nextstep/subway/ui/PathController.java @@ -0,0 +1,25 @@ +package nextstep.subway.ui; + +import nextstep.subway.applicaion.PathService; +import nextstep.subway.applicaion.dto.PathResponse; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +@RestController +@RequestMapping("/path") +public class PathController { + private final PathService pathService; + + public PathController(PathService pathService) { + this.pathService = pathService; + } + + @GetMapping + public ResponseEntity getPaths(@RequestParam Long source, @RequestParam Long target) { + PathResponse pathResponse = pathService.searchPath(source, target); + return ResponseEntity.ok().body(pathResponse); + } +} diff --git a/src/test/java/nextstep/subway/acceptance/PathAcceptanceTest.java b/src/test/java/nextstep/subway/acceptance/PathAcceptanceTest.java new file mode 100644 index 0000000000..e203d6dedd --- /dev/null +++ b/src/test/java/nextstep/subway/acceptance/PathAcceptanceTest.java @@ -0,0 +1,120 @@ +package nextstep.subway.acceptance; + +import static nextstep.subway.acceptance.LineSteps.지하철_노선_생성_요청; +import static nextstep.subway.acceptance.LineSteps.지하철_노선에_지하철_구간_생성_요청; +import static nextstep.subway.acceptance.PathSteps.경로_조회_요청; +import static nextstep.subway.acceptance.StationSteps.지하철역_생성_요청; +import static org.assertj.core.api.Assertions.assertThat; + +import io.restassured.response.ExtractableResponse; +import io.restassured.response.Response; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.http.HttpStatus; + +@DisplayName("지하철 경로 검색") +public class PathAcceptanceTest extends AcceptanceTest { + private Long 교대역; + private Long 강남역; + private Long 양재역; + private Long 남부터미널역; + private Long 청량리역; + private Long 회기역; + + private Long 일호선; + private Long 이호선; + private Long 신분당선; + private Long 삼호선; + + private int 교대역_강남역_거리 = 10; + private int 강남역_양재역_거리 = 5; + private int 교대역_남부터미널역_거리 = 2; + private int 남부터미널역_양재역_거리 = 3; + private int 청량리역_회기역_거리 = 7; + + /** (10) + * 교대역 --- *2호선* --- 강남역 + * | | + * *3호선* (2) *신분당선* (5) + * | | + * 남부터미널역 --- *3호선* --- 양재 + (3) **/ + + @BeforeEach + public void setUp() { + super.setUp(); + + 교대역 = 지하철역_생성_요청("교대역").jsonPath().getLong("id"); + 강남역 = 지하철역_생성_요청("강남역").jsonPath().getLong("id"); + 양재역 = 지하철역_생성_요청("양재역").jsonPath().getLong("id"); + 남부터미널역 = 지하철역_생성_요청("남부터미널역").jsonPath().getLong("id"); + 청량리역 = 지하철역_생성_요청("청량리역").jsonPath().getLong("id"); + 회기역 = 지하철역_생성_요청("회기역").jsonPath().getLong("id"); + + 일호선 = 지하철_노선_생성_요청("1호선", "blue", 청량리역, 회기역, 청량리역_회기역_거리).jsonPath().getLong("id"); + 이호선 = 지하철_노선_생성_요청("2호선", "green", 교대역, 강남역, 교대역_강남역_거리).jsonPath().getLong("id"); + 신분당선 = 지하철_노선_생성_요청("신분당선", "red", 강남역, 양재역, 강남역_양재역_거리).jsonPath().getLong("id"); + 삼호선 = 지하철_노선_생성_요청("3호선", "orange", 교대역, 남부터미널역, 교대역_남부터미널역_거리).jsonPath().getLong("id"); + + 지하철_노선에_지하철_구간_생성_요청(삼호선, createSectionCreateParams(남부터미널역, 양재역, 남부터미널역_양재역_거리)); + } + + /** + * Given 지하철 노선 (2호선, 3호선, 신분당선) 을 생성하고 역, 구간을 생성한다. + * When 서로 다른 두 역의 최단 거리를 조회하면 + * Then 최단 거리 조회에 성공한다. + */ + @DisplayName("최단 거리 조회하기") + @Test + void searchShortestPath() { + // when + ExtractableResponse response = 경로_조회_요청(남부터미널역, 강남역); + + // then + assertThat(response.statusCode()).isEqualTo(HttpStatus.OK.value()); + assertThat(response.jsonPath().getList("stations.id", Long.class)).containsExactly(남부터미널역, 양재역, 강남역); + assertThat(response.jsonPath().getInt("distance")).isEqualTo(남부터미널역_양재역_거리 + 강남역_양재역_거리); + } + + /** + * Given 지하철 노선 (2호선, 3호선, 신분당선) 을 생성하고 역, 구간을 생성한다. + * 기존 노선과 연결되지 않는 새로운 노선 (1호선)을 생성한다. + * When 연결 되지 않은 두 역의 최단 거리 조회를 요청 하면 + * Then 최단 거리 조회에 실패한다. + */ + @DisplayName("최단 거리 조회하기 - 연결되지 않은 역을 조회 할 경우") + @Test + void searchShortestPathDoesNotExistPath() { + // when + ExtractableResponse response = 경로_조회_요청(강남역, 청량리역); + + // then + assertThat(response.statusCode()).isEqualTo(HttpStatus.BAD_REQUEST.value());; + } + + /** + * Given 지하철 노선 (2호선, 3호선, 신분당선) 을 생성하고 역, 구간을 생성한다. + * When 출발역과 도착역이 동일한데 최단 거리 조회를 요청하면 + * Then 최단 거리 조회에 실패한다. + */ + @DisplayName("최단 거리 조회하기 - 출발역과 도착역이 동일한 경우") + @Test + void searchShortestPathSourceEqualsTarget() { + // when + ExtractableResponse response = 경로_조회_요청(강남역, 강남역); + + // then + assertThat(response.statusCode()).isEqualTo(HttpStatus.BAD_REQUEST.value());; + } + + private Map createSectionCreateParams(Long upStationId, Long downStationId, int distance) { + Map params = new HashMap<>(); + params.put("upStationId", upStationId + ""); + params.put("downStationId", downStationId + ""); + params.put("distance", distance + ""); + return params; + } +} diff --git a/src/test/java/nextstep/subway/acceptance/PathSteps.java b/src/test/java/nextstep/subway/acceptance/PathSteps.java new file mode 100644 index 0000000000..6c06c0c83a --- /dev/null +++ b/src/test/java/nextstep/subway/acceptance/PathSteps.java @@ -0,0 +1,24 @@ +package nextstep.subway.acceptance; + +import io.restassured.RestAssured; +import io.restassured.response.ExtractableResponse; +import io.restassured.response.Response; +import java.util.HashMap; +import java.util.Map; +import org.springframework.http.MediaType; + +public class PathSteps { + public static ExtractableResponse 경로_조회_요청(Long sourceId, Long targetId) { + Map params = new HashMap<>(); + params.put("source", sourceId.toString()); + params.put("target", targetId.toString()); + + return RestAssured.given().log().all() + .params(params) + .contentType(MediaType.APPLICATION_JSON_VALUE) + .when() + .get("/path") + .then().log().all() + .extract(); + } +} diff --git a/src/test/java/nextstep/subway/unit/PathFinderTest.java b/src/test/java/nextstep/subway/unit/PathFinderTest.java new file mode 100644 index 0000000000..6382cc2e14 --- /dev/null +++ b/src/test/java/nextstep/subway/unit/PathFinderTest.java @@ -0,0 +1,66 @@ +package nextstep.subway.unit; + +import static nextstep.subway.unit.PathFixture.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +import java.util.Arrays; +import nextstep.subway.domain.Line; +import nextstep.subway.domain.Path; +import nextstep.subway.domain.PathFinder; +import nextstep.subway.domain.Station; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +public class PathFinderTest { + + @DisplayName("최단 거리 조회하기") + @Test + void searchPath() { + // when + PathFinder pathFinder = PathFinder.from(노선_목록); + Path path = pathFinder.searchPath(교대역, 양재역); + + // then + assertThat(path.getStations()).containsExactly(교대역, 남부터미널역, 양재역); + assertThat(path.getDistance()).isEqualTo(교대역_남부터미널역_거리 + 남부터미널역_양재역_거리); + } + + @DisplayName("출발역과 도착역이 동일 할 경우") + @Test + void searchPathSourceEqualsTarget() { + // when + // then + PathFinder pathFinder = PathFinder.from(노선_목록); + assertThatIllegalArgumentException().isThrownBy(() -> pathFinder.searchPath(교대역, 교대역)) + .withMessage("출발역과 도착역이 동일합니다."); + } + + @DisplayName("출발역과 도착역이 연결되어 있지 않을 경우") + @Test + void searchPathDoseNotExistShortestPath() { + // given + Station 사당역 = new Station("사당역"); + Station 이수역 = new Station("이수역"); + Line 사호선 = Line.of("4호선", "bg-blue-600", 사당역, 이수역, 10); + + // when + // then + PathFinder pathFinder = PathFinder.from(Arrays.asList(이호선, 신분당선, 삼호선, 사호선)); + assertThatIllegalArgumentException().isThrownBy(() -> pathFinder.searchPath(교대역, 이수역)) + .withMessage("출발역과 도착역이 연결되어 있지 않습니다."); + } + + @DisplayName("노선에 존재하지 않은 출발역이나 도착역을 조회 할 경우") + @Test + void searchPathSourceDoesNotExistStation() { + // given + Station 시청역 = new Station("시청역"); + // when + // then + PathFinder pathFinder = PathFinder.from(노선_목록); + assertThatIllegalArgumentException().isThrownBy(() -> pathFinder.searchPath(교대역, 시청역)) + .withMessage("노선에 포함된 역의 경로만 조회 가능합니다."); + } + +} diff --git a/src/test/java/nextstep/subway/unit/PathFixture.java b/src/test/java/nextstep/subway/unit/PathFixture.java new file mode 100644 index 0000000000..afffd2c8f6 --- /dev/null +++ b/src/test/java/nextstep/subway/unit/PathFixture.java @@ -0,0 +1,48 @@ +package nextstep.subway.unit; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import nextstep.subway.domain.Line; +import nextstep.subway.domain.Station; +import org.springframework.test.util.ReflectionTestUtils; + +public class PathFixture { + /** (10) + * 교대역 --- *2호선* --- 강남역 + * | | + * *3호선* (2) *신분당선* (5) + * | | + * 남부터미널역 --- *3호선* --- 양재 + (3) **/ + + + public static Station 강남역 = new Station("강남역"); + public static Station 교대역 = new Station("교대역"); + public static Station 양재역 = new Station("양재역"); + public static Station 남부터미널역 = new Station("남부터미널역"); + + public static int 교대역_강남역_거리 = 10; + public static int 강남역_양재역_거리 = 5; + public static int 교대역_남부터미널역_거리 = 2; + public static int 남부터미널역_양재역_거리 = 3; + + public static Line 이호선 = Line.of("2호선", "bg-green-600", 교대역, 강남역, 교대역_강남역_거리); + public static Line 삼호선 = Line.of("3호선", "bg-orange-500", 교대역, 남부터미널역, 교대역_남부터미널역_거리); + public static Line 신분당선 = Line.of("신분당선", "bg-red-500", 강남역, 양재역, 강남역_양재역_거리); + + public static List 노선_목록 = new ArrayList<>(); + + static { + 삼호선.addSection(남부터미널역, 양재역, 남부터미널역_양재역_거리); + + ReflectionTestUtils.setField(강남역, "id", 1L); + ReflectionTestUtils.setField(교대역, "id", 2L); + ReflectionTestUtils.setField(양재역, "id", 3L); + ReflectionTestUtils.setField(남부터미널역, "id", 4L); + + 노선_목록.addAll(Arrays.asList(이호선, 삼호선, 신분당선)); + } + + +} diff --git a/src/test/java/nextstep/subway/unit/PathServiceMockTest.java b/src/test/java/nextstep/subway/unit/PathServiceMockTest.java new file mode 100644 index 0000000000..1da47058c7 --- /dev/null +++ b/src/test/java/nextstep/subway/unit/PathServiceMockTest.java @@ -0,0 +1,49 @@ +package nextstep.subway.unit; +import static nextstep.subway.unit.PathFixture.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; + +import java.util.List; +import java.util.stream.Collectors; +import nextstep.subway.applicaion.PathService; +import nextstep.subway.applicaion.StationService; +import nextstep.subway.applicaion.dto.PathResponse; +import nextstep.subway.applicaion.dto.StationResponse; +import nextstep.subway.domain.LineRepository; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class PathServiceMockTest { + @Mock + private StationService stationService; + + @Mock + private LineRepository lineRepository; + + @InjectMocks + private PathService pathService; + + @Test + void searchPath() { + // given + given(stationService.findById(강남역.getId())).willReturn(강남역); + given(stationService.findById(남부터미널역.getId())).willReturn(남부터미널역); + given(lineRepository.findAll()).willReturn(노선_목록); + + // when + PathResponse pathResponse = pathService.searchPath(강남역.getId(), 남부터미널역.getId()); + + // then + List stationIdList = pathResponse.getStations().stream() + .map(StationResponse::getId) + .collect(Collectors.toList()); + + assertThat(stationIdList).containsExactly(강남역.getId(), 양재역.getId(), 남부터미널역.getId()); + assertThat(pathResponse.getDistance()).isEqualTo(남부터미널역_양재역_거리 + 강남역_양재역_거리); + + } +}