diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index bad2dacf5..fdc6bb7af 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -57,7 +57,7 @@ jobs: - name: Build documentation run: | jupyter nbconvert --to rst docs/source/useful_information/validation_benchmarking/IPF_benchmark.ipynb - sphinx-build -M latexpdf docs/source docs/source/_static + # sphinx-build -M latexpdf docs/source docs/source/_static sphinx-build -b html docs/source docs/build python3 -m zipfile -c AequilibraE.zip docs/build cp AequilibraE.zip docs/source/_static diff --git a/aequilibrae/project/database_specification/transit/tables/fare_attributes.sql b/aequilibrae/project/database_specification/transit/tables/fare_attributes.sql index d9ae5b2f9..ec49de2d6 100644 --- a/aequilibrae/project/database_specification/transit/tables/fare_attributes.sql +++ b/aequilibrae/project/database_specification/transit/tables/fare_attributes.sql @@ -1,7 +1,8 @@ --@ The *fare_attributes* table holds information about the fare values. --@ This table information comes from the GTFS file *fare_attributes.txt*. --@ Given that this file is optional in GTFS, it can be empty. ---@ You can check out more information `on fare attributes here `_. +--@ You can check out more information +--@ `on fare attributes here `_. --@ --@ **fare_id** identifies a fare class --@ diff --git a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql index 37c3985e0..25f08abc9 100644 --- a/aequilibrae/project/database_specification/transit/tables/fare_zones.sql +++ b/aequilibrae/project/database_specification/transit/tables/fare_zones.sql @@ -1,12 +1,15 @@ --@ The *fare_zones* table hold information on the transit fare zones and --@ the transit agencies that operate in it. --@ ---@ **transit_fare_zone** identifies the transit fare zones +--@ **fare_zone_id** identifies the fare zone +--@ +--@ **transit_zone** identifies the transit fare zones --@ --@ **agency_id** identifies the agency/agencies for the specified fare zone CREATE TABLE IF NOT EXISTS fare_zones ( - transit_fare_zone TEXT NOT NULL, - agency_id INTEGER NOT NULL, + fare_zone_id INTEGER PRIMARY KEY, + transit_zone TEXT, + agency_id INTEGER, FOREIGN KEY(agency_id) REFERENCES agencies(agency_id) deferrable initially deferred ); \ No newline at end of file diff --git a/aequilibrae/project/database_specification/transit/tables/routes.sql b/aequilibrae/project/database_specification/transit/tables/routes.sql index 7882637b0..ac98af4cd 100644 --- a/aequilibrae/project/database_specification/transit/tables/routes.sql +++ b/aequilibrae/project/database_specification/transit/tables/routes.sql @@ -1,6 +1,7 @@ --@ The *routes* table holds information on the available transit routes for a --@ specific day. This table information comes from the GTFS file *routes.txt*. ---@ You can find more information about `the routes table here `_. +--@ You can find more information about +--@ `the routes table here `_. --@ --@ **pattern_id** is an unique pattern for the route --@ diff --git a/aequilibrae/project/database_specification/transit/tables/stops.sql b/aequilibrae/project/database_specification/transit/tables/stops.sql index f15118615..850d84b86 100644 --- a/aequilibrae/project/database_specification/transit/tables/stops.sql +++ b/aequilibrae/project/database_specification/transit/tables/stops.sql @@ -21,11 +21,9 @@ --@ --@ **description** provides useful description of the stop location --@ ---@ **street** identifies the address of a stop +--@ **fare_zone_id** identifies the fare zone for a stop --@ ---@ **zone_id** identifies the TAZ for a stop ---@ ---@ **transit_fare_zone** identifies the transit fare zone for a stop +--@ **transit_zone** identifies the transit fare zone for a stop --@ --@ **route_type** indicates the type of transporation used on a route @@ -38,11 +36,11 @@ CREATE TABLE IF NOT EXISTS stops ( name TEXT, parent_station TEXT, description TEXT, - street TEXT, - zone_id INTEGER, - transit_fare_zone TEXT, - route_type INTEGER NOT NULL DEFAULT -1, - FOREIGN KEY(agency_id) REFERENCES agencies(agency_id) + fare_zone_id INTEGER, + transit_zone TEXT, + route_type INTEGER NOT NULL DEFAULT -1, + FOREIGN KEY(agency_id) REFERENCES agencies(agency_id), + FOREIGN KEY("fare_zone_id") REFERENCES fare_zones("fare_zone_id") ); --# diff --git a/aequilibrae/transit/column_order.py b/aequilibrae/transit/column_order.py index 2680b7427..60f54cec7 100644 --- a/aequilibrae/transit/column_order.py +++ b/aequilibrae/transit/column_order.py @@ -7,10 +7,10 @@ ("agency_name", str), ("agency_url", str), ("agency_timezone", str), - ("agency_lang", str), - ("agency_phone", str), - ("agency_fare_url", str), - ("agency_email", str), + # ("agency_lang", str), + # ("agency_phone", str), + # ("agency_fare_url", str), + # ("agency_email", str), ] ), "routes.txt": OrderedDict( @@ -24,7 +24,7 @@ # ("route_color", str), # ("route_text_color", str), # ("route_sort_order", int), - # ("agency_id", str), + ("agency_id", str), ] ), "trips.txt": OrderedDict( @@ -76,7 +76,7 @@ ("currency_type", str), ("payment_method", int), ("transfers", int), - # ("agency_id", str), + ("agency_id", str), ("transfer_duration", float), ] ), @@ -113,13 +113,14 @@ ("stop_desc", str), ("stop_lat", float), ("stop_lon", float), - ("stop_street", str), ("zone_id", str), # ("stop_url", str), # ("location_type", int), ("parent_station", str), # ("stop_timezone", str), # ("wheelchair_boarding", int), + # ("level_id", int), + # ("platform_code", str) ] ), "shapes.txt": OrderedDict( diff --git a/aequilibrae/transit/constants.py b/aequilibrae/transit/constants.py index 8ceaecc84..fdfce006e 100644 --- a/aequilibrae/transit/constants.py +++ b/aequilibrae/transit/constants.py @@ -9,6 +9,7 @@ WALK_LINK_RANGE = 30000000 TRANSIT_LINK_RANGE = 20000000 WALK_AGENCY_ID = 1 +STOP_ID = 1 # 1 for right, -1 for wrong (left) DRIVING_SIDE = 1 @@ -21,7 +22,7 @@ class Constants: trips: Dict[int, int] = {} patterns: Dict[int, int] = {} pattern_lookup: Dict[int, int] = {} - stops: Dict[int, int] = {} + stops: Dict[int, Any] = {} fares: Dict[int, int] = {} links: Dict[int, int] = {} transit_links: Dict[int, int] = {} diff --git a/aequilibrae/transit/gtfs_loader.py b/aequilibrae/transit/gtfs_loader.py index 9cfea2525..d5004c4eb 100644 --- a/aequilibrae/transit/gtfs_loader.py +++ b/aequilibrae/transit/gtfs_loader.py @@ -31,7 +31,7 @@ def __init__(self): self.__pces__ = {} self.__max_speeds__ = {} self.feed_date = "" - self.agency = Agency() + self.agency: Dict[int, Agency] = {} self.services = {} self.routes: Dict[int, Route] = {} self.trips: Dict[int, Dict[Route]] = {} @@ -46,6 +46,7 @@ def __init__(self): self.wgs84 = pyproj.Proj("epsg:4326") self.srid = get_srid() self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False) + self.agency_correspondence = {} self.logger = get_logger() def set_feed_path(self, file_path): @@ -75,14 +76,15 @@ def _set_pces(self, pces: dict): def _set_maximum_speeds(self, max_speeds: dict): self.__max_speeds__ = max_speeds - def load_data(self, service_date: str): + def load_data(self, service_date: str, description: str): """Loads the data for a respective service date. :Arguments: **service_date** (:obj:`str`): service date. e.g. "2020-04-01". """ - ag_id = self.agency.agency - self.logger.info(f"Loading data for {service_date} from the {ag_id} GTFS feed. This may take some time") + self.service_date = service_date + self.description = description + self.logger.info(f"Loading data for {self.service_date} from the GTFS feed. This may take some time") self.__load_date() @@ -90,6 +92,8 @@ def __load_date(self): self.logger.debug("Starting __load_date") self.zip_archive = zipfile.ZipFile(self.archive_dir) + self.__load_agencies() + self.signal.emit(["start", 7, "Loading routes"]) self.__load_routes_table() @@ -207,13 +211,19 @@ def __load_fare_data(self): fareatt = parse_csv(file, column_order[fareatttxt]) self.data_arrays[fareatttxt] = fareatt + existing_agencies = np.unique(fareatt["agency_id"]) + if existing_agencies.shape[0] != len(self.agency): + self.logger.debug("agency_id exists on fare_attributes.txt but not in agency.txt") + elif existing_agencies.shape[0] == 1 and existing_agencies[0] == "": + fareatt["agency_id"] = list(self.agency.keys())[0] + for line in range(fareatt.shape[0]): data = tuple(fareatt[line][list(column_order[fareatttxt].keys())]) headers = ["fare_id", "price", "currency", "payment_method", "transfer", "transfer_duration"] - f = Fare(self.agency.agency_id) + f = Fare(fareatt[line]["agency_id"]) f.populate(data, headers) if f.fare in self.fare_attributes: - self.__fail(f"Fare ID {f.fare} for {self.agency.agency} is duplicated") + self.__fail(f"Fare ID {f.fare} for {fareatt[line]['agency_id']} is duplicated") self.fare_attributes[f.fare] = f farerltxt = "fare_rules.txt" @@ -227,8 +237,6 @@ def __load_fare_data(self): farerl = parse_csv(file, column_order[farerltxt]) self.data_arrays[farerltxt] = farerl - corresp = {} - zone_id = self.agency.agency_id * AGENCY_MULTIPLIER + 1 for line in range(farerl.shape[0]): data = tuple(farerl[line][list(column_order[farerltxt].keys())]) fr = FareRule() @@ -236,14 +244,9 @@ def __load_fare_data(self): fr.fare_id = self.fare_attributes[fr.fare].fare_id if fr.route in self.routes: fr.route_id = self.routes[fr.route].route_id - fr.agency_id = self.agency.agency_id - for x in [fr.origin, fr.destination]: - if x not in corresp: - corresp[x] = zone_id - zone_id += 1 - fr.origin_id = corresp[fr.origin] - fr.destination_id = corresp[fr.destination] if fr.destination == "" else fr.destination_id - self.fare_rules.append(fr) if fr.origin == "" else fr.origin_id + fr.origin_id = None if fr.origin == "" else int(fr.origin) + fr.destination_id = None if fr.destination == "" else int(fr.destination) + self.fare_rules.append(fr) def __load_shapes_table(self): self.logger.debug("Starting __load_shapes_table") @@ -331,6 +334,7 @@ def __load_trips_table(self): trip.source_time = list(stop_times.source_time.values) self.logger.debug(f"{trip.trip} has {len(trip.stops)} stops") trip._stop_based_shape = LineString([self.stops[x].geo for x in trip.stops]) + # trip.shape = self.shapes.get(trip.shape) trip.pce = self.routes[trip.route].pce trip.seated_capacity = self.routes[trip.route].seated_capacity @@ -444,8 +448,7 @@ def __load_stops_table(self): stops[:]["stop_lon"][:] = lons[:] for i, line in enumerate(stops): - s = Stop(self.agency.agency_id, line, stops.dtype.names) - s.agency = self.agency.agency + s = Stop(line, stops.dtype.names) s.srid = self.srid s.get_node_id() self.stops[s.stop_id] = s @@ -474,8 +477,11 @@ def __load_routes_table(self): for route_type, pce in self.__pces__.items(): routes.loc[routes.route_type == route_type, ["pce"]] = pce + agency_finder = routes["agency_id"].values.tolist() + routes.drop(columns="agency_id", inplace=True) + for i, line in routes.iterrows(): - r = Route(self.agency.agency_id) + r = Route(self.agency_correspondence[agency_finder[i]]) r.populate(line.values, routes.columns) self.routes[r.route] = r @@ -572,6 +578,25 @@ def __load_feed_calendar(self): if exception_inconsistencies: self.logger.info(" Minor inconsistencies found between calendar.txt and calendar_dates.txt") + def __load_agencies(self): + self.logger.debug("Starting __load_agencies") + agencytxt = "agency.txt" + + self.logger.debug(' Loading "agency" table') + self.agency = {} + with self.zip_archive.open(agencytxt, "r") as file: + agencies = parse_csv(file, column_order[agencytxt]) + self.data_arrays[agencytxt] = agencies + + for i, line in enumerate(agencies): + a = Agency() + a.agency = line["agency_name"] + a.feed_date = self.feed_date + a.service_date = self.service_date + a.description = self.description + self.agency[a.agency_id] = a + self.agency_correspondence[line["agency_id"]] = a.agency_id + def __fail(self, msg: str) -> None: self.logger.error(msg) raise Exception(msg) diff --git a/aequilibrae/transit/lib_gtfs.py b/aequilibrae/transit/lib_gtfs.py index bae1e6c12..625af8f44 100644 --- a/aequilibrae/transit/lib_gtfs.py +++ b/aequilibrae/transit/lib_gtfs.py @@ -22,27 +22,21 @@ class GTFSRouteSystemBuilder: signal = SIGNAL(object) - def __init__( - self, network, agency_identifier, file_path, day="", description="", capacities=None, pces=None - ): # noqa: B006 + def __init__(self, network, file_path, description="", capacities=None, pces=None): # noqa: B006 """Instantiates a transit class for the network :Arguments: **local network** (:obj:`Network`): Supply model to which this GTFS will be imported - **agency_identifier** (:obj:`str`): ID for the agency this feed refers to (e.g. 'CTA') - **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') - **day** (:obj:`str`, *Optional*): Service data contained in this field to be imported (e.g. '2019-10-04') - **description** (:obj:`str`, *Optional*): Description for this feed (e.g. 'CTA19 fixed by John after coffee') """ self.__network = network self.project = get_active_project(False) self.archive_dir = None # type: str - self.day = day + self.day = None self.logger = get_logger() self.gtfs_data = GTFSReader() @@ -52,10 +46,9 @@ def __init__( self.trip_by_service = {} self.patterns = {} self.graphs = {} + self.description = description self.transformer = Transformer.from_crs("epsg:4326", f"epsg:{self.srid}", always_xy=False) self.sridproj = pyproj.Proj(f"epsg:{self.srid}") - self.gtfs_data.agency.agency = agency_identifier - self.gtfs_data.agency.description = description self.__default_capacities = {} if capacities is None else capacities self.__default_pces = {} if pces is None else pces self.__do_execute_map_matching = False @@ -153,14 +146,6 @@ def map_match(self, route_types=[3]) -> None: # noqa: B006 self.logger.warning(msg) self.signal.emit(["finished"]) - def set_agency_identifier(self, agency_id: str) -> None: - """Adds agency ID to this GTFS for use on import. - - :Arguments: - **agency_id** (:obj:`str`): ID for the agency this feed refers to (e.g. 'CTA') - """ - self.gtfs_data.agency.agency = agency_id - def set_feed(self, feed_path: str) -> None: """Sets GTFS feed source to be used. @@ -168,7 +153,6 @@ def set_feed(self, feed_path: str) -> None: **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') """ self.gtfs_data.set_feed_path(feed_path) - self.gtfs_data.agency.feed_date = self.gtfs_data.feed_date def set_description(self, description: str) -> None: """Adds description to be added to the imported layers metadata @@ -196,11 +180,10 @@ def load_date(self, service_date: str) -> None: raise ValueError("The date chosen is not available in this GTFS feed") self.day = service_date - self.gtfs_data.load_data(service_date) + self.gtfs_data.load_data(service_date, self.description) self.logger.info(" Building data structures") self.__build_data() - self.gtfs_data.agency.service_date = self.day def doWork(self): """Alias for execute_import""" @@ -212,10 +195,10 @@ def execute_import(self): if self.__target_date__ is not None: self.load_date(self.__target_date__) if not self.select_routes: - self.logger.warning(f"Nothing to import for {self.gtfs_data.agency.agency} on {self.day}") + self.logger.warning(f"Nothing to import on {self.day}") return - self.logger.info(f" Importing feed for agency {self.gtfs_data.agency.agency} on {self.day}") + self.logger.info(f" Importing GTFS feed on {self.day}") self.save_to_disk() @@ -227,7 +210,8 @@ def save_to_disk(self): pattern.save_to_database(conn, commit=False) conn.commit() - self.gtfs_data.agency.save_to_database(conn) + for counter, (_, agency) in enumerate(self.gtfs_data.agency.items()): + agency.save_to_database(conn) for counter, trip in enumerate(self.select_trips): trip.save_to_database(conn, commit=False) @@ -254,16 +238,48 @@ def save_to_disk(self): for fare_rule in self.gtfs_data.fare_rules: fare_rule.save_to_database(conn) + sql = """WITH t1 AS ( + SELECT from_stop stop_id, pattern_id FROM route_links + UNION ALL + SELECT to_stop stop_id, pattern_id FROM route_links + ), + t2 AS ( + SELECT route_id, pattern_id, agency_id FROM routes + ), + t3 AS ( + SELECT t1.stop_id, t2.agency_id, COUNT(*) as frequency + FROM t1 + JOIN t2 ON t1.pattern_id = t2.pattern_id + GROUP BY t1.stop_id, t2.agency_id + ) + SELECT t3.stop_id, t3.agency_id + FROM t3 + WHERE t3.frequency = ( + SELECT MAX(frequency) + FROM t3 AS sub + WHERE sub.stop_id = t3.stop_id + );""" + frequent_agency = conn.execute(sql).fetchall() + + zones = [] for counter, (_, stop) in enumerate(self.select_stops.items()): - if stop.zone in zone_ids: - stop.zone_id = zone_ids[stop.zone] if self.__has_taz: closest_zone = self.project.zoning.get_closest_zone(stop.geo) if stop.geo.within(self.project.zoning.get(closest_zone).geometry): stop.taz = closest_zone + stop.agency_id = frequent_agency[counter][1] + if stop.zone_id: + zones.append((stop.zone_id, "", stop.agency_id)) stop.save_to_database(conn, commit=False) conn.commit() + zones = list(set(zones)) + + if zones: + sql = "insert into fare_zones (fare_zone_id, transit_zone, agency_id) values(?,?,?);" + conn.executemany(sql, zones) + conn.commit() + self.__outside_zones = None in [x.taz for x in self.select_stops.values()] if self.__outside_zones: msg = " Some stops are outside the zoning system. Check the result on a map and see the log for info" @@ -390,9 +406,6 @@ def __get_routes_by_date(self): if not routes: self.logger.warning("NO ROUTES OPERATING FOR THIS DATE") - for route_id, route in routes.items(): - route.agency = self.gtfs_data.agency.agency - self.select_routes = routes def _get_trips_by_date_and_route(self, route_id: int, service_date: str) -> list: diff --git a/aequilibrae/transit/map_matching_graph.py b/aequilibrae/transit/map_matching_graph.py index eaafd71f1..f8fca00b9 100644 --- a/aequilibrae/transit/map_matching_graph.py +++ b/aequilibrae/transit/map_matching_graph.py @@ -41,7 +41,7 @@ def __init__(self, lib_gtfs): self.mode_id = -1 self.__mode = "" self.__df_file = "" - self.__agency = lib_gtfs.gtfs_data.agency.agency + self.__agency = "-".join(list(lib_gtfs.gtfs_data.agency_correspondence.keys())) self.__centroids_file = "" self.__mm_graph_file = "" self.node_corresp = [] diff --git a/aequilibrae/transit/parse_csv.py b/aequilibrae/transit/parse_csv.py index 408febda2..2709a1c86 100644 --- a/aequilibrae/transit/parse_csv.py +++ b/aequilibrae/transit/parse_csv.py @@ -35,7 +35,7 @@ def parse_csv(file_name: str, column_order=[]): # noqa B006 missing_cols_names = [x for x in column_order.keys() if x not in data.dtype.names] for col in missing_cols_names: - data = append_fields(data, col, np.array([""] * len(tot))) + data = append_fields(data, col, np.array([""] * len(tot)), usemask=False) if column_order: col_names = [x for x in column_order.keys() if x in data.dtype.names] diff --git a/aequilibrae/transit/transit.py b/aequilibrae/transit/transit.py index fd19a72d9..69927c0d6 100644 --- a/aequilibrae/transit/transit.py +++ b/aequilibrae/transit/transit.py @@ -47,16 +47,12 @@ def __init__(self, project): self.create_transit_database() self.pt_con = database_connection("transit") - def new_gtfs_builder(self, agency, file_path, day="", description="") -> GTFSRouteSystemBuilder: + def new_gtfs_builder(self, file_path, description="") -> GTFSRouteSystemBuilder: """Returns a ``GTFSRouteSystemBuilder`` object compatible with the project :Arguments: - **agency** (:obj:`str`): Name for the agency this feed refers to (e.g. 'CTA') - **file_path** (:obj:`str`): Full path to the GTFS feed (e.g. 'D:/project/my_gtfs_feed.zip') - **day** (:obj:`str`, *Optional*): Service data contained in this field to be imported (e.g. '2019-10-04') - **description** (:obj:`str`, *Optional*): Description for this feed (e.g. 'CTA2019 fixed by John Doe') :Returns: @@ -64,9 +60,7 @@ def new_gtfs_builder(self, agency, file_path, day="", description="") -> GTFSRou """ gtfs = GTFSRouteSystemBuilder( network=self.project_base_path, - agency_identifier=agency, file_path=file_path, - day=day, description=description, capacities=self.default_capacities, pces=self.default_pces, diff --git a/aequilibrae/transit/transit_elements/stop.py b/aequilibrae/transit/transit_elements/stop.py index ccafb1b34..c70fb0ad6 100644 --- a/aequilibrae/transit/transit_elements/stop.py +++ b/aequilibrae/transit/transit_elements/stop.py @@ -4,7 +4,11 @@ from shapely.geometry import Point -from aequilibrae.transit.constants import Constants, AGENCY_MULTIPLIER +from contextlib import closing + +from aequilibrae.project.database_connection import database_connection + +from aequilibrae.transit.constants import Constants, STOP_ID from aequilibrae.transit.transit_elements.basic_element import BasicPTElement @@ -12,7 +16,7 @@ class Stop(BasicPTElement): """Transit stop as read from the GTFS feed""" - def __init__(self, agency_id: int, record: tuple, headers: list): + def __init__(self, record: tuple, headers: list): self.stop_id = -1 self.stop = "" self.stop_code = "" @@ -20,7 +24,6 @@ def __init__(self, agency_id: int, record: tuple, headers: list): self.stop_desc = "" self.stop_lat: float = None self.stop_lon: float = None - self.stop_street = "" self.zone = "" self.zone_id = None self.stop_url = "" @@ -30,8 +33,7 @@ def __init__(self, agency_id: int, record: tuple, headers: list): # Not part of GTFS self.taz = None - self.agency = "" - self.agency_id = agency_id + self.agency_id = None self.link = None self.dir = None self.srid = -1 @@ -49,15 +51,15 @@ def __init__(self, agency_id: int, record: tuple, headers: list): if None not in [self.stop_lon, self.stop_lat]: self.geo = Point(self.stop_lon, self.stop_lat) - if len(str(self.zone_id)) == 0: - self.zone_id = None + if len(self.zone) > 0: + self.zone_id = int(self.zone) def save_to_database(self, conn: Connection, commit=True) -> None: """Saves Transit Stop to the database""" sql = """insert into stops (stop_id, stop, agency_id, link, dir, name, - parent_station, description, street, zone_id, transit_fare_zone, route_type, geometry) - values (?,?,?,?,?,?,?,?,?,?,?,?, GeomFromWKB(?, ?));""" + parent_station, description, fare_zone_id, transit_zone, route_type, geometry) + values (?,?,?,?,?,?,?,?,?,?,?, GeomFromWKB(?, ?));""" dt = self.data conn.execute(sql, dt) @@ -75,7 +77,6 @@ def data(self) -> list: self.stop_name, self.parent_station, self.stop_desc, - self.stop_street, self.zone_id, self.taz, int(self.route_type), @@ -84,8 +85,10 @@ def data(self) -> list: ] def get_node_id(self): - c = Constants() + with closing(database_connection("transit")) as conn: + sql = "Select count(stop_id) from stops;" + max_db = int(conn.execute(sql).fetchone()[0]) - val = 1 + c.stops.get(self.agency_id, AGENCY_MULTIPLIER * self.agency_id) - c.stops[self.agency_id] = val - self.stop_id = c.stops[self.agency_id] + c = Constants() + c.stops["stops"] = max(c.stops.get("stops", 0), max_db) + 1 + self.stop_id = c.stops["stops"] diff --git a/docs/source/examples/creating_models/plot_import_gtfs.py b/docs/source/examples/creating_models/plot_import_gtfs.py index afb4ee776..ade3a3420 100644 --- a/docs/source/examples/creating_models/plot_import_gtfs.py +++ b/docs/source/examples/creating_models/plot_import_gtfs.py @@ -39,8 +39,8 @@ project = create_example(fldr, "coquimbo") # %% -# As the Coquimbo example already has a complete GTFS model, we shall remove its public transport -# database for the sake of this example. +# Since Coquimbo already includes a complete GTFS model, we will remove its public transport +# database for the purposes of this example. remove(join(fldr, "public_transport.sqlite")) # %% @@ -48,35 +48,42 @@ dest_path = join(fldr, "gtfs_coquimbo.zip") # %% -# Now we create our Transit object and import the GTFS feed into our model. -# This will automatically create a new public transport database. - +# Now we create our Transit object. This will automatically create a new public transport database. data = Transit(project) -transit = data.new_gtfs_builder(agency="Lisanco", file_path=dest_path) - # %% -# To load the data, we must choose one date. We're going to continue with 2016-04-13 but feel free -# to experiment with any other available dates. Transit class has a function allowing you to check -# dates for the GTFS feed. It should take approximately 2 minutes to load the data. +# To initialize the GTFS builder, specify the path to the GTFS file. You no longer need to provide the +# name of the transit agency, but you can add a general description of the GTFS feed. If your GTFS file +# includes multiple transit agencies, all data will be loaded simultaneously. However, any description +# you add will apply to all agencies in the file. +transit = data.new_gtfs_builder(file_path=dest_path, description="Wednesday feed by John Doe") +#%% +# Case you want information on the available dates before loading the GTFS data to the database, +# it is possible to use the function ``transit.dates_available()`` to check the available feed dates. + +# %% +# To load the data, we must choose one date using the format ``YYYY-MM-DD``. We're going to build +# our database using the day 2016-04-13 but feel free to experiment with any other available dates. +# +# It shouldn't take long to load the data. transit.load_date("2016-04-13") -# Now we execute the map matching to find the real paths. -# Depending on the GTFS size, this process can be really time-consuming. -transit.set_allow_map_match(True) +# %% +# Now we execute the map matching to find the real paths. Depending on the number or different route +# patterns and/or the project area size, this process can be really time-consuming. transit.map_match() -# Finally, we save our GTFS into our model. +# %% +# Finally, we save our GTFS feed into our model. transit.save_to_disk() # %% -# Now we will plot one of the route's patterns we just imported +# Now we will plot the route's patterns we just imported conn = database_connection("transit") links = pd.read_sql("SELECT pattern_id, ST_AsText(geometry) geom FROM routes;", con=conn) - -stops = pd.read_sql("""SELECT stop_id, ST_X(geometry) X, ST_Y(geometry) Y FROM stops""", con=conn) +stops = pd.read_sql("SELECT stop_id, ST_X(geometry) X, ST_Y(geometry) Y FROM stops;", con=conn) # %% gtfs_links = folium.FeatureGroup("links") diff --git a/tests/aequilibrae/paths/test_transit_graph_builder.py b/tests/aequilibrae/paths/test_transit_graph_builder.py index 230585216..88973b295 100644 --- a/tests/aequilibrae/paths/test_transit_graph_builder.py +++ b/tests/aequilibrae/paths/test_transit_graph_builder.py @@ -30,7 +30,7 @@ def setUp(self) -> None: self.data = Transit(self.project) dest_path = join(self.temp_proj_folder, "gtfs_coquimbo.zip") - self.transit = self.data.new_gtfs_builder(agency="LISANCO", file_path=dest_path) + self.transit = self.data.new_gtfs_builder(file_path=dest_path) self.transit.load_date("2016-04-13") self.transit.save_to_disk() diff --git a/tests/aequilibrae/project/test_transit_tables.py b/tests/aequilibrae/project/test_transit_tables.py index fbb7bf50a..27249aebd 100644 --- a/tests/aequilibrae/project/test_transit_tables.py +++ b/tests/aequilibrae/project/test_transit_tables.py @@ -18,7 +18,7 @@ def create_project(project: Project): ["fare_id", "fare", "agency_id", "price", "currency", "payment_method", "transfer", "transfer_duration"], ), ("fare_rules", ["fare_id", "route_id", "origin", "destination", "contains"]), - ("fare_zones", ["transit_fare_zone", "agency_id"]), + ("fare_zones", ["fare_zone_id", "transit_zone", "agency_id"]), ("pattern_mapping", ["pattern_id", "seq", "link", "dir", "geometry"]), ( "routes", @@ -49,9 +49,8 @@ def create_project(project: Project): "name", "parent_station", "description", - "street", - "zone_id", - "transit_fare_zone", + "fare_zone_id", + "transit_zone", "route_type", "geometry", ], diff --git a/tests/aequilibrae/transit/test_gtfs_loader.py b/tests/aequilibrae/transit/test_gtfs_loader.py index 0af0ea558..c50d79375 100644 --- a/tests/aequilibrae/transit/test_gtfs_loader.py +++ b/tests/aequilibrae/transit/test_gtfs_loader.py @@ -38,4 +38,4 @@ def test_load_data(gtfs_loader, gtfs_fldr): gtfs._set_maximum_speeds(dict_speeds) gtfs.set_feed_path(gtfs_fldr) - gtfs.load_data("2016-04-13") + gtfs.load_data("2016-04-13", "this is a description") diff --git a/tests/aequilibrae/transit/test_gtfs_stop.py b/tests/aequilibrae/transit/test_gtfs_stop.py index bb8d5bdd4..bc2ea9674 100644 --- a/tests/aequilibrae/transit/test_gtfs_stop.py +++ b/tests/aequilibrae/transit/test_gtfs_stop.py @@ -18,8 +18,7 @@ def data(): "stop_desc": randomword(randint(0, 40)), "stop_lat": uniform(0, 30000), "stop_lon": uniform(0, 30000), - "stop_street": randomword(randint(0, 40)), - "zone_id": randomword(randint(0, 40)), + "zone_id": str(randint(0, 40)), "stop_url": randomword(randint(0, 40)), "location_type": choice((0, 1)), "parent_station": randomword(randint(0, 40)), @@ -28,7 +27,7 @@ def data(): def test__populate(data): - s = Stop(1, tuple(data.values()), list(data.keys())) + s = Stop(tuple(data.values()), list(data.keys())) xy = (s.geo.x, s.geo.y) assert xy == (data["stop_lon"], data["stop_lat"]), "Stop built geo wrongly" data["stop"] = data["stop_id"] @@ -42,17 +41,17 @@ def test__populate(data): new_data = deepcopy(data) new_data[randomword(randint(1, 15))] = randomword(randint(1, 20)) with pytest.raises(KeyError): - _ = Stop(1, tuple(new_data.values()), list(new_data.keys())) + _ = Stop(tuple(new_data.values()), list(new_data.keys())) def test_save_to_database(data, transit_conn): line = LineString([[-23.59, -46.64], [-23.43, -46.50]]).wkb tlink_id = randint(10000, 200000044) - s = Stop(1, tuple(data.values()), list(data.keys())) + s = Stop(tuple(data.values()), list(data.keys())) s.link = link = randint(1, 30000) s.dir = direc = choice((0, 1)) - s.agency = randint(5, 100000) s.route_type = randint(0, 13) + s.agency_id = randint(1, 10) s.srid = get_srid() s.get_node_id() s.save_to_database(transit_conn, commit=True) @@ -61,7 +60,7 @@ def test_save_to_database(data, transit_conn): VALUES(?, ?, ?, ?, ?, ?, GeomFromWKB(?, 4326));""" transit_conn.execute(sql_tl, [tlink_id, randint(1, 1000000000), randint(1, 10), s.stop_id, s.stop_id + 1, 0, line]) - sql = "Select agency_id, link, dir, description, street from stops where stop=?" + sql = "Select link, dir, description from stops where stop=?" result = list(transit_conn.execute(sql, [data["stop_id"]]).fetchone()) - expected = [s.agency_id, link, direc, data["stop_desc"], data["stop_street"]] + expected = [link, direc, data["stop_desc"]] assert result == expected, "Saving Stop to the database failed" diff --git a/tests/aequilibrae/transit/test_lib_gtfs.py b/tests/aequilibrae/transit/test_lib_gtfs.py index 9ff39f585..4845c6b77 100644 --- a/tests/aequilibrae/transit/test_lib_gtfs.py +++ b/tests/aequilibrae/transit/test_lib_gtfs.py @@ -11,9 +11,7 @@ def gtfs_file(create_path): @pytest.fixture def system_builder(transit_conn, gtfs_file): - yield GTFSRouteSystemBuilder( - network=transit_conn, agency_identifier="LISERCO, LISANCO, LINCOSUR", file_path=gtfs_file - ) + yield GTFSRouteSystemBuilder(network=transit_conn, file_path=gtfs_file) def test_set_capacities(system_builder): @@ -56,12 +54,6 @@ def test_map_match(transit_conn, system_builder): assert transit_conn.execute("SELECT * FROM pattern_mapping;").fetchone()[0] > 1 -def test_set_agency_identifier(system_builder): - assert system_builder.gtfs_data.agency.agency != "CTA" - system_builder.set_agency_identifier("CTA") - assert system_builder.gtfs_data.agency.agency == "CTA" - - def test_set_feed(gtfs_file, system_builder): system_builder.set_feed(gtfs_file) assert system_builder.gtfs_data.archive_dir == gtfs_file @@ -79,7 +71,6 @@ def test_set_date(system_builder): def test_load_date(system_builder): system_builder.load_date("2016-04-13") - assert system_builder.gtfs_data.agency.service_date == "2016-04-13" assert "101387" in system_builder.select_routes.keys() diff --git a/tests/aequilibrae/transit/test_pattern.py b/tests/aequilibrae/transit/test_pattern.py index ab0b5c066..6d6a4b6a2 100644 --- a/tests/aequilibrae/transit/test_pattern.py +++ b/tests/aequilibrae/transit/test_pattern.py @@ -6,7 +6,7 @@ def pat(create_path, create_gtfs_project): gtfs_fldr = os.path.join(create_path, "gtfs_coquimbo.zip") - transit = create_gtfs_project.new_gtfs_builder(agency="Lisanco", file_path=gtfs_fldr, description="") + transit = create_gtfs_project.new_gtfs_builder(file_path=gtfs_fldr, description="") transit.load_date("2016-04-13") patterns = transit.select_patterns diff --git a/tests/aequilibrae/transit/test_transit.py b/tests/aequilibrae/transit/test_transit.py index fab7ee94d..850b37792 100644 --- a/tests/aequilibrae/transit/test_transit.py +++ b/tests/aequilibrae/transit/test_transit.py @@ -12,30 +12,26 @@ def test_new_gtfs_builder(create_gtfs_project, create_path): existing = conn.execute("SELECT COALESCE(MAX(DISTINCT(agency_id)), 0) FROM agencies;").fetchone()[0] transit = create_gtfs_project.new_gtfs_builder( - agency="Agency_1", - day="2016-04-13", file_path=join(create_path, "gtfs_coquimbo.zip"), ) - + transit.load_date("2016-04-13") + transit.save_to_disk() assert str(type(transit)) == "" transit2 = create_gtfs_project.new_gtfs_builder( - agency="Agency_2", - day="2016-07-19", file_path=join(create_path, "gtfs_coquimbo.zip"), ) - transit.save_to_disk() + transit2.load_date("2016-07-19") transit2.save_to_disk() assert conn.execute("SELECT MAX(DISTINCT(agency_id)) FROM agencies;").fetchone()[0] == existing + 2 transit3 = create_gtfs_project.new_gtfs_builder( - agency="Agency_3", - day="2016-07-19", file_path=join(create_path, "gtfs_coquimbo.zip"), ) + transit3.load_date("2016-06-04") transit3.save_to_disk() assert conn.execute("SELECT MAX(DISTINCT(agency_id)) FROM agencies;").fetchone()[0] == existing + 3