From 638b833c74dd4d070b55dc7aa5b176d4f719c128 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Murta?= Date: Thu, 11 May 2023 00:26:18 +0100 Subject: [PATCH] ImportCommand and Serializable types The new command ImportCommand takes a Serializable type, from which it can call the deserialize method to generate a DB ORM type. The Serializable interface also declares the serialize method. (De)serialization moved to the ORM types, due to the inability to properly use overloading. Possible improvement for the future is to merge serialization information on JSONDecoder/Encoder classes. Adds a MockClient with the in-memory SQLite DB which can be used by tests. Most types export/import functionally tested using two DBs and comparing entries. --- pfbudget/core/command.py | 32 +++- pfbudget/db/model.py | 287 ++++++++++++++++++++++------------- pfbudget/utils/serializer.py | 54 ------- tests/mocks/categories.py | 11 +- tests/mocks/client.py | 11 ++ tests/mocks/transactions.py | 22 +++ tests/test_backup.py | 60 ++++++++ tests/test_command.py | 37 +++-- tests/test_database.py | 8 +- 9 files changed, 345 insertions(+), 177 deletions(-) delete mode 100644 pfbudget/utils/serializer.py create mode 100644 tests/mocks/client.py create mode 100644 tests/test_backup.py diff --git a/pfbudget/core/command.py b/pfbudget/core/command.py index ae9e1c6..f901aef 100644 --- a/pfbudget/core/command.py +++ b/pfbudget/core/command.py @@ -2,11 +2,11 @@ from abc import ABC, abstractmethod import json from pathlib import Path import pickle -from typing import Any, Type +from typing import Type from pfbudget.common.types import ExportFormat from pfbudget.db.client import Client -from pfbudget.utils.serializer import serialize +from pfbudget.db.model import Serializable class Command(ABC): @@ -19,7 +19,9 @@ class Command(ABC): class ExportCommand(Command): - def __init__(self, client: Client, what: Type[Any], fn: Path, format: ExportFormat): + def __init__( + self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat + ): self.__client = client self.what = what self.fn = fn @@ -30,7 +32,29 @@ class ExportCommand(Command): match self.format: case ExportFormat.JSON: with open(self.fn, "w", newline="") as f: - json.dump([serialize(e) for e in values], f, indent=4) + json.dump([e.serialize() for e in values], f, indent=4) case ExportFormat.pickle: with open(self.fn, "wb") as f: pickle.dump(values, f) + + +class ImportCommand(Command): + def __init__( + self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat + ): + self.__client = client + self.what = what + self.fn = fn + self.format = format + + def execute(self) -> None: + match self.format: + case ExportFormat.JSON: + with open(self.fn, "r") as f: + values = json.load(f) + values = [self.what.deserialize(v) for v in values] + case ExportFormat.pickle: + with open(self.fn, "rb") as f: + values = pickle.load(f) + + self.__client.insert(values) diff --git a/pfbudget/db/model.py b/pfbudget/db/model.py index aa1c9fa..6dbef14 100644 --- a/pfbudget/db/model.py +++ b/pfbudget/db/model.py @@ -1,9 +1,11 @@ from __future__ import annotations +from collections.abc import Mapping, MutableMapping, Sequence +from dataclasses import dataclass, fields import datetime as dt import decimal import enum import re -from typing import Annotated, Any, Callable, Optional +from typing import Annotated, Any, Callable, Optional, Self, cast from sqlalchemy import ( BigInteger, @@ -41,6 +43,16 @@ class Base(MappedAsDataclass, DeclarativeBase): } +@dataclass +class Serializable: + def serialize(self) -> Mapping[str, Any]: + return {field.name: getattr(self, field.name) for field in fields(self)} + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls(**map) + + class AccountType(enum.Enum): checking = enum.auto() savings = enum.auto() @@ -50,13 +62,7 @@ class AccountType(enum.Enum): MASTERCARD = enum.auto() -class Export: - @property - def format(self) -> dict[str, Any]: - raise NotImplementedError - - -class Bank(Base, Export): +class Bank(Base, Serializable): __tablename__ = "banks" name: Mapped[str] = mapped_column(primary_key=True) @@ -65,15 +71,29 @@ class Bank(Base, Export): nordigen: Mapped[Optional[Nordigen]] = relationship(default=None, lazy="joined") - @property - def format(self) -> dict[str, Any]: + def serialize(self) -> Mapping[str, Any]: + nordigen = None + if self.nordigen: + nordigen = { + "bank_id": self.nordigen.bank_id, + "requisition_id": self.nordigen.requisition_id, + "invert": self.nordigen.invert, + } + return dict( name=self.name, BIC=self.BIC, - type=self.type, - nordigen=self.nordigen.format if self.nordigen else None, + type=self.type.name, + nordigen=nordigen, ) + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + bank = cls(map["name"], map["BIC"], map["type"]) + if map["nordigen"]: + bank.nordigen = Nordigen(**map["nordigen"]) + return bank + bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))] @@ -88,7 +108,7 @@ idpk = Annotated[ money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))] -class Transaction(Base, Export): +class Transaction(Base, Serializable): __tablename__ = "transactions" id: Mapped[idpk] = mapped_column(init=False) @@ -109,18 +129,59 @@ class Transaction(Base, Export): type: Mapped[str] = mapped_column(init=False) __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "transaction"} - @property - def format(self) -> dict[str, Any]: + def serialize(self) -> Mapping[str, Any]: + category = None + if self.category: + category = { + "name": self.category.name, + "selector": self.category.selector.name, + } + return dict( - id=self.id, - date=self.date, + date=self.date.isoformat(), description=self.description, - amount=self.amount, + amount=str(self.amount), split=self.split, + category=category if category else None, + tags=[{"tag": tag.tag} for tag in self.tags], + note=self.note, type=self.type, - category=self.category.format if self.category else None, - # TODO note - tags=[tag.format for tag in self.tags] if self.tags else None, + ) + + @classmethod + def deserialize( + cls, map: Mapping[str, Any] + ) -> Transaction | BankTransaction | MoneyTransaction | SplitTransaction: + match map["type"]: + case "bank": + return BankTransaction.deserialize(map) + case "money": + return MoneyTransaction.deserialize(map) + case "split": + raise NotImplementedError + case _: + return cls._deserialize(map) + + @classmethod + def _deserialize(cls, map: Mapping[str, Any]) -> Self: + category = None + if map["category"]: + category = TransactionCategory(map["category"]["name"]) + if map["category"]["selector"]: + category.selector = map["category"]["selector"] + + tags: set[TransactionTag] = set() + if map["tags"]: + tags = set(t["tag"] for t in map["tags"]) + + return cls( + dt.date.fromisoformat(map["date"]), + map["description"], + map["amount"], + map["split"], + category, + tags, + map["note"], ) def __lt__(self, other: Transaction): @@ -132,41 +193,47 @@ idfk = Annotated[ ] -class BankTransaction(Transaction): +class BankTransaction(Transaction, Serializable): bank: Mapped[Optional[bankfk]] = mapped_column(default=None) __mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"} - @property - def format(self) -> dict[str, Any]: - return super().format | dict(bank=self.bank) + def serialize(self) -> Mapping[str, Any]: + map = cast(MutableMapping[str, Any], super().serialize()) + map["bank"] = self.bank + return map + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + transaction = cls._deserialize(map) + transaction.bank = map["bank"] + return transaction -class MoneyTransaction(Transaction): +class MoneyTransaction(Transaction, Serializable): __mapper_args__ = {"polymorphic_identity": "money"} + def serialize(self) -> Mapping[str, Any]: + return super().serialize() + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls._deserialize(map) + class SplitTransaction(Transaction): original: Mapped[Optional[idfk]] = mapped_column(default=None) __mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"} - @property - def format(self) -> dict[str, Any]: - return super().format | dict(original=self.original) - -class CategoryGroup(Base, Export): +class CategoryGroup(Base, Serializable): __tablename__ = "category_groups" name: Mapped[str] = mapped_column(primary_key=True) - @property - def format(self) -> dict[str, Any]: - return dict(name=self.name) - -class Category(Base, Export): +class Category(Base, Serializable, repr=False): __tablename__ = "categories" name: Mapped[str] = mapped_column(primary_key=True) @@ -175,27 +242,62 @@ class Category(Base, Export): ) rules: Mapped[list[CategoryRule]] = relationship( - cascade="all, delete-orphan", passive_deletes=True, default_factory=list + cascade="all, delete-orphan", + passive_deletes=True, + default_factory=list, + lazy="joined", ) schedule: Mapped[Optional[CategorySchedule]] = relationship( - cascade="all, delete-orphan", passive_deletes=True, default=None + cascade="all, delete-orphan", passive_deletes=True, default=None, lazy="joined" ) + def serialize(self) -> Mapping[str, Any]: + rules: Sequence[Mapping[str, Any]] = [] + for rule in self.rules: + rules.append( + { + "name": rule.name, + "start": rule.start, + "end": rule.end, + "description": rule.description, + "regex": rule.regex, + "bank": rule.bank, + "min": str(rule.min) if rule.min is not None else None, + "max": str(rule.max) if rule.max is not None else None, + } + ) + + schedule = None + if self.schedule: + schedule = { + "name": self.schedule.name, + "period": self.schedule.period.name if self.schedule.period else None, + "period_multiplier": self.schedule.period_multiplier, + "amount": self.schedule.amount, + } + + return dict( + name=self.name, + group=self.group, + rules=rules, + schedule=schedule, + ) + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls( + map["name"], + map["group"], + list(CategoryRule(**r) for r in map["rules"]), + CategorySchedule(**map["schedule"]) if map["schedule"] else None, + ) + def __repr__(self) -> str: return ( f"Category(name={self.name}, group={self.group}, #rules={len(self.rules)}," f" schedule={self.schedule})" ) - @property - def format(self) -> dict[str, Any]: - return dict( - name=self.name, - group=self.group if self.group else None, - rules=[rule.format for rule in self.rules], - schedule=self.schedule.format if self.schedule else None, - ) - catfk = Annotated[ str, @@ -212,7 +314,7 @@ class CategorySelector(enum.Enum): manual = enum.auto() -class TransactionCategory(Base, Export): +class TransactionCategory(Base): __tablename__ = "transactions_categorized" id: Mapped[idfk] = mapped_column(primary_key=True, init=False) @@ -224,12 +326,6 @@ class TransactionCategory(Base, Export): back_populates="category", init=False, compare=False ) - @property - def format(self): - return dict( - name=self.name, selector=self.selector.name - ) - class Note(Base): __tablename__ = "notes" @@ -238,7 +334,7 @@ class Note(Base): note: Mapped[str] -class Nordigen(Base, Export): +class Nordigen(Base): __tablename__ = "banks_nordigen" name: Mapped[bankfk] = mapped_column(primary_key=True, init=False) @@ -246,36 +342,51 @@ class Nordigen(Base, Export): requisition_id: Mapped[Optional[str]] invert: Mapped[Optional[bool]] = mapped_column(default=None) - @property - def format(self) -> dict[str, Any]: - return dict( - name=self.name, - bank_id=self.bank_id, - requisition_id=self.requisition_id, - invert=self.invert, - ) - -class Tag(Base): +class Tag(Base, Serializable): __tablename__ = "tags" name: Mapped[str] = mapped_column(primary_key=True) rules: Mapped[list[TagRule]] = relationship( - cascade="all, delete-orphan", passive_deletes=True, default_factory=list + cascade="all, delete-orphan", + passive_deletes=True, + default_factory=list, + lazy="joined", ) + def serialize(self) -> Mapping[str, Any]: + rules: Sequence[Mapping[str, Any]] = [] + for rule in self.rules: + rules.append( + { + "name": rule.tag, + "start": rule.start, + "end": rule.end, + "description": rule.description, + "regex": rule.regex, + "bank": rule.bank, + "min": str(rule.min) if rule.min is not None else None, + "max": str(rule.max) if rule.max is not None else None, + } + ) -class TransactionTag(Base, Export, unsafe_hash=True): + return dict(name=self.name, rules=rules) + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls( + map["name"], + list(TagRule(**r) for r in map["rules"]), + ) + + +class TransactionTag(Base, unsafe_hash=True): __tablename__ = "transactions_tagged" id: Mapped[idfk] = mapped_column(primary_key=True, init=False) tag: Mapped[str] = mapped_column(ForeignKey(Tag.name), primary_key=True) - @property - def format(self): - return dict(tag=self.tag) - class SchedulePeriod(enum.Enum): daily = enum.auto() @@ -284,7 +395,7 @@ class SchedulePeriod(enum.Enum): yearly = enum.auto() -class CategorySchedule(Base, Export): +class CategorySchedule(Base): __tablename__ = "category_schedules" name: Mapped[catfk] = mapped_column(primary_key=True) @@ -292,15 +403,6 @@ class CategorySchedule(Base, Export): period_multiplier: Mapped[Optional[int]] amount: Mapped[Optional[int]] - @property - def format(self) -> dict[str, Any]: - return dict( - name=self.name, - period=self.period, - period_multiplier=self.period_multiplier, - amount=self.amount, - ) - class Link(Base): __tablename__ = "links" @@ -309,7 +411,7 @@ class Link(Base): link: Mapped[idfk] = mapped_column(primary_key=True) -class Rule(Base, Export, init=False): +class Rule(Base, init=False): __tablename__ = "rules" id: Mapped[idpk] = mapped_column(init=False) @@ -355,19 +457,6 @@ class Rule(Base, Export, init=False): return False - @property - def format(self) -> dict[str, Any]: - return dict( - start=self.start, - end=self.end, - description=self.description, - regex=self.regex, - bank=self.bank, - min=self.min, - max=self.max, - type=self.type, - ) - @staticmethod def exists(r: Optional[Any], op: Callable[[Any], bool]) -> bool: return op(r) if r is not None else True @@ -388,10 +477,6 @@ class CategoryRule(Rule): "polymorphic_identity": "category_rule", } - @property - def format(self) -> dict[str, Any]: - return super().format | dict(name=self.name) - def __init__(self, name: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.name = name @@ -412,10 +497,6 @@ class TagRule(Rule): "polymorphic_identity": "tag_rule", } - @property - def format(self) -> dict[str, Any]: - return super().format | dict(tag=self.tag) - def __init__(self, name: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.tag = name diff --git a/pfbudget/utils/serializer.py b/pfbudget/utils/serializer.py deleted file mode 100644 index 35d3657..0000000 --- a/pfbudget/utils/serializer.py +++ /dev/null @@ -1,54 +0,0 @@ -from collections.abc import Mapping -from dataclasses import fields -from functools import singledispatch -from typing import Any - -from pfbudget.db.model import Transaction, TransactionCategory, TransactionTag, Note - - -class NotSerializableError(Exception): - pass - - -@singledispatch -def serialize(obj: Any) -> Mapping[str, Any]: - return {field.name: getattr(obj, field.name) for field in fields(obj)} - - -@serialize.register -def _(obj: Transaction) -> Mapping[str, Any]: - category = None - if obj.category: - category = { - "name": obj.category.name, - "selector": str(obj.category.selector) - if obj.category.selector - else None, - } - - return dict( - id=obj.id, - date=obj.date.isoformat(), - description=obj.description, - amount=str(obj.amount), - split=obj.split, - category=category if category else None, - tags=[{"tag": tag.tag} for tag in obj.tags], - note=obj.note, - type=obj.type, - ) - - -@serialize.register -def _(_: TransactionCategory) -> Mapping[str, Any]: - raise NotSerializableError("TransactionCategory") - - -@serialize.register -def _(_: TransactionTag) -> Mapping[str, Any]: - raise NotSerializableError("TransactionTag") - - -@serialize.register -def _(_: Note) -> Mapping[str, Any]: - raise NotSerializableError("Note") diff --git a/tests/mocks/categories.py b/tests/mocks/categories.py index 5a40cb7..9f5540d 100644 --- a/tests/mocks/categories.py +++ b/tests/mocks/categories.py @@ -1,11 +1,20 @@ from decimal import Decimal -from pfbudget.db.model import Category, CategoryRule, Tag, TagRule +from pfbudget.db.model import Category, CategoryGroup, CategoryRule, Tag, TagRule category_null = Category("null") +categorygroup1 = CategoryGroup("group#1") + category1 = Category( "cat#1", + "group#1", + rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))], +) + +category2 = Category( + "cat#2", + "group#1", rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))], ) diff --git a/tests/mocks/client.py b/tests/mocks/client.py new file mode 100644 index 0000000..f8118db --- /dev/null +++ b/tests/mocks/client.py @@ -0,0 +1,11 @@ +from pfbudget.db.client import Client +from pfbudget.db.model import Base + + +class MockClient(Client): + def __init__(self): + url = "sqlite://" + super().__init__( + url, execution_options={"schema_translate_map": {"pfbudget": None}} + ) + Base.metadata.create_all(self.engine) diff --git a/tests/mocks/transactions.py b/tests/mocks/transactions.py index a540fbf..6bdc311 100644 --- a/tests/mocks/transactions.py +++ b/tests/mocks/transactions.py @@ -2,7 +2,10 @@ from datetime import date from decimal import Decimal from pfbudget.db.model import ( + BankTransaction, CategorySelector, + MoneyTransaction, + SplitTransaction, Transaction, TransactionCategory, ) @@ -26,3 +29,22 @@ simple_transformed = [ category=TransactionCategory("category#2", CategorySelector.algorithm), ), ] + +bank = [ + BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#1"), + BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#2"), +] + +money = [ + MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")), + MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")), +] + +__original = Transaction(date(2023, 1, 1), "", Decimal("-10"), split=True) +__original.id = 1 + +split = [ + __original, + SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id), + SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id), +] diff --git a/tests/test_backup.py b/tests/test_backup.py new file mode 100644 index 0000000..037d772 --- /dev/null +++ b/tests/test_backup.py @@ -0,0 +1,60 @@ +from pathlib import Path +from typing import Any, Sequence, Type +import pytest + +from mocks import banks, categories, transactions +from mocks.client import MockClient + +from pfbudget.common.types import ExportFormat +from pfbudget.core.command import ExportCommand, ImportCommand +from pfbudget.db.client import Client +from pfbudget.db.model import Bank, Category, CategoryGroup, Tag, Transaction + + +@pytest.fixture +def client() -> Client: + return MockClient() + + +params = [ + (transactions.simple, Transaction), + (transactions.simple_transformed, Transaction), + (transactions.bank, Transaction), + (transactions.money, Transaction), + # (transactions.split, Transaction), NotImplemented + ([banks.checking, banks.cc], Bank), + ([categories.category_null, categories.category1, categories.category2], Category), + ( + [ + categories.categorygroup1, + categories.category_null, + categories.category1, + categories.category2, + ], + CategoryGroup, + ), + ([categories.tag_1], Tag), +] + + +class TestBackup: + @pytest.mark.parametrize("input, what", params) + def test_import(self, tmp_path: Path, input: Sequence[Any], what: Type[Any]): + file = tmp_path / "test.json" + + client = MockClient() + client.insert(input) + originals = client.select(what) + + assert originals + + command = ExportCommand(client, what, file, ExportFormat.JSON) + command.execute() + + other = MockClient() + command = ImportCommand(other, what, file, ExportFormat.JSON) + command.execute() + + imported = other.select(what) + + assert originals == imported diff --git a/tests/test_command.py b/tests/test_command.py index f55318c..cdc6353 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -2,15 +2,15 @@ from collections.abc import Sequence import json from pathlib import Path import pickle -from typing import Any +import pytest +from typing import Any, cast import mocks.transactions from pfbudget.common.types import ExportFormat -from pfbudget.core.command import ExportCommand +from pfbudget.core.command import ExportCommand, ImportCommand from pfbudget.db.client import Client from pfbudget.db.model import Transaction -from pfbudget.utils.serializer import serialize class FakeClient(Client): @@ -40,8 +40,13 @@ class FakeClient(Client): self._transactions = value +@pytest.fixture +def client() -> Client: + return FakeClient() + + class TestCommand: - def test_export_json(self, tmp_path: Path): + def test_export_json(self, tmp_path: Path, client: Client): client = FakeClient() file = tmp_path / "test.json" command = ExportCommand(client, Transaction, file, ExportFormat.JSON) @@ -49,19 +54,18 @@ class TestCommand: with open(file, newline="") as f: result = json.load(f) - assert result == [serialize(t) for t in mocks.transactions.simple] + assert result == [t.serialize() for t in mocks.transactions.simple] - client.transactions = mocks.transactions.simple_transformed + cast(FakeClient, client).transactions = mocks.transactions.simple_transformed command.execute() with open(file, newline="") as f: result = json.load(f) assert result == [ - serialize(t) for t in mocks.transactions.simple_transformed + t.serialize() for t in mocks.transactions.simple_transformed ] - def test_export_pickle(self, tmp_path: Path): - client = FakeClient() + def test_export_pickle(self, tmp_path: Path, client: Client): file = tmp_path / "test.pickle" command = ExportCommand(client, Transaction, file, ExportFormat.pickle) command.execute() @@ -70,9 +74,22 @@ class TestCommand: result = pickle.load(f) assert result == mocks.transactions.simple - client.transactions = mocks.transactions.simple_transformed + cast(FakeClient, client).transactions = mocks.transactions.simple_transformed command.execute() with open(file, "rb") as f: result = pickle.load(f) assert result == mocks.transactions.simple_transformed + + def test_import(self, tmp_path: Path, client: Client): + file = tmp_path / "test" + for format in list(ExportFormat): + command = ExportCommand(client, Transaction, file, format) + command.execute() + + command = ImportCommand(client, Transaction, file, format) + command.execute() + + transactions = cast(FakeClient, client).transactions + assert len(transactions) > 0 + assert transactions == client.select(Transaction) diff --git a/tests/test_database.py b/tests/test_database.py index d1187a0..848ed74 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -2,11 +2,12 @@ from datetime import date from decimal import Decimal import pytest +from mocks.client import MockClient + from pfbudget.db.client import Client from pfbudget.db.model import ( AccountType, Bank, - Base, Nordigen, CategorySelector, Transaction, @@ -16,10 +17,7 @@ from pfbudget.db.model import ( @pytest.fixture def client() -> Client: - url = "sqlite://" - client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}}) - Base.metadata.create_all(client.engine) - return client + return MockClient() @pytest.fixture