Skip to content
This repository was archived by the owner on Sep 29, 2025. It is now read-only.

Commit 2cb65f3

Browse files
committed
refactor(pems_data): more generic sources subpackage
create an abstract base class interface for a data source refactor S3Bucket to S3DataSource
1 parent ba09852 commit 2cb65f3

File tree

7 files changed

+141
-81
lines changed

7 files changed

+141
-81
lines changed

pems_data/src/pems_data/sources/__init__.py

Whitespace-only changes.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from abc import ABC, abstractmethod
2+
3+
import pandas as pd
4+
5+
6+
class IDataSource(ABC):
7+
"""
8+
An abstract interface for a generic data source.
9+
"""
10+
11+
@abstractmethod
12+
def read(self, identifier: str, **kwargs) -> pd.DataFrame:
13+
"""
14+
Reads data identified by a generic identifier from the source.
15+
16+
Args:
17+
identifier (str): The unique identifier for the data, e.g.,
18+
an S3 key, a database table name, etc.
19+
**kwargs: Additional arguments for the underlying read operation,
20+
such as 'columns' or 'filters'.
21+
"""
22+
raise NotImplementedError

pems_data/src/pems_data/s3.py renamed to pems_data/src/pems_data/sources/s3.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
import os
12
import re
23

34
import boto3
45
import pandas as pd
56

7+
from pems_data.sources.base import IDataSource
68

7-
class S3Bucket:
8-
prod_marts = "caltrans-pems-prd-us-west-2-marts"
9+
10+
class S3DataSource(IDataSource):
11+
default_bucket = os.environ.get("S3_BUCKET_NAME", "caltrans-pems-prd-us-west-2-marts")
912

1013
def __init__(self, name: str = None):
11-
self.name = name or self.prod_marts
14+
self.name = name or self.default_bucket
15+
self._client = boto3.client("s3")
1216

1317
def get_prefixes(self, filter_pattern: re.Pattern = re.compile(".+"), initial_prefix: str = "", match_func=None) -> list:
1418
"""
@@ -17,8 +21,7 @@ def get_prefixes(self, filter_pattern: re.Pattern = re.compile(".+"), initial_pr
1721
When a match is found, if match_func exists, add its result to the output list. Otherwise add the entire match.
1822
"""
1923

20-
s3 = boto3.client("s3")
21-
s3_keys = s3.list_objects(Bucket=self.name, Prefix=initial_prefix)
24+
s3_keys = self._client.list_objects(Bucket=self.name, Prefix=initial_prefix)
2225

2326
result = set()
2427

@@ -33,7 +36,7 @@ def get_prefixes(self, filter_pattern: re.Pattern = re.compile(".+"), initial_pr
3336

3437
return sorted(result)
3538

36-
def read_parquet(self, *args, path=None, columns=None, filters=None, **kwargs) -> pd.DataFrame:
39+
def read(self, *args: str, path=None, columns=None, filters=None, **kwargs) -> pd.DataFrame:
3740
"""Reads data from the S3 path into a pandas DataFrame. Extra kwargs are pass along to `pandas.read_parquet()`.
3841
3942
Args:

tests/pytest/pems_data/sources/__init__.py

Whitespace-only changes.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from pems_data.sources.base import IDataSource
5+
6+
7+
class TestIDataSource:
8+
9+
def test_cannot_instantiate_abstract(self):
10+
"""Test that IDataSource cannot be instantiated directly"""
11+
with pytest.raises(TypeError, match=r"Can't instantiate abstract class IDataSource"):
12+
IDataSource()
13+
14+
def test_must_implement_read(self):
15+
"""Test that concrete classes must implement read method"""
16+
17+
class InvalidSource(IDataSource):
18+
pass
19+
20+
with pytest.raises(TypeError, match=r"Can't instantiate abstract class InvalidSource"):
21+
InvalidSource()
22+
23+
def test_valid_implementation(self):
24+
"""Test that a valid implementation can be instantiated and used"""
25+
26+
class ValidSource(IDataSource):
27+
def read(self, identifier: str, **kwargs) -> pd.DataFrame:
28+
return pd.DataFrame({"test": [1, 2, 3]})
29+
30+
source = ValidSource()
31+
result = source.read("test-id", columns=["col1"])
32+
33+
assert isinstance(result, pd.DataFrame)
34+
assert not result.empty
35+
assert result.equals(pd.DataFrame({"test": [1, 2, 3]}))
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import re
2+
3+
import pytest
4+
5+
from pems_data.sources.s3 import S3DataSource
6+
7+
8+
class TestS3DataSource:
9+
10+
@pytest.fixture
11+
def data_source(self) -> S3DataSource:
12+
return S3DataSource()
13+
14+
@pytest.fixture(autouse=True)
15+
def mock_s3(self, mocker):
16+
s3 = mocker.patch("boto3.client").return_value
17+
s3.list_objects.return_value = {
18+
"Contents": [
19+
{"Key": "path1/file2.json"},
20+
{"Key": "path2/file1.json"},
21+
{"Key": "path1/file1.json"},
22+
]
23+
}
24+
return s3
25+
26+
@pytest.fixture(autouse=True)
27+
def mock_read_parquet(self, mocker):
28+
return mocker.patch("pandas.read_parquet")
29+
30+
def test_name_custom(self):
31+
assert S3DataSource("name").name == "name"
32+
33+
def test_name_default(self):
34+
assert S3DataSource().name == S3DataSource.default_bucket
35+
36+
def test_get_prefixes__default(self, data_source: S3DataSource, mock_s3):
37+
result = data_source.get_prefixes()
38+
39+
mock_s3.list_objects.assert_called_once_with(Bucket=data_source.name, Prefix="")
40+
assert result == ["path1/file1.json", "path1/file2.json", "path2/file1.json"]
41+
42+
def test_get_prefixes__filter_pattern(self, data_source: S3DataSource):
43+
result = data_source.get_prefixes(re.compile("path1/.+"))
44+
45+
assert result == ["path1/file1.json", "path1/file2.json"]
46+
47+
def test_get_prefixes__initial_prefix(self, data_source: S3DataSource, mock_s3):
48+
data_source.get_prefixes(initial_prefix="prefix")
49+
50+
mock_s3.list_objects.assert_called_once_with(Bucket=data_source.name, Prefix="prefix")
51+
52+
def test_get_prefixes__match_func(self, data_source: S3DataSource):
53+
result = data_source.get_prefixes(re.compile("path1/(.+)"), match_func=lambda m: m.group(1))
54+
55+
assert result == ["file1.json", "file2.json"]
56+
57+
def test_read(self, data_source: S3DataSource, mock_read_parquet):
58+
mock_read_parquet.return_value = "data"
59+
expected_path = data_source.url("path")
60+
61+
columns = ["col1", "col2", "col3"]
62+
filters = [("col1", "=", "val1")]
63+
64+
result = data_source.read("path", columns=columns, filters=filters, extra1="extra1", extra2="extra2")
65+
66+
assert result == "data"
67+
mock_read_parquet.assert_called_once_with(
68+
expected_path, columns=columns, filters=filters, extra1="extra1", extra2="extra2"
69+
)
70+
71+
def test_url__no_path(self, data_source: S3DataSource):
72+
assert data_source.url() == f"s3://{data_source.name}"
73+
74+
def test_url__with_path(self, data_source: S3DataSource):
75+
assert data_source.url("path1", "path2") == f"s3://{data_source.name}/path1/path2"

tests/pytest/pems_data/test_s3.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

0 commit comments

Comments
 (0)