Skip to content

Commit da183c0

Browse files
committed
upgrade to numpy 2.0 and remove imgaug
1 parent 834d570 commit da183c0

File tree

4 files changed

+267
-48
lines changed

4 files changed

+267
-48
lines changed

ppocr/data/imaug/iaa_augment.py

+87-43
Original file line numberDiff line numberDiff line change
@@ -19,37 +19,88 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121
from __future__ import unicode_literals
22+
import os
2223

2324
import numpy as np
24-
import imgaug
25-
import imgaug.augmenters as iaa
2625

26+
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
27+
import albumentations as A
2728

28-
class AugmenterBuilder(object):
29+
30+
class AugmenterBuilder:
2931
def __init__(self):
3032
pass
3133

32-
def build(self, args, root=True):
33-
if args is None or len(args) == 0:
34+
def build(self, args):
35+
if not args:
3436
return None
3537
elif isinstance(args, list):
36-
if root:
37-
sequence = [self.build(value, root=False) for value in args]
38-
return iaa.Sequential(sequence)
38+
# Recursively build transforms from the list
39+
transforms = [self.build(value) for value in args if self.build(value)]
40+
return A.Compose(transforms, keypoint_params=A.KeypointParams(format="xy"))
41+
elif isinstance(args, dict):
42+
# Get the transform type and its arguments
43+
transform_type = args.get("type")
44+
transform_args = args.get("args", {})
45+
# Map the transform type to the corresponding function
46+
transform_func = self._get_transform_function(
47+
transform_type, transform_args
48+
)
49+
if transform_func:
50+
return transform_func
3951
else:
40-
return getattr(iaa, args[0])(
41-
*[self.to_tuple_if_list(a) for a in args[1:]]
52+
raise NotImplementedError(
53+
f"Transform {transform_type} not implemented."
4254
)
43-
elif isinstance(args, dict):
44-
cls = getattr(iaa, args["type"])
45-
return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()})
4655
else:
47-
raise RuntimeError("unknown augmenter arg: " + str(args))
56+
raise RuntimeError(f"Unknown augmenter arg: {args}")
57+
58+
def _get_transform_function(self, transform_type, transform_args):
59+
# Define mapping from transform types to functions
60+
transform_mapping = {
61+
"Fliplr": self._build_horizontal_flip,
62+
"Affine": self._build_affine,
63+
"Resize": self._build_resize,
64+
}
65+
func = transform_mapping.get(transform_type)
66+
if func:
67+
return func(transform_args)
68+
else:
69+
return None
4870

49-
def to_tuple_if_list(self, obj):
50-
if isinstance(obj, list):
51-
return tuple(obj)
52-
return obj
71+
def _build_horizontal_flip(self, transform_args):
72+
p = transform_args.get("p", 0.5)
73+
return A.HorizontalFlip(p=p)
74+
75+
def _build_affine(self, transform_args):
76+
rotate = transform_args.get("rotate")
77+
shear = transform_args.get("shear")
78+
translate_percent = transform_args.get("translate_percent")
79+
affine_args = {"fit_output": True}
80+
if rotate is not None:
81+
affine_args["rotate"] = (
82+
tuple(rotate) if isinstance(rotate, list) else rotate
83+
)
84+
if shear is not None:
85+
affine_args["shear"] = shear
86+
if translate_percent is not None:
87+
affine_args["translate_percent"] = translate_percent
88+
return A.Affine(**affine_args)
89+
90+
def _build_resize(self, transform_args):
91+
size = transform_args.get("size", [1.0, 1.0])
92+
if isinstance(size, list) and len(size) == 2:
93+
scale_factor = size[0]
94+
height = int(scale_factor * 100)
95+
width = int(scale_factor * 100)
96+
return A.Resize(height=height, width=width)
97+
elif isinstance(size, (int, float)):
98+
scale_factor = float(size)
99+
height = int(scale_factor * 100)
100+
width = int(scale_factor * 100)
101+
return A.Resize(height=height, width=width)
102+
else:
103+
raise ValueError("Invalid size parameter for Resize")
53104

54105

55106
class IaaAugment:
@@ -58,35 +109,28 @@ def __init__(self, augmenter_args=None, **kwargs):
58109
augmenter_args = [
59110
{"type": "Fliplr", "args": {"p": 0.5}},
60111
{"type": "Affine", "args": {"rotate": [-10, 10]}},
61-
{"type": "Resize", "args": {"size": [0.5, 3]}},
112+
{"type": "Resize", "args": {"size": [0.5, 3.0]}},
62113
]
63114
self.augmenter = AugmenterBuilder().build(augmenter_args)
64115

65116
def __call__(self, data):
66117
image = data["image"]
67-
shape = image.shape
68-
118+
polys = data["polys"]
119+
# Flatten polys to a list of keypoints
120+
keypoints = [tuple(point) for poly in polys for point in poly]
69121
if self.augmenter:
70-
aug = self.augmenter.to_deterministic()
71-
data["image"] = aug.augment_image(image)
72-
data = self.may_augment_annotation(aug, data, shape)
73-
return data
74-
75-
def may_augment_annotation(self, aug, data, shape):
76-
if aug is None:
77-
return data
78-
79-
line_polys = []
80-
for poly in data["polys"]:
81-
new_poly = self.may_augment_poly(aug, shape, poly)
82-
line_polys.append(new_poly)
83-
data["polys"] = np.array(line_polys)
122+
transformed = self.augmenter(image=image, keypoints=keypoints)
123+
data["image"] = transformed["image"]
124+
# Reconstruct polys from transformed keypoints
125+
transformed_keypoints = transformed["keypoints"]
126+
new_polys = []
127+
idx = 0
128+
for poly in polys:
129+
num_points = len(poly)
130+
new_poly = np.array(
131+
transformed_keypoints[idx : idx + num_points], dtype=np.float32
132+
)
133+
new_polys.append(new_poly)
134+
idx += num_points
135+
data["polys"] = new_polys
84136
return data
85-
86-
def may_augment_poly(self, aug, img_shape, poly):
87-
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
88-
keypoints = aug.augment_keypoints(
89-
[imgaug.KeypointsOnImage(keypoints, shape=img_shape)]
90-
)[0].keypoints
91-
poly = [(p.x, p.y) for p in keypoints]
92-
return poly

pyproject.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@ classifiers = [
4141
dependencies = [
4242
"shapely",
4343
"scikit-image",
44-
"imgaug",
4544
"pyclipper",
4645
"lmdb",
4746
"tqdm",
48-
"numpy<2.0",
47+
"numpy",
4948
"rapidfuzz",
5049
"opencv-python",
5150
"opencv-contrib-python",
@@ -56,7 +55,10 @@ dependencies = [
5655
"beautifulsoup4",
5756
"fonttools>=4.24.0",
5857
"fire>=0.3.0",
59-
"requests"
58+
"requests",
59+
"albumentations==1.4.10",
60+
# to be compatible with albumentations
61+
"albucore==0.0.13"
6062
]
6163

6264
[project.urls]

requirements.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
shapely
22
scikit-image
3-
imgaug
43
pyclipper
54
lmdb
65
tqdm
7-
numpy<2.0
6+
numpy
87
rapidfuzz
98
opencv-python
109
opencv-contrib-python

tests/test_iaa_augment.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import os
2+
import sys
3+
import pytest
4+
import numpy as np
5+
import random
6+
7+
current_dir = os.path.dirname(os.path.abspath(__file__))
8+
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
9+
10+
from ppocr.data.imaug.iaa_augment import IaaAugment
11+
12+
# Set a fixed random seed for reproducibility
13+
np.random.seed(42)
14+
random.seed(42)
15+
16+
17+
# Fixtures for common test inputs
18+
@pytest.fixture
19+
def sample_image():
20+
# Create a dummy image of size 100x100 with 3 channels
21+
return np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
22+
23+
24+
@pytest.fixture
25+
def sample_polys():
26+
# Create some dummy polygons
27+
polys = [
28+
np.array([[10, 10], [20, 10], [20, 20], [10, 20]], dtype=np.float32),
29+
np.array([[30, 30], [40, 30], [40, 40], [30, 40]], dtype=np.float32),
30+
]
31+
return polys
32+
33+
34+
# Helper function to create data dictionary
35+
def create_data(sample_image, sample_polys):
36+
return {
37+
"image": sample_image.copy(),
38+
"polys": [poly.copy() for poly in sample_polys],
39+
}
40+
41+
42+
# Test default augmenter (with default augmenter_args)
43+
def test_iaa_augment_default(sample_image, sample_polys):
44+
data = create_data(sample_image, sample_polys)
45+
augmenter = IaaAugment()
46+
transformed_data = augmenter(data)
47+
48+
assert isinstance(
49+
transformed_data["image"], np.ndarray
50+
), "Image should be a numpy array"
51+
assert isinstance(transformed_data["polys"], list), "Polys should be a list"
52+
assert transformed_data["image"].ndim == 3, "Image should be 3-dimensional"
53+
54+
# Check that the polys have been transformed
55+
polys_changed = any(
56+
not np.allclose(orig_poly, trans_poly)
57+
for orig_poly, trans_poly in zip(sample_polys, transformed_data["polys"])
58+
)
59+
assert polys_changed, "Polygons should have been transformed"
60+
61+
62+
# Test augmenter with empty augmenter_args (no augmentation)
63+
def test_iaa_augment_none(sample_image, sample_polys):
64+
data = create_data(sample_image, sample_polys)
65+
augmenter = IaaAugment(augmenter_args=[])
66+
transformed_data = augmenter(data)
67+
68+
assert np.array_equal(
69+
data["image"], transformed_data["image"]
70+
), "Image should be unchanged"
71+
for orig_poly, transformed_poly in zip(data["polys"], transformed_data["polys"]):
72+
assert np.array_equal(
73+
orig_poly, transformed_poly
74+
), "Polygons should be unchanged"
75+
76+
77+
# Parameterize tests to cover multiple augmenter_args scenarios
78+
@pytest.mark.parametrize(
79+
"augmenter_args, expected_shape",
80+
[
81+
([], (100, 100, 3)),
82+
([{"type": "Resize", "args": {"size": [0.5, 0.5]}}], (50, 50, 3)),
83+
([{"type": "Resize", "args": {"size": [2.0, 2.0]}}], (200, 200, 3)),
84+
],
85+
)
86+
def test_iaa_augment_resize(sample_image, sample_polys, augmenter_args, expected_shape):
87+
data = create_data(sample_image, sample_polys)
88+
augmenter = IaaAugment(augmenter_args=augmenter_args)
89+
transformed_data = augmenter(data)
90+
91+
assert (
92+
transformed_data["image"].shape == expected_shape
93+
), f"Expected image shape {expected_shape}, got {transformed_data['image'].shape}"
94+
95+
96+
# Test with custom augmenter_args
97+
def test_iaa_augment_custom(sample_image, sample_polys):
98+
data = create_data(sample_image, sample_polys)
99+
augmenter_args = [
100+
{"type": "Affine", "args": {"rotate": [45, 45]}}, # Fixed rotation angle
101+
{"type": "Resize", "args": {"size": [0.5, 0.5]}},
102+
]
103+
augmenter = IaaAugment(augmenter_args=augmenter_args)
104+
transformed_data = augmenter(data)
105+
106+
expected_height = int(sample_image.shape[0] * 0.5)
107+
expected_width = int(sample_image.shape[1] * 0.5)
108+
109+
assert (
110+
transformed_data["image"].shape[0] == expected_height
111+
), "Image height should be scaled by 0.5"
112+
assert (
113+
transformed_data["image"].shape[1] == expected_width
114+
), "Image width should be scaled by 0.5"
115+
116+
# Check that the polys have been transformed
117+
polys_changed = any(
118+
not np.allclose(orig_poly, trans_poly)
119+
for orig_poly, trans_poly in zip(sample_polys, transformed_data["polys"])
120+
)
121+
assert polys_changed, "Polygons should have been transformed"
122+
123+
124+
# Test unknown transform type raises NotImplementedError
125+
def test_iaa_augment_unknown_transform():
126+
augmenter_args = [{"type": "UnknownTransform", "args": {}}]
127+
with pytest.raises(NotImplementedError):
128+
IaaAugment(augmenter_args=augmenter_args)
129+
130+
131+
# Test invalid size parameter raises ValueError
132+
def test_iaa_augment_invalid_resize_size():
133+
augmenter_args = [{"type": "Resize", "args": {"size": "invalid_size"}}]
134+
with pytest.raises(ValueError):
135+
IaaAugment(augmenter_args=augmenter_args)
136+
137+
138+
# Test that polys are transformed appropriately
139+
def test_iaa_augment_polys_transformation(sample_image, sample_polys):
140+
data = create_data(sample_image, sample_polys)
141+
augmenter_args = [
142+
{"type": "Affine", "args": {"rotate": [90, 90]}}, # Fixed rotation angle
143+
]
144+
augmenter = IaaAugment(augmenter_args=augmenter_args)
145+
transformed_data = augmenter(data)
146+
147+
# Check that the polygons have changed
148+
polys_changed = any(
149+
not np.allclose(orig_poly, trans_poly)
150+
for orig_poly, trans_poly in zip(sample_polys, transformed_data["polys"])
151+
)
152+
assert polys_changed, "Polygons should have been transformed"
153+
154+
155+
# Test with multiple transforms in augmenter_args
156+
def test_iaa_augment_multiple_transforms(sample_image, sample_polys):
157+
augmenter_args = [
158+
{"type": "Fliplr", "args": {"p": 1.0}}, # Always flip
159+
{"type": "Affine", "args": {"shear": 10}},
160+
]
161+
data = create_data(sample_image, sample_polys)
162+
augmenter = IaaAugment(augmenter_args=augmenter_args)
163+
transformed_data = augmenter(data)
164+
165+
# Check that the image has been transformed
166+
images_different = not np.array_equal(transformed_data["image"], sample_image)
167+
assert images_different, "Image should be transformed"
168+
169+
# Check that the polys have been transformed
170+
polys_changed = any(
171+
not np.allclose(orig_poly, trans_poly)
172+
for orig_poly, trans_poly in zip(sample_polys, transformed_data["polys"])
173+
)
174+
assert polys_changed, "Polygons should have been transformed"

0 commit comments

Comments
 (0)