Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ description = "GTFS feed validator using Python and Polars"
requires-python = ">=3.12"
dependencies = [
"polars>=1.0",
"numpy>=1.24",
"click>=8.0",
"jinja2>=3.0",
"httpx>=0.27",
Expand Down
10 changes: 5 additions & 5 deletions specops/slowest.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
1. shape_to_stop_matching — 3.683s
2. stop_time_travel_speed — 2.857s
3. transfers_in_seat_transfer_type — 0.569s
4. block_trips_overlapping — 0.417s
5. shape_increasing_distance — 0.378s
1. shape_to_stop_matching — 2.631s
2. stop_time_travel_speed — 0.906s
3. transfers_in_seat_transfer_type — 0.531s
4. block_trips_overlapping — 0.382s
5. shape_increasing_distance — 0.365s
157 changes: 122 additions & 35 deletions src/gtfs_validator/validators/shape_to_stop_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import Optional

import polars as pl

from gtfs_validator.context import ValidationContext
from gtfs_validator.notices import Notice, Severity
from gtfs_validator.validators.shape_to_stop_matching_util import (
CandidateMatch,
RAIL_ROUTE_TYPE,
MatchSettings,
Problem,
_VEC_MIN_SEGMENTS,
_VEC_MIN_WORK_ITEMS,
build_shape_arrays,
build_shape_points,
build_shape_spatial_index,
build_stop_points,
compute_trip_hash,
match_using_geo_distance,
Expand All @@ -29,9 +35,16 @@
_REQUIRED_TABLES = ["stops", "trips", "routes", "stop_times", "shapes"]


@dataclass(frozen=True)
class TripInfo:
trip_id: str
route_id: str
csv_row_number: int


def _problems_to_notices(
problems: list[Problem],
trip: dict,
trip: TripInfo,
shape_id: str,
stops_by_id: dict,
reported_stop_ids: set[str],
Expand Down Expand Up @@ -71,9 +84,9 @@ def _problems_to_notices(
code=code,
severity=Severity.WARNING,
fields={
"trip_csv_row_number": trip["csv_row_number"],
"trip_csv_row_number": trip.csv_row_number,
"shape_id": shape_id,
"trip_id": trip["trip_id"],
"trip_id": trip.trip_id,
"stop_time_csv_row_number": stop_time_row["csv_row_number"],
"stop_id": stop_id,
"stop_name": stop_name,
Expand Down Expand Up @@ -101,9 +114,9 @@ def _problems_to_notices(
code="stop_has_too_many_matches_for_shape",
severity=Severity.WARNING,
fields={
"trip_csv_row_number": trip["csv_row_number"],
"trip_csv_row_number": trip.csv_row_number,
"shape_id": shape_id,
"trip_id": trip["trip_id"],
"trip_id": trip.trip_id,
"stop_time_csv_row_number": stop_time_row["csv_row_number"],
"stop_id": stop_id,
"stop_name": stop_name,
Expand Down Expand Up @@ -137,9 +150,9 @@ def _problems_to_notices(
code="stops_match_shape_out_of_order",
severity=Severity.WARNING,
fields={
"trip_csv_row_number": trip["csv_row_number"],
"trip_csv_row_number": trip.csv_row_number,
"shape_id": shape_id,
"trip_id": trip["trip_id"],
"trip_id": trip.trip_id,
"stop_time_csv_row_number1": stop_time_row1["csv_row_number"],
"stop_id1": stop_id1,
"stop_name1": stop_name1,
Expand Down Expand Up @@ -191,14 +204,32 @@ def validate_shape_to_stop_matching(

stops_by_id: dict[str, dict] = {
row["stop_id"]: row
for row in stops_df.select(available_stops_cols).to_dicts()
for row in stops_df.select(available_stops_cols).iter_rows(named=True)
}

routes_by_id: dict[str, dict] = {
row["route_id"]: row
for row in routes_df.select(["route_id", "route_type"]).to_dicts()
route_type_by_id: dict[str, int] = {
row["route_id"]: row["route_type"]
for row in routes_df.select(["route_id", "route_type"]).iter_rows(named=True)
}

# Group trips by shape_id first so downstream row processing can filter.
trip_cols = ["trip_id", "route_id", "shape_id", "csv_row_number"]
trips_by_shape_id: dict[str, list[TripInfo]] = defaultdict(list)
relevant_trip_ids: set[str] = set()
for row in trips_df.select(trip_cols).iter_rows(named=True):
shape_id = row.get("shape_id")
if shape_id:
trip = TripInfo(
trip_id=row["trip_id"],
route_id=row["route_id"],
csv_row_number=row["csv_row_number"],
)
trips_by_shape_id[shape_id].append(trip)
relevant_trip_ids.add(trip.trip_id)

if not trips_by_shape_id:
return []

# Group stop_times by trip_id
st_cols = ["trip_id", "stop_id", "stop_sequence", "csv_row_number"]
optional_st = ["shape_dist_traveled"]
Expand All @@ -209,15 +240,13 @@ def validate_shape_to_stop_matching(

# Pre-sort in Polars (vectorized, GIL-released) so Python loops need no further sorting.
st_by_trip_id: dict[str, list[dict]] = defaultdict(list)
for row in stop_times_df.select(available_st_cols).sort(["trip_id", "stop_sequence"]).to_dicts():
st_by_trip_id[row["trip_id"]].append(row)

# Group trips by shape_id
trip_cols = ["trip_id", "route_id", "shape_id", "csv_row_number"]
trips_by_shape_id: dict[str, list[dict]] = defaultdict(list)
for row in trips_df.select(trip_cols).to_dicts():
if row.get("shape_id"):
trips_by_shape_id[row["shape_id"]].append(row)
for row in (
stop_times_df.select(available_st_cols)
.sort(["trip_id", "stop_sequence"])
.iter_rows(named=True)
):
if row["trip_id"] in relevant_trip_ids:
st_by_trip_id[row["trip_id"]].append(row)

# Group shapes by shape_id
shape_cols = ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"]
Expand All @@ -227,10 +256,17 @@ def validate_shape_to_stop_matching(
if col in shapes_df.columns:
available_shape_cols.append(col)

relevant_shape_ids = set(trips_by_shape_id.keys())
# Pre-sort by shape_pt_sequence in Polars so build_shape_points needs no further sorting.
shapes_groups: dict[str, list[dict]] = defaultdict(list)
for row in shapes_df.select(available_shape_cols).sort(["shape_id", "shape_pt_sequence"]).to_dicts():
shapes_groups[row["shape_id"]].append(row)
for row in (
shapes_df.select(available_shape_cols)
.sort(["shape_id", "shape_pt_sequence"])
.iter_rows(named=True)
):
shape_id = row["shape_id"]
if shape_id in relevant_shape_ids:
shapes_groups[shape_id].append(row)

notices: list[Notice] = []

Expand All @@ -243,30 +279,76 @@ def validate_shape_to_stop_matching(
if not shape_points:
continue

num_segments = max(0, len(shape_points) - 1)
max_trip_stops = max(
(len(st_by_trip_id.get(trip.trip_id, [])) for trip in trips_for_shape),
default=0,
)
should_build_shape_arrays = (
num_segments >= _VEC_MIN_SEGMENTS
and (num_segments * max_trip_stops) >= _VEC_MIN_WORK_ITEMS
)
shape_arrays = build_shape_arrays(shape_points) if should_build_shape_arrays else None
shape_spatial_index = build_shape_spatial_index(shape_points)
shape_has_user_dist = shape_points[-1].user_distance > 0.0
reported_stop_ids: set[str] = set()
seen_trip_hashes: set[tuple] = set()
# Two-stage dedup: quick signature first, full trip hash only on collisions.
dedup_quick_state: dict[
tuple[int, str, str, float, float], list[dict] | set[tuple]
] = {}
candidates_cache: dict[tuple[str, float, float, float], list[CandidateMatch]] = {}
closest_cache: dict[tuple[str, float, float], CandidateMatch] = {}
need_trip_dedup = len(trips_for_shape) > 1

for trip in trips_for_shape:
trip_id = trip["trip_id"]
stop_times = st_by_trip_id.get(trip_id, [])
stop_times = st_by_trip_id.get(trip.trip_id, [])
if not stop_times:
continue

trip_hash = compute_trip_hash(stop_times)
if trip_hash in seen_trip_hashes:
continue
seen_trip_hashes.add(trip_hash)

route = routes_by_id.get(trip["route_id"])
if route is None:
if need_trip_dedup:
first = stop_times[0]
last = stop_times[-1]
quick_sig = (
len(stop_times),
str(first.get("stop_id") or ""),
str(last.get("stop_id") or ""),
float(first.get("shape_dist_traveled") or 0.0),
float(last.get("shape_dist_traveled") or 0.0),
)
state = dedup_quick_state.get(quick_sig)
if state is None:
dedup_quick_state[quick_sig] = stop_times
elif isinstance(state, list):
first_hash = compute_trip_hash(state)
cur_hash = compute_trip_hash(stop_times)
hash_bucket = {first_hash}
dedup_quick_state[quick_sig] = hash_bucket
if cur_hash in hash_bucket:
continue
hash_bucket.add(cur_hash)
else:
cur_hash = compute_trip_hash(stop_times)
if cur_hash in state:
continue
state.add(cur_hash)

route_type = route_type_by_id.get(trip.route_id)
if route_type is None:
continue

is_large_route = route["route_type"] == RAIL_ROUTE_TYPE
is_large_route = route_type == RAIL_ROUTE_TYPE
stop_points = build_stop_points(stop_times, stops_by_id, is_large_route, resolved_latlng)

# Geo-distance matching (always)
geo_problems = match_using_geo_distance(stop_points, shape_points, settings)
geo_problems = match_using_geo_distance(
stop_points,
shape_points,
settings,
shape_arrays,
shape_spatial_index,
candidates_cache=candidates_cache,
closest_cache=closest_cache,
)
notices.extend(
_problems_to_notices(
geo_problems,
Expand All @@ -282,7 +364,12 @@ def validate_shape_to_stop_matching(
stops_have_user_dist = stop_points[-1].user_distance > 0.0
if stops_have_user_dist and shape_has_user_dist:
user_problems = match_using_user_distance(
stop_points, shape_points, settings
stop_points,
shape_points,
settings,
shape_arrays,
shape_spatial_index,
candidates_cache=candidates_cache,
)
notices.extend(
_problems_to_notices(
Expand Down
Loading