|
| 1 | +import re |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from pems_data.s3 import S3Bucket |
| 6 | + |
| 7 | + |
| 8 | +class TestS3Bucket: |
| 9 | + |
| 10 | + @pytest.fixture |
| 11 | + def bucket(self) -> S3Bucket: |
| 12 | + return S3Bucket() |
| 13 | + |
| 14 | + @pytest.fixture(autouse=True) |
| 15 | + def mock_s3(self, mocker): |
| 16 | + s3 = mocker.patch("boto3.client").return_value |
| 17 | + s3.list_objects.return_value = { |
| 18 | + "Contents": [ |
| 19 | + {"Key": "path1/file2.json"}, |
| 20 | + {"Key": "path2/file1.json"}, |
| 21 | + {"Key": "path1/file1.json"}, |
| 22 | + ] |
| 23 | + } |
| 24 | + return s3 |
| 25 | + |
| 26 | + @pytest.fixture(autouse=True) |
| 27 | + def mock_read_parquet(self, mocker): |
| 28 | + return mocker.patch("pandas.read_parquet") |
| 29 | + |
| 30 | + def test_name_custom(self): |
| 31 | + assert S3Bucket("name").name == "name" |
| 32 | + |
| 33 | + def test_name_default(self): |
| 34 | + assert S3Bucket().name == S3Bucket.prod_marts |
| 35 | + |
| 36 | + def test_get_prefixes__default(self, bucket: S3Bucket, mock_s3): |
| 37 | + result = bucket.get_prefixes() |
| 38 | + |
| 39 | + mock_s3.list_objects.assert_called_once_with(Bucket=bucket.name, Prefix="") |
| 40 | + assert result == ["path1/file1.json", "path1/file2.json", "path2/file1.json"] |
| 41 | + |
| 42 | + def test_get_prefixes__filter_pattern(self, bucket: S3Bucket): |
| 43 | + result = bucket.get_prefixes(re.compile("path1/.+")) |
| 44 | + |
| 45 | + assert result == ["path1/file1.json", "path1/file2.json"] |
| 46 | + |
| 47 | + def test_get_prefixes__initial_prefix(self, bucket: S3Bucket, mock_s3): |
| 48 | + bucket.get_prefixes(initial_prefix="prefix") |
| 49 | + |
| 50 | + mock_s3.list_objects.assert_called_once_with(Bucket=bucket.name, Prefix="prefix") |
| 51 | + |
| 52 | + def test_get_prefixes__match_func(self, bucket: S3Bucket): |
| 53 | + result = bucket.get_prefixes(re.compile("path1/(.+)"), match_func=lambda m: m.group(1)) |
| 54 | + |
| 55 | + assert result == ["file1.json", "file2.json"] |
| 56 | + |
| 57 | + def test_read_parquet(self, bucket: S3Bucket, mock_read_parquet): |
| 58 | + mock_read_parquet.return_value = "data" |
| 59 | + expected_path = bucket.url("path") |
| 60 | + |
| 61 | + columns = ["col1", "col2", "col3"] |
| 62 | + filters = [("col1", "=", "val1")] |
| 63 | + |
| 64 | + result = bucket.read_parquet("path", columns=columns, filters=filters, extra1="extra1", extra2="extra2") |
| 65 | + |
| 66 | + assert result == "data" |
| 67 | + mock_read_parquet.assert_called_once_with( |
| 68 | + expected_path, columns=columns, filters=filters, extra1="extra1", extra2="extra2" |
| 69 | + ) |
| 70 | + |
| 71 | + def test_url__no_path(self, bucket: S3Bucket): |
| 72 | + assert bucket.url() == f"s3://{bucket.name}" |
| 73 | + |
| 74 | + def test_url__with_path(self, bucket: S3Bucket): |
| 75 | + assert bucket.url("path1", "path2") == f"s3://{bucket.name}/path1/path2" |
0 commit comments