diff --git a/pfbudget/core/command.py b/pfbudget/core/command.py index 61226ac..b26f119 100644 --- a/pfbudget/core/command.py +++ b/pfbudget/core/command.py @@ -6,7 +6,17 @@ from typing import Type from pfbudget.common.types import ExportFormat from pfbudget.db.client import Client -from pfbudget.db.model import Serializable +from pfbudget.db.model import ( + Bank, + Category, + CategoryGroup, + Serializable, + Tag, + Transaction, +) + +# required for the backup import +import pfbudget.db.model class Command(ABC): @@ -68,3 +78,51 @@ class ImportCommand(Command): class ImportFailedError(Exception): pass + + +class BackupCommand(Command): + def __init__(self, client: Client, fn: Path, format: ExportFormat) -> None: + self.__client = client + self.fn = fn + self.format = format + + def execute(self) -> None: + banks = self.__client.select(Bank) + groups = self.__client.select(CategoryGroup) + categories = self.__client.select(Category) + tags = self.__client.select(Tag) + transactions = self.__client.select(Transaction) + + values = [*banks, *groups, *categories, *tags, *transactions] + + match self.format: + case ExportFormat.JSON: + with open(self.fn, "w", newline="") as f: + json.dump([e.serialize() for e in values], f, indent=4) + case ExportFormat.pickle: + raise AttributeError("pickle export not working at the moment!") + + +class ImportBackupCommand(Command): + def __init__(self, client: Client, fn: Path, format: ExportFormat) -> None: + self.__client = client + self.fn = fn + self.format = format + + def execute(self) -> None: + match self.format: + case ExportFormat.JSON: + with open(self.fn, "r") as f: + try: + values = json.load(f) + values = [ + getattr(pfbudget.db.model, v["class_"]).deserialize(v) + for v in values + ] + except json.JSONDecodeError as e: + raise ImportFailedError(e) + + case ExportFormat.pickle: + raise AttributeError("pickle import not working at the moment!") + + self.__client.insert(values) diff --git a/pfbudget/db/model.py b/pfbudget/db/model.py index b556ea0..0c12deb 100644 --- a/pfbudget/db/model.py +++ b/pfbudget/db/model.py @@ -1,6 +1,6 @@ from __future__ import annotations from collections.abc import Mapping, MutableMapping, Sequence -from dataclasses import dataclass, fields +from dataclasses import dataclass import datetime as dt import decimal import enum @@ -46,11 +46,11 @@ class Base(MappedAsDataclass, DeclarativeBase): @dataclass class Serializable: def serialize(self) -> Mapping[str, Any]: - return {field.name: getattr(self, field.name) for field in fields(self)} + return dict(class_=type(self).__name__) @classmethod def deserialize(cls, map: Mapping[str, Any]) -> Self: - return cls(**map) + raise NotImplementedError class AccountType(enum.Enum): @@ -80,7 +80,7 @@ class Bank(Base, Serializable): "invert": self.nordigen.invert, } - return dict( + return super().serialize() | dict( name=self.name, BIC=self.BIC, type=self.type.name, @@ -137,7 +137,7 @@ class Transaction(Base, Serializable): "selector": self.category.selector.name, } - return dict( + return super().serialize() | dict( id=self.id, date=self.date.isoformat(), description=self.description, @@ -145,7 +145,7 @@ class Transaction(Base, Serializable): split=self.split, category=category if category else None, tags=[{"tag": tag.tag} for tag in self.tags], - note=self.note, + note={"note": self.note.note} if self.note else None, type=self.type, ) @@ -175,6 +175,10 @@ class Transaction(Base, Serializable): if map["tags"]: tags = set(TransactionTag(t["tag"]) for t in map["tags"]) + note = None + if map["note"]: + note = Note(map["note"]["note"]) + result = cls( dt.date.fromisoformat(map["date"]), map["description"], @@ -182,7 +186,7 @@ class Transaction(Base, Serializable): map["split"], category, tags, - map["note"], + note, ) if map["id"]: @@ -248,6 +252,17 @@ class CategoryGroup(Base, Serializable): name: Mapped[str] = mapped_column(primary_key=True) + categories: Mapped[list[Category]] = relationship( + default_factory=list, lazy="joined" + ) + + def serialize(self) -> Mapping[str, Any]: + return super().serialize() | dict(name=self.name) + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls(map["name"]) + class Category(Base, Serializable, repr=False): __tablename__ = "categories" @@ -290,7 +305,7 @@ class Category(Base, Serializable, repr=False): "amount": self.schedule.amount, } - return dict( + return super().serialize() | dict( name=self.name, group=self.group, rules=rules, @@ -398,7 +413,7 @@ class Tag(Base, Serializable): } ) - return dict(name=self.name, rules=rules) + return super().serialize() | dict(name=self.name, rules=rules) @classmethod def deserialize(cls, map: Mapping[str, Any]) -> Self: diff --git a/tests/mocks/transactions.py b/tests/mocks/transactions.py index dbf4c8e..5564c30 100644 --- a/tests/mocks/transactions.py +++ b/tests/mocks/transactions.py @@ -43,7 +43,7 @@ money = [ ] __original = Transaction(date(2023, 1, 1), "", Decimal("-10"), split=True) -__original.id = 1 +__original.id = 9000 split = [ __original, diff --git a/tests/test_backup.py b/tests/test_backup.py index b29dc2d..fe824b0 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -6,7 +6,13 @@ from mocks import banks, categories, transactions from mocks.client import MockClient from pfbudget.common.types import ExportFormat -from pfbudget.core.command import ExportCommand, ImportCommand, ImportFailedError +from pfbudget.core.command import ( + BackupCommand, + ExportCommand, + ImportBackupCommand, + ImportCommand, + ImportFailedError, +) from pfbudget.db.client import Client from pfbudget.db.model import ( Bank, @@ -87,7 +93,6 @@ class TestBackup: with pytest.raises(AttributeError): command.execute() - @pytest.mark.parametrize("input, what", not_serializable) def test_try_backup_not_serializable( self, tmp_path: Path, input: Sequence[Any], what: Type[Any] @@ -112,3 +117,23 @@ class TestBackup: imported = other.select(what) assert not imported + + def test_full_backup(self, tmp_path: Path): + file = tmp_path / "test.json" + + client = MockClient() + client.insert([e for t in params for e in t[0]]) + originals = client.select(Transaction) + + assert originals + + command = BackupCommand(client, file, ExportFormat.JSON) + command.execute() + + other = MockClient() + command = ImportBackupCommand(other, file, ExportFormat.JSON) + command.execute() + + imported = other.select(Transaction) + + assert originals == imported