3
3
import asyncio
4
4
from base64 import b64decode , b64encode
5
5
from collections import defaultdict
6
- from collections .abc import Awaitable
6
+ from collections .abc import AsyncGenerator , Awaitable
7
+ from contextlib import asynccontextmanager
7
8
from copy import deepcopy
8
9
from datetime import timedelta
9
10
from functools import cached_property
12
13
import logging
13
14
from pathlib import Path
14
15
import tarfile
16
+ from tarfile import TarFile
15
17
from tempfile import TemporaryDirectory
16
18
import time
17
19
from typing import Any , Self
56
58
from ..utils import remove_folder
57
59
from ..utils .dt import parse_datetime , utcnow
58
60
from ..utils .json import json_bytes
61
+ from ..utils .sentinel import DEFAULT
59
62
from .const import BUF_SIZE , LOCATION_CLOUD_BACKUP , BackupType
60
63
from .utils import key_to_iv , password_to_key
61
64
from .validate import SCHEMA_BACKUP
@@ -86,7 +89,6 @@ def __init__(
86
89
self ._data : dict [str , Any ] = data or {ATTR_SLUG : slug }
87
90
self ._tmp = None
88
91
self ._outer_secure_tarfile : SecureTarFile | None = None
89
- self ._outer_secure_tarfile_tarfile : tarfile .TarFile | None = None
90
92
self ._key : bytes | None = None
91
93
self ._aes : Cipher | None = None
92
94
self ._locations : dict [str | None , Path ] = {location : tar_file }
@@ -375,59 +377,68 @@ def _load_file():
375
377
376
378
return True
377
379
378
- async def __aenter__ (self ):
379
- """Async context to open a backup."""
380
+ @asynccontextmanager
381
+ async def create (self ) -> AsyncGenerator [None ]:
382
+ """Create new backup file."""
383
+ if self .tarfile .is_file ():
384
+ raise BackupError (
385
+ f"Cannot make new backup at { self .tarfile .as_posix ()} , file already exists!" ,
386
+ _LOGGER .error ,
387
+ )
380
388
381
- # create a backup
382
- if not self .tarfile .is_file ():
383
- self ._outer_secure_tarfile = SecureTarFile (
384
- self .tarfile ,
385
- "w" ,
386
- gzip = False ,
387
- bufsize = BUF_SIZE ,
389
+ self ._outer_secure_tarfile = SecureTarFile (
390
+ self .tarfile ,
391
+ "w" ,
392
+ gzip = False ,
393
+ bufsize = BUF_SIZE ,
394
+ )
395
+ try :
396
+ with self ._outer_secure_tarfile as outer_tarfile :
397
+ yield
398
+ await self ._create_cleanup (outer_tarfile )
399
+ finally :
400
+ self ._outer_secure_tarfile = None
401
+
402
+ @asynccontextmanager
403
+ async def open (self , location : str | None | type [DEFAULT ]) -> AsyncGenerator [None ]:
404
+ """Open backup for restore."""
405
+ if location != DEFAULT and location not in self .all_locations :
406
+ raise BackupError (
407
+ f"Backup { self .slug } does not exist in location { location } " ,
408
+ _LOGGER .error ,
409
+ )
410
+
411
+ backup_tarfile = (
412
+ self .tarfile if location == DEFAULT else self .all_locations [location ]
413
+ )
414
+ if not backup_tarfile .is_file ():
415
+ raise BackupError (
416
+ f"Cannot open backup at { backup_tarfile .as_posix ()} , file does not exist!" ,
417
+ _LOGGER .error ,
388
418
)
389
- self ._outer_secure_tarfile_tarfile = self ._outer_secure_tarfile .__enter__ ()
390
- return
391
419
392
420
# extract an existing backup
393
- self ._tmp = TemporaryDirectory (dir = str (self . tarfile .parent ))
421
+ self ._tmp = TemporaryDirectory (dir = str (backup_tarfile .parent ))
394
422
395
423
def _extract_backup ():
396
424
"""Extract a backup."""
397
- with tarfile .open (self . tarfile , "r:" ) as tar :
425
+ with tarfile .open (backup_tarfile , "r:" ) as tar :
398
426
tar .extractall (
399
427
path = self ._tmp .name ,
400
428
members = secure_path (tar ),
401
429
filter = "fully_trusted" ,
402
430
)
403
431
404
- await self .sys_run_in_executor (_extract_backup )
405
-
406
- async def __aexit__ (self , exception_type , exception_value , traceback ):
407
- """Async context to close a backup."""
408
- # exists backup or exception on build
409
- try :
410
- await self ._aexit (exception_type , exception_value , traceback )
411
- finally :
412
- if self ._tmp :
413
- self ._tmp .cleanup ()
414
- if self ._outer_secure_tarfile :
415
- self ._outer_secure_tarfile .__exit__ (
416
- exception_type , exception_value , traceback
417
- )
418
- self ._outer_secure_tarfile = None
419
- self ._outer_secure_tarfile_tarfile = None
432
+ with self ._tmp :
433
+ await self .sys_run_in_executor (_extract_backup )
434
+ yield
420
435
421
- async def _aexit (self , exception_type , exception_value , traceback ) :
436
+ async def _create_cleanup (self , outer_tarfile : TarFile ) -> None :
422
437
"""Cleanup after backup creation.
423
438
424
- This is a separate method to allow it to be called from __aexit__ to ensure
439
+ Separate method to be called from create to ensure
425
440
that cleanup is always performed, even if an exception is raised.
426
441
"""
427
- # If we're not creating a new backup, or if an exception was raised, we're done
428
- if not self ._outer_secure_tarfile or exception_type is not None :
429
- return
430
-
431
442
# validate data
432
443
try :
433
444
self ._data = SCHEMA_BACKUP (self ._data )
@@ -445,7 +456,7 @@ def _add_backup_json():
445
456
tar_info = tarfile .TarInfo (name = "./backup.json" )
446
457
tar_info .size = len (raw_bytes )
447
458
tar_info .mtime = int (time .time ())
448
- self . _outer_secure_tarfile_tarfile .addfile (tar_info , fileobj = fileobj )
459
+ outer_tarfile .addfile (tar_info , fileobj = fileobj )
449
460
450
461
try :
451
462
await self .sys_run_in_executor (_add_backup_json )
0 commit comments