Skip to content


fix(yankee): refactor yankee! make it simpler, and it now works local…
Browse files Browse the repository at this point in the history
…ly for me
  • Loading branch information
rudiejd committed Mar 29, 2024
1 parent cf921b1 commit c72fab5
Showing 1 changed file with 152 additions and 84 deletions.
236 changes: 152 additions & 84 deletions ingestor/chalicelib/
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ddtrace import tracer
from dataclasses import dataclass
from chalicelib import dynamo, s3
from decimal import Decimal

from typing import List
from sqlalchemy.orm import Session
Expand All @@ -22,7 +23,9 @@
BUCKET = "tm-shuttle-positions"
KEY = "yankee/last_shuttle_positions.csv"
BOSTON_COORDS = (-71.057083, 42.361145)
METERS_PER_MILE = 0.000621371
Expand All @@ -38,19 +41,31 @@ class ShuttleTravelTime:
line: str
# route of the trip e.g. Shuttle-AlewifeParkSt
routeId: str
date: datetime
date: str
# distance in miles of the trip
distance_miles: float
distance_miles: Decimal
# time in minutes of the trip
time: float
time: Decimal
# ID of the stop from which the shuttle originated
from_stop_id: str
# ID of the stop the shuttle travelled to
to_stop_id: str
# yankee's identifier for the bus that made the trip
name: str

class ShuttlePosition:
name: str
latitude: str
longitude: str
detected_route: str
detected_stop_id: str
last_update_date: str

def load_bus_positions() -> Optional[List[Dict]]:
def load_bus_positions() -> Optional[List[ShuttlePosition]]:
data =, KEY, compressed=False)
return json.loads(data)
return json.loads(data, object_hook=lambda pos: ShuttlePosition(**pos))
except ClientError as ex:
if ex.response["Error"]["Code"] != "NoSuchKey":
Expand All @@ -62,6 +77,54 @@ def load_bus_positions() -> Optional[List[Dict]]:
def get_shuttle_stops(session: Session) -> List[Stop]:
return session.query(Stop).filter(Stop.platform_name.contains("Shuttle")).all()

def get_stop_in_radius(coords: Coords, session: Session) -> Optional[Stop]:
result: List[Stop] = []

distance_fn = lambda s: distance.geodesic((s.stop_lon, s.stop_lat), coords) <= STOP_RADIUS_MILES
result = session.query(Stop).filter(

result: List[Stop] = list(filter(distance_fn, result))
except Exception as e:
print(f"Failed to match coords {coords} to stop")
print(f"Exception: {e}")
return None

if len(result) == 0:
return None

return sorted(result, key=distance_fn)[0]

def get_stop_by_id(session: Session, stop_id: Optional[str]):
if stop_id is None:
return None

result = None
result = session.query(Stop).filter(
Stop.stop_id == stop_id
except Exception:
print(f"Failed to find stop with ID {stop_id}")

return result

# TODO(rudiejd): Make types for the yankee API response
def query_yankee_bus_api():
headers = {"accept": "application/json", "authorization": f"Bearer {YANKEE_API_KEY}"}

response = requests.get(YANKEE_BUS_API, headers=headers)
if response.status_code != 200:
raise Exception(f"Received status code {response.status_code} from Samsara bus API. Body: {response.text}")
buses = json.loads(response.text)["data"]
return buses
except Exception:
raise Exception(f"Bus response problematic. We received {json}")

def get_shuttle_shapes(
session: Session,
Expand Down Expand Up @@ -108,8 +171,21 @@ def get_session_for_latest_feed() -> Session:
latest_feed = archive.get_latest_feed()
feeds = archive.get_all_feeds()
if not feeds:
raise Exception("Failed to get feeds from MBTA list")

latest_feed = next(feed for feed in reversed(feeds) if feed.exists_remotely())

if not latest_feed:
raise Exception("Unable to find feed in S3, aborting")

print(f"Downloading data from feed with key {latest_feed.key}")


print("Finished downloading data for feed")

return latest_feed.create_sqlite_session()

Expand Down Expand Up @@ -144,19 +220,21 @@ def is_in_shape(coords: Tuple[float, float], shape: List[ShapePoint]):
return in_shape

def save_bus_positions(bus_positions: List[dict]):
def save_bus_positions(bus_positions: List[ShuttlePosition]):
now_str =
print(f"{now_str}: saving bus positions")

s3.upload(BUCKET, KEY, json.dumps(bus_positions), compress=False)
bus_positions_dicts = list(map(lambda pos: pos.__dict__, bus_positions))

s3.upload(BUCKET, KEY, json.dumps(bus_positions_dicts), compress=False)

def write_traveltimes_to_dynamo(travel_times: List[Optional[ShuttleTravelTime]]):
row_dicts = []
for travel_time in travel_times:
if travel_time:

def write_traveltimes_to_dynamo(travel_times: List[ShuttleTravelTime]):
row_dicts = list(map(lambda pos: pos.__dict__, travel_times))

print(f"Writing {len(row_dicts)} travel times to dynamo")
dynamo.dynamo_batch_write(row_dicts, SHUTTLE_TRAVELTIME_TABLE)
print("Finished writing to dynamo")

def get_driving_distance(old_coords: Tuple[float, float], new_coords: Tuple[float, float]) -> Optional[float]:
Expand Down Expand Up @@ -225,23 +303,11 @@ def get_driving_distance(old_coords: Tuple[float, float], new_coords: Tuple[floa

# TODO: this function is doing too much, trying to make it chill
def _update_shuttles(last_bus_positions: List[Dict], shuttle_shapes: ShapeDict, shuttle_stops: List[Stop]):
url = ""
def _update_shuttles(last_bus_positions: List[ShuttlePosition], shuttle_shapes: ShapeDict, session: Session):
buses = query_yankee_bus_api()

headers = {"accept": "application/json", "authorization": f"Bearer {YANKEE_API_KEY}"}

response = requests.get(url, headers=headers)
if response.status_code != 200:
raise Exception(f"Received status code {response.status_code} from Samsara bus API. Body: {response.text}")
buses = json.loads(response.text)["data"]
except Exception:
raise Exception(f"Bus response problematic. We received {json}")

bus_positions = []

travel_times: List[Optional[ShuttleTravelTime]] = []
bus_positions: List[ShuttlePosition] = []
travel_times: List[ShuttleTravelTime] = []

for bus in buses:
name = bus["name"]
Expand All @@ -250,6 +316,10 @@ def _update_shuttles(last_bus_positions: List[Dict], shuttle_shapes: ShapeDict,

coords = (float(long), float(lat))

# skip all of the buses that are far from boston
if distance.geodesic(BOSTON_COORDS, coords).miles > MAX_DIST_FROM_BOSTON:

# skip buses that aren't in a shuttle shape
# TODO(rudiejd): optimize this. there is probably a more efficient way to check if a shape is in
# any one of a list of polygons. Maybe you can use the the ray method on all of the poly poitns?
Expand All @@ -263,44 +333,43 @@ def _update_shuttles(last_bus_positions: List[Dict], shuttle_shapes: ShapeDict,
if detected_route is None:

last_detected_stop_id = -1
print(f"Detected bus {name} on route {detected_route} at {long}, {lat}")

last_detected_stop_id = None
last_update_date = None
# travel times to write to dynamo

for pos in last_bus_positions:
if pos["name"] == name:
last_detected_stop_id = pos["detected_stop_id"]
last_update_date = pos["last_update_date"]
if == name:
last_detected_stop_id = pos.detected_stop_id
last_update_date = pos.last_update_date

detected_stop_id: int = -1
for stop in shuttle_stops:
stop_coords = (stop.stop_lon, stop.stop_lat)
if distance.geodesic(stop_coords, coords).miles <= STOP_RADIUS_MILES:
detected_stop_id = int(stop.stop_id)
detected_stop: Optional[Stop] = get_stop_in_radius(coords, session)

# if we're not currently near a stop, use the last stop ID we detected
if detected_stop_id == -1:
detected_stop_id = last_detected_stop_id
if detected_stop is None:

print(f"Bus {name} is at stop {detected_stop.stop_name}!")

# here, we've had the bus arrive at a new stop!
if detected_stop_id != last_detected_stop_id and last_detected_stop_id != -1:
if detected_stop.stop_id != last_detected_stop_id and last_detected_stop_id != None:
# insert into table
print(f"Bus {name} arrived at stop {detected_stop_id} from stop {last_detected_stop_id}")
print(f"Bus {name} arrived at stop {detected_stop} from stop {last_detected_stop_id}")
last_detected_stop = get_stop_by_id(session, last_detected_stop_id)
travel_time = maybe_create_travel_time(
name, detected_route, last_detected_stop_id, detected_stop_id, last_update_date, shuttle_stops
name, detected_route, last_detected_stop, detected_stop, last_update_date
if travel_time:

# TODO(rudiejd) use an object to serialize this instead of a dict
# Only save the position when it's at a stop
"name": name,
"latitude": lat,
"longitude": long,
"detected_route": detected_route,
"detected_stop_id": detected_stop_id,


Expand All @@ -310,51 +379,45 @@ def _update_shuttles(last_bus_positions: List[Dict], shuttle_shapes: ShapeDict,
def maybe_create_travel_time(
name: str,
route_id: str,
last_detected_stop_id: int,
detected_stop_id: int,
last_detected_stop: Optional[Stop],
detected_stop: Optional[Stop],
last_update_date: Optional[str],
shuttle_stops: List[Stop],
if last_detected_stop is None or detected_stop is None:
print(f"Unable to create travel time for stop {last_detected_stop} to {detected_stop}")
return None

# don't write travel times with no start date
if last_update_date is None:
f"Position of bus {name} on {route_id} from {last_detected_stop_id} to {detected_stop_id} had no last update date, cannot create travel time"
f"Position of bus {name} on {route_id} from {last_detected_stop.stop_id} to {detected_stop.stop_id} had no last update date, cannot create travel time"
return None

last_update_datetime = datetime.strptime(last_update_date, TIME_FORMAT)
update_datetime =

last_stop_coords: Optional[Coords] = None
stop_coords: Optional[Coords] = None

# TODO(rudiejd) this can be made O(1) if it's slow
for stop in shuttle_stops:
coords = (stop.stop_lon, stop.stop_lat)
if stop.stop_id == last_detected_stop_id:
last_stop_coords = coords
elif stop.stop_id == detected_stop_id:
stop_coords = coords

if stop_coords is not None and last_stop_coords is not None:

if stop_coords is None or last_stop_coords is None:
f"Unable to detect stop ids. Last stop coordinates {last_stop_coords}, current stop coordinates {stop_coords}"
if last_detected_stop is None or detected_stop is None:
return None

# TODO(rudiejd) maybe precompute the stop distances for all the shuttle lines?
dist = get_driving_distance(last_stop_coords, stop_coords)
dist = get_driving_distance((last_detected_stop.stop_lon, last_detected_stop.stop_lat), (detected_stop.stop_lon, detected_stop.stop_lat))

if dist is None:
print(f"Unable calculate driving distnance for stop ids {last_detected_stop_id}, {detected_stop_id}")
print(f"Unable calculate driving distnance for stops {}, {detected_stop.stop_id} ({last_detected_stop.stop_name} to {detected_stop.stop_name})")
return None
# total time in minutes
time_minutes = (update_datetime - last_update_datetime).total_seconds() // 60

return ShuttleTravelTime(SHUTTLE_LINE, route_id,, dist, time_minutes, name)
return ShuttleTravelTime(SHUTTLE_LINE,
# cover your eyes
Decimal(str(round(dist, 2))),
Decimal(str(round(time_minutes, 2))),

def update_shuttles():
Expand All @@ -373,7 +436,12 @@ def update_shuttles():
last_bus_positions = []

session = get_session_for_latest_feed()

print("Finished creating SQLite DB")

shuttle_shapes = get_shuttle_shapes(session)
shuttle_stops = get_shuttle_stops(session)

_update_shuttles(last_bus_positions, shuttle_shapes, shuttle_stops)
last_bus_positions = _update_shuttles(last_bus_positions, shuttle_shapes, session)


0 comments on commit c72fab5

Please sign in to comment.