Skip to content

Commit dcb37ab

Browse files
Merge pull request #135 from fabiansinz/master
Extended order_by and bugfixes
2 parents 90291ef + dad10f8 commit dcb37ab

File tree

6 files changed

+207
-75
lines changed

6 files changed

+207
-75
lines changed

datajoint/fetch.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from collections import OrderedDict
2+
from functools import wraps
3+
import itertools
4+
import re
25
from .blob import unpack
36
import numpy as np
47
from datajoint import DataJointError
58
from . import key as PRIMARY_KEY
9+
from collections import abc
10+
611

712
def prepare_attributes(relation, item):
813
if isinstance(item, str) or item is PRIMARY_KEY:
@@ -20,46 +25,57 @@ def prepare_attributes(relation, item):
2025
raise DataJointError("Index must be a slice, a tuple, a list, a string.")
2126
return item, attributes
2227

23-
class FetchQuery:
28+
def copy_first(f):
29+
@wraps(f)
30+
def ret(*args, **kwargs):
31+
args = list(args)
32+
args[0] = args[0].__class__(args[0]) # call copy constructor
33+
return f(*args, **kwargs)
2434

25-
def __init__(self, relation):
26-
"""
35+
return ret
2736

28-
"""
29-
self.behavior = dict(
30-
offset=0, limit=None, order_by=None, descending=False, as_dict=False, map=None
31-
)
32-
self._relation = relation
37+
class Fetch:
38+
def __init__(self, relation):
39+
if isinstance(relation, Fetch): # copy constructor
40+
self.behavior = dict(relation.behavior)
41+
self._relation = relation._relation
42+
else:
43+
self.behavior = dict(
44+
offset=0, limit=None, order_by=None, as_dict=False
45+
)
46+
self._relation = relation
3347

3448

49+
@copy_first
3550
def from_to(self, fro, to):
3651
self.behavior['offset'] = fro
3752
self.behavior['limit'] = to - fro
3853
return self
3954

40-
def order_by(self, order_by):
41-
self.behavior['order_by'] = order_by
55+
@copy_first
56+
def order_by(self, *args):
57+
if len(args) > 0:
58+
self.behavior['order_by'] = self.behavior['order_by'] if self.behavior['order_by'] is not None else []
59+
namepat = re.compile(r"\s*(?P<name>\w+).*")
60+
for a in args: # remove duplicates
61+
name = namepat.match(a).group('name')
62+
pat = re.compile(r"%s(\s*$|\s+(\S*\s*)*$)" % (name,))
63+
self.behavior['order_by'] = [e for e in self.behavior['order_by'] if not pat.match(e)]
64+
self.behavior['order_by'].extend(args)
4265
return self
4366

67+
@copy_first
4468
def as_dict(self):
4569
self.behavior['as_dict'] = True
46-
47-
def ascending(self):
48-
self.behavior['descending'] = False
4970
return self
5071

51-
def descending(self):
52-
self.behavior['descending'] = True
53-
return self
54-
55-
def apply(self, f):
56-
self.behavior['map'] = f
57-
return self
5872

59-
def limit_by(self, limit):
73+
@copy_first
74+
def limit_to(self, limit):
6075
self.behavior['limit'] = limit
6176
return self
6277

78+
@copy_first
6379
def set_behavior(self, **kwargs):
6480
self.behavior.update(kwargs)
6581
return self
@@ -78,9 +94,7 @@ def __call__(self, **kwargs):
7894
"""
7995
behavior = dict(self.behavior, **kwargs)
8096

81-
cur = self._relation.cursor(offset=behavior['offset'], limit=behavior['limit'],
82-
order_by=behavior['order_by'], descending=behavior['descending'],
83-
as_dict=behavior['as_dict'])
97+
cur = self._relation.cursor(**behavior)
8498

8599
heading = self._relation.heading
86100
if behavior['as_dict']:
@@ -92,22 +106,15 @@ def __call__(self, **kwargs):
92106
for blob_name in heading.blobs:
93107
ret[blob_name] = list(map(unpack, ret[blob_name]))
94108

95-
if behavior['map'] is not None:
96-
f = behavior['map']
97-
for i in range(len(ret)):
98-
ret[i] = f(ret[i])
99-
100109
return ret
101110

102111
def __iter__(self):
103112
"""
104113
Iterator that returns the contents of the database.
105114
"""
106-
behavior = self.behavior
115+
behavior = dict(self.behavior)
107116

108-
cur = self._relation.cursor(offset=behavior['offset'], limit=behavior['limit'],
109-
order_by=behavior['order_by'], descending=behavior['descending'],
110-
as_dict=behavior['as_dict'])
117+
cur = self._relation.cursor(**behavior)
111118

112119
heading = self._relation.heading
113120
do_unpack = tuple(h in heading.blobs for h in heading.names)
@@ -126,10 +133,10 @@ def keys(self, **kwargs):
126133
"""
127134
Iterator that returns primary keys.
128135
"""
136+
b = dict(self.behavior, **kwargs)
129137
if 'as_dict' not in kwargs:
130-
kwargs['as_dict'] = True
131-
yield from self._relation.project().fetch.set_behavior(**kwargs)
132-
138+
b['as_dict'] = True
139+
yield from self._relation.project().fetch.set_behavior(**b)
133140

134141
def __getitem__(self, item):
135142
"""
@@ -146,7 +153,7 @@ def __getitem__(self, item):
146153
single_output = isinstance(item, str) or item is PRIMARY_KEY or isinstance(item, int)
147154
item, attributes = prepare_attributes(self._relation, item)
148155

149-
result = self._relation.project(*attributes).fetch()
156+
result = self._relation.project(*attributes).fetch(**self.behavior)
150157
return_values = [
151158
np.ndarray(result.shape,
152159
np.dtype({name: result.dtype.fields[name] for name in self._relation.primary_key}),
@@ -158,8 +165,7 @@ def __getitem__(self, item):
158165
return return_values[0] if single_output else return_values
159166

160167

161-
class Fetch1Query:
162-
168+
class Fetch1:
163169
def __init__(self, relation):
164170
self._relation = relation
165171

@@ -202,4 +208,4 @@ def __getitem__(self, item):
202208
else result[attribute][0]
203209
for attribute in item
204210
)
205-
return return_values[0] if single_output else return_values
211+
return return_values[0] if single_output else return_values

datajoint/relational_operand.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from . import DataJointError
1111
import logging
1212

13-
from .fetch import FetchQuery, Fetch1Query
13+
from .fetch import Fetch, Fetch1
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -171,7 +171,7 @@ def __call__(self, *args, **kwargs):
171171
"""
172172
return self.fetch(*args, **kwargs)
173173

174-
def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict=False):
174+
def cursor(self, offset=0, limit=None, order_by=None, as_dict=False):
175175
"""
176176
Return query cursor.
177177
See Relation.fetch() for input description.
@@ -182,8 +182,7 @@ def cursor(self, offset=0, limit=None, order_by=None, descending=False, as_dict=
182182
sql = self.make_select()
183183
if order_by is not None:
184184
sql += ' ORDER BY ' + ', '.join(order_by)
185-
if descending:
186-
sql += ' DESC'
185+
187186
if limit is not None:
188187
sql += ' LIMIT %d' % limit
189188
if offset:
@@ -206,12 +205,13 @@ def __repr__(self):
206205
repr_string += ' (%d tuples)\n' % len(self)
207206
return repr_string
208207

208+
@property
209209
def fetch1(self):
210-
return Fetch1Query(self)
210+
return Fetch1(self)
211211

212212
@property
213213
def fetch(self):
214-
return FetchQuery(self)
214+
return Fetch(self)
215215

216216
@property
217217
def where_clause(self):
@@ -253,8 +253,6 @@ def make_condition(arg):
253253
return ' WHERE ' + ' AND '.join(condition_string)
254254

255255

256-
257-
258256
class Not:
259257
"""
260258
inverse restriction
@@ -319,9 +317,9 @@ def __init__(self, arg, group=None, *attributes, **renamed_attributes):
319317
self._arg = Subquery(arg)
320318
else:
321319
self._group = None
322-
if arg.heading.computed or\
320+
if arg.heading.computed or \
323321
(isinstance(arg.restrictions, RelationalOperand) and \
324-
all(attr in self._attributes for attr in arg.restrictions.heading.names)) :
322+
all(attr in self._attributes for attr in arg.restrictions.heading.names)):
325323
# can simply the expression because all restrictions attrs are projected out anyway!
326324
self._arg = arg
327325
self._restrictions = self._arg.restrictions

doc/source/_static/.dummy

Whitespace-only changes.

tests/schema.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,26 @@ class Subject(dj.Manual):
3939
def prepare(self):
4040
self.insert(self.contents, ignore_errors=True)
4141

42+
@schema
43+
class Language(dj.Lookup):
44+
45+
definition = """
46+
# languages spoken by some of the developers
47+
48+
entry_id : int
49+
---
50+
name : varchar(40) # name of the developer
51+
language : varchar(40) # language
52+
"""
53+
54+
contents = [
55+
(0, 'Fabian', 'English'),
56+
(1, 'Edgar', 'English'),
57+
(2, 'Dimitri', 'English'),
58+
(3, 'Dimitri', 'Ukrainian'),
59+
(4, 'Fabian', 'German'),
60+
(5, 'Edgar', 'Japanese'),
61+
]
4262

4363
@schema
4464
class Experiment(dj.Imported):

tests/test_fetch.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from operator import itemgetter, attrgetter
2+
import itertools
3+
from nose.tools import assert_true
4+
from numpy.testing import assert_array_equal, assert_equal
5+
import numpy as np
6+
7+
from . import schema
8+
import datajoint as dj
9+
10+
11+
class TestFetch:
12+
def __init__(self):
13+
self.subject = schema.Subject()
14+
self.lang = schema.Language()
15+
16+
def test_getitem(self):
17+
"""Testing Fetch.__getitem__"""
18+
19+
np.testing.assert_array_equal(sorted(self.subject.project().fetch(), key=itemgetter(0)),
20+
sorted(self.subject.fetch[dj.key], key=itemgetter(0)),
21+
'Primary key is not returned correctly')
22+
23+
tmp = self.subject.fetch(order_by=['subject_id'])
24+
25+
for column, field in zip(self.subject.fetch[:], [e[0] for e in tmp.dtype.descr]):
26+
np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly')
27+
28+
subject_notes, key, real_id = self.subject.fetch['subject_notes', dj.key, 'real_id']
29+
#
30+
np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes']))
31+
np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id']))
32+
np.testing.assert_array_equal(sorted(key, key=itemgetter(0)),
33+
sorted(self.subject.project().fetch(), key=itemgetter(0)))
34+
35+
for column, field in zip(self.subject.fetch['subject_id'::2], [e[0] for e in tmp.dtype.descr][::2]):
36+
np.testing.assert_array_equal(sorted(tmp[field]), sorted(column), 'slice : does not work correctly')
37+
38+
def test_order_by(self):
39+
"""Tests order_by sorting order"""
40+
langs = schema.Language.contents
41+
42+
for ord_name, ord_lang in itertools.product(*2 * [['ASC', 'DESC']]):
43+
cur = self.lang.fetch.order_by('name ' + ord_name, 'language ' + ord_lang)()
44+
langs.sort(key=itemgetter(2), reverse=ord_lang == 'DESC')
45+
langs.sort(key=itemgetter(1), reverse=ord_name == 'DESC')
46+
for c, l in zip(cur, langs):
47+
assert_true(np.all(cc == ll for cc, ll in zip(c, l)), 'Sorting order is different')
48+
49+
def test_order_by_default(self):
50+
"""Tests order_by sorting order with defaults"""
51+
langs = schema.Language.contents
52+
53+
cur = self.lang.fetch.order_by('language', 'name DESC')()
54+
langs.sort(key=itemgetter(1), reverse=True)
55+
langs.sort(key=itemgetter(2), reverse=False)
56+
57+
for c, l in zip(cur, langs):
58+
assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different')
59+
60+
def test_order_by_direct(self):
61+
"""Tests order_by sorting order passing it to __call__"""
62+
langs = schema.Language.contents
63+
64+
cur = self.lang.fetch(order_by=['language', 'name DESC'])
65+
langs.sort(key=itemgetter(1), reverse=True)
66+
langs.sort(key=itemgetter(2), reverse=False)
67+
for c, l in zip(cur, langs):
68+
assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different')
69+
70+
def test_limit_to(self):
71+
"""Test the limit_to function """
72+
langs = schema.Language.contents
73+
74+
cur = self.lang.fetch.limit_to(4)(order_by=['language', 'name DESC'])
75+
langs.sort(key=itemgetter(1), reverse=True)
76+
langs.sort(key=itemgetter(2), reverse=False)
77+
assert_equal(len(cur), 4, 'Length is not correct')
78+
for c, l in list(zip(cur, langs))[:4]:
79+
assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different')
80+
81+
def test_from_to(self):
82+
"""Test the from_to function """
83+
langs = schema.Language.contents
84+
85+
cur = self.lang.fetch.from_to(2, 6)(order_by=['language', 'name DESC'])
86+
langs.sort(key=itemgetter(1), reverse=True)
87+
langs.sort(key=itemgetter(2), reverse=False)
88+
assert_equal(len(cur), 4, 'Length is not correct')
89+
for c, l in list(zip(cur, langs[2:6])):
90+
assert_true(np.all([cc == ll for cc, ll in zip(c, l)]), 'Sorting order is different')
91+
92+
def test_iter(self):
93+
"""Test iterator"""
94+
langs = schema.Language.contents
95+
96+
cur = self.lang.fetch.order_by('language', 'name DESC')
97+
langs.sort(key=itemgetter(1), reverse=True)
98+
langs.sort(key=itemgetter(2), reverse=False)
99+
for (_, name, lang), (_, tname, tlang) in list(zip(cur, langs)):
100+
assert_true(name == tname and lang == tlang, 'Values are not the same')
101+
102+
def test_keys(self):
103+
"""test key iterator"""
104+
langs = schema.Language.contents
105+
langs.sort(key=itemgetter(1), reverse=True)
106+
langs.sort(key=itemgetter(2), reverse=False)
107+
108+
cur = self.lang.fetch.order_by('language', 'name DESC')['entry_id']
109+
cur2 = [e['entry_id'] for e in self.lang.fetch.order_by('language', 'name DESC').keys()]
110+
111+
keys, _, _ = list(zip(*langs))
112+
for k, c, c2 in zip(keys, cur, cur2):
113+
assert_true(k == c == c2, 'Values are not the same')
114+
115+
def test_fetch1(self):
116+
key = {'entry_id': 0}
117+
true = schema.Language.contents[0]
118+
119+
dat = (self.lang & key).fetch1()
120+
for k, (ke, c) in zip(true, dat.items()):
121+
assert_true(k == c == (self.lang & key).fetch1[ke], 'Values are not the same')
122+
123+
def test_copy(self):
124+
"""Test whether modifications copy the object"""
125+
f = self.lang.fetch
126+
f2 = f.order_by('name')
127+
assert_true(f.behavior['order_by'] is None and len(f2.behavior['order_by']) == 1, 'Object was not copied')
128+
129+
def test_overwrite(self):
130+
"""Test whether order_by overwrites duplicates"""
131+
f = self.lang.fetch.order_by('name DeSc ')
132+
f2 = f.order_by('name')
133+
assert_true(f2.behavior['order_by'] == ['name'], 'order_by attribute was not overwritten')

0 commit comments

Comments
 (0)