Skip to content

Commit 4734335

Browse files
committed
feat(Database): implement commit, rollback
1 parent e932081 commit 4734335

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

src/json_as_db/Database.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import copy
23
import shortuuid
34
import aiofiles
45

@@ -36,8 +37,9 @@ def _override_only_unset(__dict: dict, __target: dict):
3637
new_target = dict()
3738
for field in unset_fields:
3839
new_target[field] = __target[field]
39-
__dict.update(new_target)
40-
return __dict
40+
new_dict = copy.deepcopy(__dict)
41+
new_dict.update(new_target)
42+
return new_dict
4143

4244

4345
class Database(dict):
@@ -51,6 +53,7 @@ class Database(dict):
5153
'created_at',
5254
'updated_at',
5355
]
56+
_memory = dict()
5457

5558
def __init__(self, *arg, **kwargs):
5659
self.__dict__ = dict(*arg, **kwargs)
@@ -63,6 +66,7 @@ def __init__(self, *arg, **kwargs):
6366
self.__records__: dict(),
6467
}
6568
self.__dict__ = _override_only_unset(self.__dict__, defaults)
69+
self.commit()
6670

6771
def __getitem__(self, key: str) -> Any:
6872
try:
@@ -75,6 +79,7 @@ def __setitem__(self, key, value) -> None:
7579

7680
def __delitem__(self, key) -> None:
7781
try:
82+
self._update_timestamp()
7883
return self.records.__delitem__(key)
7984
except KeyError:
8085
return None
@@ -109,12 +114,18 @@ def metadata(self) -> dict:
109114
meta[column] = self.__dict__.get(column)
110115
return meta
111116

117+
def _update_timestamp(self) -> None:
118+
self.__dict__.update({
119+
'updated_at': datetime.now().isoformat()
120+
})
121+
112122
def get(self, key: Union[str, List[str]], default=None) -> Union[Any, List[Any]]:
113123
_type, _keys = _from_maybe_list(key)
114124
values = [self.records.get(k, default) for k in _keys]
115125
return _return_maybe(_type, values)
116126

117127
def update(self, mapping: Union[dict, tuple] = (), **kwargs) -> None:
128+
self._update_timestamp()
118129
return self.records.update(mapping, **kwargs)
119130

120131
def modify(
@@ -132,6 +143,7 @@ def modify(
132143
target = dict()
133144
target[_id] = _value
134145
self.records.update(target)
146+
self._update_timestamp()
135147

136148
def add(self, item: Union[Any, List[Any]]) -> Union[str, List[str]]:
137149
_type, _items = _from_maybe_list(item)
@@ -142,18 +154,21 @@ def add(self, item: Union[Any, List[Any]]) -> Union[str, List[str]]:
142154
self.records[uid] = i
143155
ids.append(uid)
144156

157+
self._update_timestamp()
145158
return _return_maybe(_type, ids)
146159

147160
def remove(self, key: Union[str, List[str]]) -> Union[str, List[str]]:
148161
_type, _keys = _from_maybe_list(key)
149162
popped = [self.records.pop(key) for key in _keys]
163+
self._update_timestamp()
150164
return _return_maybe(_type, popped)
151165

152166
def all(self) -> List[Any]:
153167
return self.records.values()
154168

155169
def clear(self) -> None:
156170
self.records.clear()
171+
self._update_timestamp()
157172

158173
def find(self, func: Callable[..., bool]) -> List[str]:
159174
ids = []
@@ -187,10 +202,10 @@ def drop(self) -> int:
187202
return del_count
188203

189204
def commit(self) -> None:
190-
pass
205+
self._memory = copy.deepcopy(self.__dict__)
191206

192207
def rollback(self) -> None:
193-
pass
208+
self.__dict__ = copy.deepcopy(self._memory)
194209

195210
async def save(
196211
self,

tests/database/test_database.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,24 @@ def test_db_drop(db: Database):
214214
assert dropped_count == 0
215215

216216

217-
def test_db_commit(db: Database):
218-
db.commit()
219-
pytest.skip()
217+
def test_db_commit_and_rollback(db: Database):
218+
prev = db.all()
219+
assert len(prev) == 2
220220

221+
db.commit()
222+
updated_old = db.metadata.get('updated_at')
223+
new_item = {'something': 'new'}
224+
new_id = db.add(new_item)
225+
updated = db.metadata.get('updated_at')
226+
assert db.count() == 3
227+
assert db.get(new_id) == new_item
228+
assert updated_old != updated
221229

222-
def test_db_rollback(db: Database):
223230
db.rollback()
224-
pytest.skip()
231+
updated_rollback = db.metadata.get('updated_at')
232+
assert db.count() == 2
233+
assert db.get(new_id) == None
234+
assert updated_old == updated_rollback
225235

226236

227237
@pytest.mark.asyncio

0 commit comments

Comments
 (0)