diff --git a/src/main/python/scripts/remove_duplicate_links.py b/src/main/python/scripts/remove_duplicate_links.py new file mode 100644 index 00000000000..ce14ab61034 --- /dev/null +++ b/src/main/python/scripts/remove_duplicate_links.py @@ -0,0 +1,84 @@ +import argparse +import shapefile + +parser = argparse.ArgumentParser() + +parser.add_argument("-i", "--inputshapefile", help="Path to input shapefile", required=True) +parser.add_argument("-o", "--outputshapefile", help="Path to output shapefile", required=True) +parser.add_argument("-c", "--capacity", help="Default capacity value (cars/hr/lane) for zero values in the input shapefile", type=float, required=False, default=1500.0) + +args=parser.parse_args() + +reader = shapefile.Reader(args.inputshapefile) +writer = shapefile.Writer(args.outputshapefile) + +# START +# accumulate adjacent links max_speed +inbound_links = dict() +outbound_links = dict() + +def adjacent_links(links_dict, node_id, link_speed): + if node_id not in links_dict: + links_dict[node_id] = set([link_speed]) + else: + s = links_dict[node_id] + s.add(link_speed) + +for record in reader.iterRecords(): + [from_node, to_node] = record['ID'].split('-') + max_speed = record['DATA2'] + if max_speed > 0: + adjacent_links(outbound_links, from_node, max_speed) + adjacent_links(inbound_links, to_node, max_speed) + +print(f"inbound_links size: {len(inbound_links)}") +print(f"outbound_links size: {len(outbound_links)}") +# END + +writer.field('ID', 'C', 20, 0) +writer.field('MODES', 'C', 64, 0) +writer.field('LANES', 'N', 35, 7) +writer.field('DATA1', 'N', 35, 7) # hourly capacity per lane +writer.field('DATA2', 'C', 35, 7) # add mph + +shapeRecords = reader.shapeRecords() +# to keep non-zero capacity links at the top for processing +shapeRecords.sort(key=lambda x: x.record['DATA1'], reverse=True) +s = set() + +for n in shapeRecords: + record = n.record + shape = n.shape + if record['ID'] not in s: + s.add(f"{record['JNODE']}-{record['INODE']}") + + if record['DATA1'] == 0: + record['DATA1'] = args.capacity + + if record['DATA2'] == 0: + from_node = str(record['INODE']) + has_from_node = from_node in inbound_links + ilinks = set() + if has_from_node: + ilinks = inbound_links[from_node] + + to_node = str(record['JNODE']) + has_to_node = to_node in outbound_links + olinks = set() + if has_to_node: + olinks = outbound_links[to_node] + + if has_from_node or has_to_node: + links = ilinks.union(olinks) + record['DATA2'] = sum(links) / len(links) + + writer.record( + record['ID'], + record['MODES'], + record['LANES'], + record['DATA1'], + str(record['DATA2']) + " mph", + ) + writer.shape(shape) + +writer.close() diff --git a/src/main/scala/scripts/ConsolidateOSMNodes.scala b/src/main/scala/scripts/ConsolidateOSMNodes.scala new file mode 100644 index 00000000000..5a39191a4eb --- /dev/null +++ b/src/main/scala/scripts/ConsolidateOSMNodes.scala @@ -0,0 +1,105 @@ +package scripts + +import java.io.File +import scala.collection.mutable +import scala.xml.transform.{RewriteRule, RuleTransformer} +import scala.xml.{Elem, Node, PrettyPrinter, XML} + +case class LatLon(lat: Double, lon: Double) + +// Usage: +// ./gradlew :execute \ +// -PmaxRAM=10 \ +// -PmainClass=scripts.ConsolidateOSMNodes \ +// -PappArgs="['links0.osm','links_consolidated.osm']" +object ConsolidateOSMNodes { + + private val locationToIds: mutable.Map[LatLon, mutable.Seq[Long]] = + mutable.Map.empty.withDefaultValue(mutable.Seq.empty) + private val idToLocation: mutable.Map[Long, LatLon] = mutable.Map.empty + + private val replaceRedundantId = new RewriteRule { + + override def transform(node: Node): Seq[Node] = { + node match { + case nd: Elem if nd.label == "nd" => + val id = (nd \ "@ref").text.toInt + val latLon = idToLocation(id) + val head :: tail = locationToIds(latLon).toList + if (tail.contains(id)) { + val metaData = + scala.xml.Attribute(key = "ref", value = scala.xml.Text(head.toString), next = scala.xml.Null) + nd % metaData + } else nd + case n => n + } + } + } + + private val removeNode = new RewriteRule { + + override def transform(node: Node): Seq[Node] = { + node match { + case node: Elem if node.label == "node" => + val id = (node \ "@id").text.toInt + val latLon = LatLon( + (node \ "@lat").text.toDouble, + (node \ "@lon").text.toDouble + ) + val _ :: tail = locationToIds(latLon).toList + if (tail.contains(id)) Seq.empty + else node + case n => n + } + } + } + + private def populateState(xml: Node): Unit = { + for { + osm <- xml \\ "osm" + node <- osm \\ "node" + } { + val id = (node \ "@id").text.toLong + val latLon = LatLon( + (node \ "@lat").text.toDouble, + (node \ "@lon").text.toDouble + ) + idToLocation.update(id, latLon) + val seq = locationToIds(latLon) + locationToIds.update(latLon, seq :+ id) + } + } + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + println(""" + |Usage: + |./gradlew :execute \ + | -PmaxRAM=10 \ + | -PmainClass=scripts.ConsolidateOSMNodes \ + | -PappArgs="['links0.osm','links_consolidated.osm']" + |""".stripMargin) + System.exit(1) + } + + val osmFile = new File(args(0)) + println("Loading xml..") + val xml = XML.loadFile(osmFile) + populateState(xml) + + val transformer = new RuleTransformer( + replaceRedundantId, + removeNode + ) + + println("Consolidating network nodes..") + val output = { + val root = transformer.transform(xml) + val printer = new PrettyPrinter(120, 2, true) + XML.loadString(printer.format(root.head)) + } + + println("Writing xml..") + XML.save(args(1), output, "UTF-8", xmlDecl = true, null) + } +}