Skip to content

Commit fd153a5

Browse files
committed
Fix some edge cases and resolve floating point precision problem.
Add tests.
1 parent 2991877 commit fd153a5

File tree

2 files changed

+87
-22
lines changed

2 files changed

+87
-22
lines changed

sdks/python/apache_beam/transforms/periodicsequence.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from typing import Any
2222
from typing import Optional
2323
from typing import Sequence
24-
from typing import Union
2524

2625
import apache_beam as beam
2726
from apache_beam.io.restriction_trackers import OffsetRange
@@ -41,13 +40,21 @@ class ImpulseSeqGenRestrictionProvider(core.RestrictionProvider):
4140
def initial_restriction(self, element):
4241
start, end, interval = element
4342
if isinstance(start, Timestamp):
44-
start = start.micros / 1000000
43+
start_micros = start.micros
44+
else:
45+
start_micros = round(start * 1000000)
46+
4547
if isinstance(end, Timestamp):
46-
end = end.micros / 1000000
48+
end_micros = end.micros
49+
else:
50+
end_micros = round(end * 1000000)
51+
52+
interval_micros = round(interval * 1000000)
4753

48-
assert start <= end
54+
assert start_micros <= end_micros
4955
assert interval > 0
50-
total_outputs = math.ceil((end - start) / interval)
56+
delta_micros: int = end_micros - start_micros
57+
total_outputs = math.ceil(delta_micros / interval_micros)
5158
return OffsetRange(0, total_outputs)
5259

5360
def create_tracker(self, restriction):
@@ -232,19 +239,19 @@ def _validate_and_adjust_duration(self):
232239

233240
if isinstance(self.stop_ts, Timestamp):
234241
if self.stop_ts == MAX_TIMESTAMP:
235-
# adjust stop timestamp to match the data duration
236-
end = start + data_duration
237-
if self.interval > 1e-6:
238-
end += 1e-6
239-
self.stop_ts = Timestamp.of(end)
242+
# When the stop timestamp is unbounded (MAX_TIMESTAMP), set it to the
243+
# data's actual end time plus an extra fire interval, because the
244+
# impulse duration's upper bound is exclusive.
245+
end = start + data_duration + self.interval
246+
self.stop_ts = Timestamp(micros=end * 1000000)
240247
else:
241248
end = self.stop_ts.micros / 1000000
242249
else:
243250
end = self.stop_ts
244251

245252
# The total time for the impulse signal which occurs in [start, end).
246253
impulse_duration = end - start
247-
if data_duration + self.interval < impulse_duration:
254+
if round(data_duration + self.interval, 6) < round(impulse_duration, 6):
248255
# We don't have enough data for the impulse.
249256
# If we can fit at least one more data point in the impulse duration,
250257
# then we will be in the repeat mode.
@@ -264,8 +271,8 @@ def _validate_and_adjust_duration(self):
264271

265272
def __init__(
266273
self,
267-
start_timestamp: Union[Timestamp, float] = Timestamp.now(),
268-
stop_timestamp: Union[Timestamp, float] = MAX_TIMESTAMP,
274+
start_timestamp: Timestamp = Timestamp.now(),
275+
stop_timestamp: Timestamp = MAX_TIMESTAMP,
269276
fire_interval: float = 360.0,
270277
apply_windowing: bool = False,
271278
data: Optional[Sequence[Any]] = None):

sdks/python/apache_beam/transforms/periodicsequence_test.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919

2020
# pytype: skip-file
2121

22+
import logging
2223
import inspect
24+
import random
2325
import time
2426
import unittest
2527

28+
from parameterized import parameterized
29+
2630
import apache_beam as beam
2731
from apache_beam.io.restriction_trackers import OffsetRange
2832
from apache_beam.testing.test_pipeline import TestPipeline
@@ -157,6 +161,53 @@ def test_processing_time(self):
157161
expected = [0, 2, 4]
158162
assert_that(ret, equal_to(expected, lambda x, y: abs(x - y) < threshold))
159163

164+
@parameterized.expand([0.5, 1, 2, 10])
165+
def test_stop_over_by_epsilon(self, interval):
166+
with TestPipeline() as p:
167+
ret = (
168+
p | PeriodicImpulse(
169+
start_timestamp=Timestamp(seconds=1),
170+
stop_timestamp=Timestamp(seconds=1, micros=1),
171+
data=[1, 2],
172+
fire_interval=interval)
173+
| beam.WindowInto(FixedWindows(interval))
174+
| beam.WithKeys(0)
175+
| beam.GroupByKey())
176+
expected = [
177+
(0, [1]),
178+
]
179+
assert_that(ret, equal_to(expected))
180+
181+
@parameterized.expand([1, 2])
182+
def test_stop_over_by_interval(self, interval):
183+
with TestPipeline() as p:
184+
ret = (
185+
p | PeriodicImpulse(
186+
start_timestamp=Timestamp(seconds=1),
187+
stop_timestamp=Timestamp(seconds=1 + interval),
188+
data=[1, 2],
189+
fire_interval=interval)
190+
| beam.WindowInto(FixedWindows(interval))
191+
| beam.WithKeys(0)
192+
| beam.GroupByKey())
193+
expected = [(0, [1])]
194+
assert_that(ret, equal_to(expected))
195+
196+
@parameterized.expand([1, 2])
197+
def test_stop_over_by_interval_and_epsilon(self, interval):
198+
with TestPipeline() as p:
199+
ret = (
200+
p | PeriodicImpulse(
201+
start_timestamp=Timestamp(seconds=1),
202+
stop_timestamp=Timestamp(seconds=1 + interval, micros=1),
203+
data=[1, 2],
204+
fire_interval=interval)
205+
| beam.WindowInto(FixedWindows(interval))
206+
| beam.WithKeys(0)
207+
| beam.GroupByKey())
208+
expected = [(0, [1]), (0, [2])]
209+
assert_that(ret, equal_to(expected))
210+
160211
def test_interval(self):
161212
with TestPipeline() as p:
162213
ret = (
@@ -208,15 +259,22 @@ def test_not_enough_timestamped_value(self):
208259
data=data,
209260
fire_interval=0.5))
210261

211-
def test_small_interval(self):
212-
data = [(Timestamp(1), 1), (Timestamp(2), 2), (Timestamp(3), 3),
213-
(Timestamp(6), 6), (Timestamp(4), 4), (Timestamp(5), 5),
214-
(Timestamp(7), 7), (Timestamp(8), 8), (Timestamp(9), 9),
215-
(Timestamp(10), 10)]
216-
expected = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
217-
with TestPipeline() as p:
218-
ret = (p | PeriodicImpulse(data=data, fire_interval=0.0001))
219-
assert_that(ret, equal_to(expected))
262+
def test_fuzzy_interval(self):
263+
seed = int(time.time() * 1000)
264+
times = 30
265+
logging.warning("random seed=%d", seed)
266+
random.seed(seed)
267+
for _ in range(times):
268+
n = int(random.randint(1, 100))
269+
data = list(range(n))
270+
m = random.randint(1, 1000)
271+
interval = m / 1e6
272+
now = Timestamp.now()
273+
with TestPipeline() as p:
274+
ret = (
275+
p | PeriodicImpulse(
276+
start_timestamp=now, data=data, fire_interval=interval))
277+
assert_that(ret, equal_to(data))
220278

221279

222280
if __name__ == '__main__':

0 commit comments

Comments
 (0)