Skip to content
This repository was archived by the owner on Sep 29, 2025. It is now read-only.

Commit 0efc880

Browse files
authored
Feat: pems_data caching (#194)
2 parents 3fc85c5 + 1f4d38e commit 0efc880

File tree

18 files changed

+663
-25
lines changed

18 files changed

+663
-25
lines changed

.devcontainer/devcontainer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"name": "caltrans/pems",
33
"dockerComposeFile": ["../compose.yml"],
44
"service": "dev",
5-
"runServices": ["dev", "pgweb"],
5+
"runServices": ["dev", "pgweb", "redis"],
66
"forwardPorts": ["docs:8000"],
77
"workspaceFolder": "/caltrans/app",
88
"postStartCommand": ["/bin/bash", "bin/setup.sh"],

.env.sample

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ STREAMLIT_NAV=hidden
2727

2828
# AWS
2929
AWS_PROFILE=pems
30+
31+
# Redis
32+
REDIS_PORT=6379
33+
REDIS_HOSTNAME=redis

compose.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ services:
7575
ports:
7676
- "${STREAMLIT_LOCAL_PORT:-8501}:8501"
7777

78+
redis:
79+
image: redis:8
80+
ports:
81+
- "${REDIS_PORT:-6379}:6379"
82+
command: redis-server --save 60 1 --loglevel notice
83+
healthcheck:
84+
test: redis-cli ping | grep PONG
85+
interval: 1s
86+
timeout: 5s
87+
retries: 5
88+
volumes:
89+
- redisdata:/data
90+
7891
volumes:
7992
pgdata:
8093
driver: local
94+
redisdata:
95+
driver: local

pems_data/pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@ name = "pems_data"
33
description = "Common data access library for PeMS."
44
dynamic = ["version"]
55
requires-python = ">=3.12"
6-
dependencies = ["boto3==1.39.7", "pandas==2.3.0"]
6+
dependencies = [
7+
"boto3==1.39.7",
8+
"pandas==2.3.0",
9+
"pyarrow==21.0.0",
10+
"redis==6.2.0",
11+
]
12+
13+
[project.scripts]
14+
pems-cache = "pems_data.cli:cache"
715

816
[build-system]
917
requires = ["setuptools>=75", "setuptools_scm>=8"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from pems_data.cache import Cache
2+
from pems_data.services.stations import StationsService
3+
from pems_data.sources.cache import CachingDataSource
4+
from pems_data.sources.s3 import S3DataSource
5+
6+
7+
class ServiceFactory:
8+
"""
9+
A factory class to create and configure various services.
10+
11+
Shared dependencies are created once during initialization.
12+
"""
13+
14+
def __init__(self):
15+
self.cache = Cache()
16+
self.s3_source = S3DataSource()
17+
self.caching_s3_source = CachingDataSource(data_source=self.s3_source, cache=self.cache)
18+
19+
def stations_service(self) -> StationsService:
20+
"""Creates a fully-configured `StationsService`."""
21+
return StationsService(data_source=self.caching_s3_source)

pems_data/src/pems_data/cache.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import logging
2+
import os
3+
from typing import Any, Callable
4+
5+
import pandas as pd
6+
import redis
7+
8+
from pems_data.serialization import arrow_bytes_to_df, df_to_arrow_bytes
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def redis_connection(host: str = None, port: int = None, **kwargs) -> redis.Redis | None:
14+
"""Try to create a new connection to a redis backend. Return None if the connection fails.
15+
16+
Uses the `REDIS_HOSTNAME` and `REDIS_PORT` environment variables as fallback.
17+
18+
Args:
19+
host (str): The redis hostname
20+
port (int): The port to connect on
21+
"""
22+
23+
host = host or os.environ.get("REDIS_HOSTNAME", "redis")
24+
port = int(port or os.environ.get("REDIS_PORT", "6379"))
25+
26+
logger.debug(f"connecting to redis @ {host}:{port}")
27+
28+
kwargs["host"] = host
29+
kwargs["port"] = port
30+
31+
try:
32+
return redis.Redis(**kwargs)
33+
except redis.ConnectionError as ce:
34+
logger.error(f"connection failed for redis @ {host}:{port}", exc_info=ce)
35+
return None
36+
37+
38+
class Cache:
39+
"""Basic caching interface for `pems_data`."""
40+
41+
@classmethod
42+
def build_key(cls, *args) -> str:
43+
"""Build a cache key from the given parts."""
44+
return ":".join([str(a).lower() for a in args])
45+
46+
def __init__(self, host: str = None, port: int = None):
47+
"""Create a new instance of the Cache interface.
48+
49+
Args:
50+
host (str): (Optional) The hostname of the cache backend.
51+
port (int): (Optional) The port to connect on the cache backend.
52+
"""
53+
54+
self.host = host
55+
self.port = port
56+
self.c = None
57+
58+
def _connect(self):
59+
"""Establish a connection to the cache backend if necessary."""
60+
if not isinstance(self.c, redis.Redis):
61+
self.c = redis_connection(self.host, self.port)
62+
63+
def is_available(self) -> bool:
64+
"""Return a bool indicating if the cache backend is available or not."""
65+
self._connect()
66+
available = self.c and self.c.ping() is True
67+
logger.debug(f"cache is available: {available}")
68+
return available
69+
70+
def get(self, key: str, mutate_func: Callable[[Any], Any] = None) -> Any:
71+
"""Get a raw value from the cache, or None if the key doesn't exist.
72+
73+
Args:
74+
key (str): The item's cache key.
75+
mutate_func (callable): If provided, call this on the cached value and return its result.
76+
"""
77+
if self.is_available():
78+
logger.debug(f"read from cache: {key}")
79+
value = self.c.get(key)
80+
if value and mutate_func:
81+
logger.debug(f"mutating cached value: {key}")
82+
return mutate_func(value)
83+
return value
84+
logger.warning(f"cache unavailable to get: {key}")
85+
return None
86+
87+
def get_df(self, key: str) -> pd.DataFrame:
88+
"""Get a `pandas.DataFrame` from the cache, or None if the key doesn't exist."""
89+
return self.get(key, mutate_func=arrow_bytes_to_df)
90+
91+
def set(self, key: str, value: Any, ttl: int = None, mutate_func: Callable[[Any], Any] = None) -> None:
92+
"""Set a value in the cache.
93+
94+
Args:
95+
key (str): The item's cache key.
96+
value (Any): The item's value to store in the cache.
97+
ttl (int): Seconds until expiration.
98+
mutate_func (callable): If provided, call this on the value and insert the result in the cache.
99+
"""
100+
if self.is_available():
101+
if mutate_func:
102+
logger.debug(f"mutating value for cache: {key}")
103+
value = mutate_func(value)
104+
logger.debug(f"store in cache: {key}")
105+
self.c.set(key, value, ex=ttl)
106+
else:
107+
logger.warning(f"cache unavailable to set: {key}")
108+
109+
def set_df(self, key: str, value: pd.DataFrame, ttl: int = None) -> None:
110+
"""Set a `pandas.DataFrame` in the cache, with an optional TTL (seconds until expiration)."""
111+
self.set(key, value, mutate_func=df_to_arrow_bytes, ttl=ttl)

pems_data/src/pems_data/cli.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import argparse
2+
import sys
3+
4+
from pems_data.cache import Cache
5+
6+
7+
def cache(): # prama: no cover
8+
parser = argparse.ArgumentParser("pems-cache", description="Simple CLI for the cache")
9+
parser.add_argument("op", choices=("check", "get", "set"), default="check", nargs="?", help="the operation to perform")
10+
parser.add_argument("--key", "-k", required=False, type=str, help="the item's key, required for get/set")
11+
parser.add_argument("--value", "-v", required=False, type=str, help="the item's value, required for set")
12+
parsed_args = parser.parse_args(sys.argv[1:])
13+
14+
c = Cache()
15+
16+
match parsed_args.op:
17+
case "get":
18+
if parsed_args.key:
19+
print(f"[{parsed_args.key}]: {c.get(parsed_args.key)}")
20+
else:
21+
parser.print_usage()
22+
raise SystemExit(1)
23+
case "set":
24+
if parsed_args.key and parsed_args.value:
25+
print(f"[{parsed_args.key}] = '{parsed_args.value}'")
26+
c.set(parsed_args.key, parsed_args.value)
27+
else:
28+
parser.print_usage()
29+
raise SystemExit(1)
30+
case _:
31+
print(f"cache is available: {c.is_available()}")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pandas as pd
2+
import pyarrow as pa
3+
import pyarrow.ipc as ipc
4+
5+
6+
def arrow_bytes_to_df(arrow_buffer: bytes) -> pd.DataFrame:
7+
"""Deserializes Arrow IPC format `bytes` back to a `pandas.DataFrame`."""
8+
if not arrow_buffer:
9+
return pd.DataFrame()
10+
# deserialize the Arrow IPC stream
11+
with pa.BufferReader(arrow_buffer) as buffer:
12+
# the reader reconstructs the Arrow Table from the buffer
13+
reader = ipc.RecordBatchStreamReader(buffer)
14+
arrow_table = reader.read_all()
15+
return arrow_table.to_pandas()
16+
17+
18+
def df_to_arrow_bytes(df: pd.DataFrame) -> bytes:
19+
"""Serializes a `pandas.DataFrame` to Arrow IPC format `bytes`."""
20+
if df.empty:
21+
return b""
22+
# convert DataFrame to an Arrow Table
23+
arrow_table = pa.Table.from_pandas(df, preserve_index=False)
24+
# serialize the Arrow Table to bytes using the IPC stream format
25+
sink = pa.BufferOutputStream()
26+
with ipc.RecordBatchStreamWriter(sink, arrow_table.schema) as writer:
27+
writer.write_table(arrow_table)
28+
# get the buffer from the stream
29+
return sink.getvalue().to_pybytes()

pems_data/src/pems_data/services/stations.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pandas as pd
22

3+
from pems_data.cache import Cache
34
from pems_data.sources import IDataSource
45

56

@@ -12,9 +13,13 @@ class StationsService:
1213
def __init__(self, data_source: IDataSource):
1314
self.data_source = data_source
1415

16+
def _build_cache_key(self, *args):
17+
return Cache.build_key("stations", *args)
18+
1519
def get_district_metadata(self, district_number: str) -> pd.DataFrame:
1620
"""Loads metadata for all stations in the selected District from S3."""
1721

22+
cache_opts = {"key": self._build_cache_key("metadata", "district", district_number), "ttl": 3600} # 1 hour
1823
columns = [
1924
"STATION_ID",
2025
"NAME",
@@ -33,11 +38,12 @@ def get_district_metadata(self, district_number: str) -> pd.DataFrame:
3338
]
3439
filters = [("DISTRICT", "=", district_number)]
3540

36-
return self.data_source.read(self.metadata_file, columns=columns, filters=filters)
41+
return self.data_source.read(self.metadata_file, cache_opts=cache_opts, columns=columns, filters=filters)
3742

3843
def get_imputed_agg_5min(self, station_id: str) -> pd.DataFrame:
3944
"""Loads imputed aggregate 5 minute data for a specific station."""
4045

46+
cache_opts = {"key": self._build_cache_key("imputed", "agg", "5m", "station", station_id), "ttl": 300} # 5 minutes
4147
columns = [
4248
"STATION_ID",
4349
"LANE",
@@ -48,4 +54,6 @@ def get_imputed_agg_5min(self, station_id: str) -> pd.DataFrame:
4854
]
4955
filters = [("STATION_ID", "=", station_id)]
5056

51-
return self.data_source.read(self.imputation_detector_agg_5min, columns=columns, filters=filters)
57+
return self.data_source.read(
58+
self.imputation_detector_agg_5min, cache_opts=cache_opts, columns=columns, filters=filters
59+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pandas as pd
2+
3+
from pems_data.cache import Cache
4+
from pems_data.sources import IDataSource
5+
6+
7+
class CachingDataSource(IDataSource):
8+
"""
9+
A DataSource decorator that adds a caching layer to another data source.
10+
"""
11+
12+
def __init__(self, data_source: IDataSource, cache: Cache):
13+
self.cache = cache
14+
self.data_source = data_source
15+
16+
def read(self, identifier: str, **kwargs) -> pd.DataFrame:
17+
# get cache options from kwargs
18+
cache_opts = kwargs.pop("cache_opts", {})
19+
# use cache key from options, fallback to identifier
20+
cache_key = cache_opts.get("key", identifier)
21+
ttl = cache_opts.get("ttl")
22+
23+
# try to get df from cache
24+
cached_df = self.cache.get_df(cache_key)
25+
if cached_df is not None:
26+
return cached_df
27+
28+
# on miss, call the wrapped source
29+
df = self.data_source.read(identifier, **kwargs)
30+
# store the result in the cache
31+
self.cache.set_df(cache_key, df, ttl=ttl)
32+
33+
return df

0 commit comments

Comments
 (0)