Skip to content

Commit cac6de9

Browse files
filipecosta90gkorland
authored andcommitted
[fix] fixes APPLY / SORTBY / GROUPBY / REDUCE order on FT.AGGREGATE s… (#41)
* [fix] fixes APPLY / SORTBY / GROUPBY / REDUCE order on FT.AGGREGATE support. * [add] setted sorby max default to 0. [add] set apply alias as not mandatory * [add] added supported for filter expressions on aggregations. include more examples on test_builder.py * [fix] corrected FT.AGGREGATE filter expressions to relate to the current state of the pipeline
1 parent 5c9456b commit cac6de9

File tree

3 files changed

+120
-38
lines changed

3 files changed

+120
-38
lines changed

.circleci/config.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,13 @@ jobs:
7777
name: run tests
7878
command: |
7979
. venv/bin/activate
80-
REDIS_PORT=6379 python test/test.py
80+
REDIS_PORT=6379 python test/test.py
81+
82+
- run:
83+
name: run query builder tests
84+
command: |
85+
. venv/bin/activate
86+
python test/test.py
8187
8288
# no need for store_artifacts on nightly builds
8389

redisearch/aggregation.py

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,58 @@ def __init__(self, fields, reducers):
9999
self.limit = Limit()
100100

101101
def build_args(self):
102-
ret = [str(len(self.fields))]
102+
ret = ['GROUPBY', str(len(self.fields))]
103103
ret.extend(self.fields)
104104
for reducer in self.reducers:
105105
ret += ['REDUCE', reducer.NAME, str(len(reducer.args))]
106106
ret.extend(reducer.args)
107-
if reducer._alias:
107+
if reducer._alias is not None:
108108
ret += ['AS', reducer._alias]
109109
return ret
110110

111+
class Projection(object):
112+
"""
113+
This object automatically created in the `AggregateRequest.apply()`
114+
"""
115+
116+
def __init__(self, projector, alias=None ):
117+
118+
self.alias = alias
119+
self.projector = projector
120+
121+
def build_args(self):
122+
ret = ['APPLY', self.projector]
123+
if self.alias is not None:
124+
ret += ['AS', self.alias]
125+
126+
return ret
127+
128+
class SortBy(object):
129+
"""
130+
This object automatically created in the `AggregateRequest.sort_by()`
131+
"""
132+
133+
def __init__(self, fields, max=0):
134+
self.fields = fields
135+
self.max = max
136+
137+
138+
139+
def build_args(self):
140+
fields_args = []
141+
for f in self.fields:
142+
if isinstance(f, SortDirection):
143+
fields_args += [f.field, f.DIRSTRING]
144+
else:
145+
fields_args += [f]
146+
147+
ret = ['SORTBY', str(len(fields_args))]
148+
ret.extend(fields_args)
149+
if self.max > 0:
150+
ret += ['MAX', str(self.max)]
151+
152+
return ret
153+
111154

112155
class AggregateRequest(object):
113156
"""
@@ -127,11 +170,9 @@ def __init__(self, query='*'):
127170
return the object itself, making them useful for chaining.
128171
"""
129172
self._query = query
130-
self._groups = []
131-
self._projections = []
173+
self._aggregateplan = []
132174
self._loadfields = []
133175
self._limit = Limit()
134-
self._sortby = []
135176
self._max = 0
136177
self._with_schema = False
137178
self._verbatim = False
@@ -162,7 +203,7 @@ def group_by(self, fields, *reducers):
162203
`aggregation` module.
163204
"""
164205
group = Group(fields, reducers)
165-
self._groups.append(group)
206+
self._aggregateplan.extend(group.build_args())
166207

167208
return self
168209

@@ -177,7 +218,8 @@ def apply(self, **kwexpr):
177218
expression itself, for example `apply(square_root="sqrt(@foo)")`
178219
"""
179220
for alias, expr in kwexpr.items():
180-
self._projections.append([alias, expr])
221+
projection = Projection(expr, alias )
222+
self._aggregateplan.extend(projection.build_args())
181223

182224
return self
183225

@@ -224,10 +266,7 @@ def limit(self, offset, num):
224266
225267
"""
226268
limit = Limit(offset, num)
227-
if self._groups:
228-
self._groups[-1].limit = limit
229-
else:
230-
self._limit = limit
269+
self._limit = limit
231270
return self
232271

233272
def sort_by(self, *fields, **kwargs):
@@ -258,16 +297,34 @@ def sort_by(self, *fields, **kwargs):
258297
.sort_by(Desc('@paid'), max=10)
259298
```
260299
"""
261-
self._max = kwargs.get('max', 0)
262300
if isinstance(fields, (string_types, SortDirection)):
263301
fields = [fields]
264-
for f in fields:
265-
if isinstance(f, SortDirection):
266-
self._sortby += [f.field, f.DIRSTRING]
267-
else:
268-
self._sortby.append(f)
302+
303+
max = kwargs.get('max', 0)
304+
sortby = SortBy(fields, max)
305+
306+
self._aggregateplan.extend(sortby.build_args())
307+
return self
308+
309+
def filter(self, expressions):
310+
"""
311+
Specify filter for post-query results using predicates relating to values in the result set.
312+
313+
### Parameters
314+
315+
- **fields**: Fields to group by. This can either be a single string,
316+
or a list of strings.
317+
"""
318+
if isinstance(expressions, (string_types)):
319+
expressions = [expressions]
320+
321+
for expression in expressions:
322+
self._aggregateplan.extend(['FILTER', expression])
323+
269324
return self
270325

326+
327+
271328
def with_schema(self):
272329
"""
273330
If set, the `schema` property will contain a list of `[field, type]`
@@ -312,18 +369,8 @@ def build_args(self):
312369
ret.append('LOAD')
313370
ret.append(str(len(self._loadfields)))
314371
ret.extend(self._loadfields)
315-
for group in self._groups:
316-
ret += ['GROUPBY'] + group.build_args() + group.limit.build_args()
317-
for alias, projector in self._projections:
318-
ret += ['APPLY', projector]
319-
if alias:
320-
ret += ['AS', alias]
321-
322-
if self._sortby:
323-
ret += ['SORTBY', str(len(self._sortby))]
324-
ret += self._sortby
325-
if self._max:
326-
ret += ['MAX', str(self._max)]
372+
373+
ret.extend(self._aggregateplan)
327374

328375
ret += self._limit.build_args()
329376

test/test_builder.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from unittest import TestCase
1+
import unittest
22
import redisearch.aggregation as a
33
import redisearch.querystring as q
44
import redisearch.reducers as r
55

6-
class QueryBuilderTest(TestCase):
6+
class QueryBuilderTest(unittest.TestCase):
77
def testBetween(self):
88
b = q.between(1, 10)
99
self.assertEqual('[1 10]', str(b))
@@ -42,16 +42,16 @@ def testGroup(self):
4242
# Single field, single reducer
4343
g = a.Group('foo', r.count())
4444
ret = g.build_args()
45-
self.assertEqual(['1', 'foo', 'REDUCE', 'COUNT', '0'], ret)
45+
self.assertEqual(['GROUPBY', '1', 'foo', 'REDUCE', 'COUNT', '0'], ret)
4646

4747
# Multiple fields, single reducer
4848
g = a.Group(['foo', 'bar'], r.count())
49-
self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'],
49+
self.assertEqual(['GROUPBY', '2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'],
5050
g.build_args())
5151

5252
# Multiple fields, multiple reducers
5353
g = a.Group(['foo', 'bar'], [r.count(), r.count_distinct('@fld1')])
54-
self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'],
54+
self.assertEqual(['GROUPBY', '2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'],
5555
g.build_args())
5656

5757
def testAggRequest(self):
@@ -62,13 +62,38 @@ def testAggRequest(self):
6262
req = a.AggregateRequest().group_by('@foo', r.count())
6363
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'], req.build_args())
6464

65+
# Test with group_by and alias on reducer
66+
req = a.AggregateRequest().group_by('@foo', r.count().alias('foo_count'))
67+
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'AS', 'foo_count'], req.build_args())
68+
6569
# Test with limit
66-
req = a.AggregateRequest().\
67-
group_by('@foo', r.count()).\
70+
req = a.AggregateRequest(). \
71+
group_by('@foo', r.count()). \
6872
sort_by('@foo')
6973
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '1',
7074
'@foo'], req.build_args())
7175

76+
# Test with apply
77+
req = a.AggregateRequest(). \
78+
apply(foo="@bar / 2"). \
79+
group_by('@foo', r.count())
80+
81+
self.assertEqual(['*', 'APPLY', '@bar / 2', 'AS', 'foo', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'],
82+
req.build_args())
83+
84+
# Test with filter
85+
req = a.AggregateRequest().group_by('@foo', r.count()).filter( "@foo=='bar'")
86+
self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'FILTER', "@foo=='bar'" ], req.build_args())
87+
88+
# Test with filter on different state of the pipeline
89+
req = a.AggregateRequest().filter("@foo=='bar'").group_by('@foo', r.count())
90+
self.assertEqual(['*', 'FILTER', "@foo=='bar'", 'GROUPBY', '1', '@foo','REDUCE', 'COUNT', '0' ], req.build_args())
91+
92+
# Test with filter on different state of the pipeline
93+
req = a.AggregateRequest().filter(["@foo=='bar'","@foo2=='bar2'"]).group_by('@foo', r.count())
94+
self.assertEqual(['*', 'FILTER', "@foo=='bar'", 'FILTER', "@foo2=='bar2'", 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'],
95+
req.build_args())
96+
7297
# Test with sort_by
7398
req = a.AggregateRequest().group_by('@foo', r.count()).sort_by('@date')
7499
# print req.build_args()
@@ -105,4 +130,8 @@ def test_reducers(self):
105130
self.assertEqual(('f1', 'BY', 'f2', 'ASC'), r.first_value('f1', a.Asc('f2')).args)
106131
self.assertEqual(('f1', 'BY', 'f1', 'ASC'), r.first_value('f1', a.Asc).args)
107132

108-
self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args)
133+
self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args)
134+
135+
if __name__ == '__main__':
136+
137+
unittest.main()

0 commit comments

Comments
 (0)