diff --git a/pfbudget/core/command.py b/pfbudget/core/command.py index ae9e1c6..f901aef 100644 --- a/pfbudget/core/command.py +++ b/pfbudget/core/command.py @@ -2,11 +2,11 @@ from abc import ABC, abstractmethod import json from pathlib import Path import pickle -from typing import Any, Type +from typing import Type from pfbudget.common.types import ExportFormat from pfbudget.db.client import Client -from pfbudget.utils.serializer import serialize +from pfbudget.db.model import Serializable class Command(ABC): @@ -19,7 +19,9 @@ class Command(ABC): class ExportCommand(Command): - def __init__(self, client: Client, what: Type[Any], fn: Path, format: ExportFormat): + def __init__( + self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat + ): self.__client = client self.what = what self.fn = fn @@ -30,7 +32,29 @@ class ExportCommand(Command): match self.format: case ExportFormat.JSON: with open(self.fn, "w", newline="") as f: - json.dump([serialize(e) for e in values], f, indent=4) + json.dump([e.serialize() for e in values], f, indent=4) case ExportFormat.pickle: with open(self.fn, "wb") as f: pickle.dump(values, f) + + +class ImportCommand(Command): + def __init__( + self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat + ): + self.__client = client + self.what = what + self.fn = fn + self.format = format + + def execute(self) -> None: + match self.format: + case ExportFormat.JSON: + with open(self.fn, "r") as f: + values = json.load(f) + values = [self.what.deserialize(v) for v in values] + case ExportFormat.pickle: + with open(self.fn, "rb") as f: + values = pickle.load(f) + + self.__client.insert(values) diff --git a/pfbudget/db/model.py b/pfbudget/db/model.py index aa1c9fa..6dbef14 100644 --- a/pfbudget/db/model.py +++ b/pfbudget/db/model.py @@ -1,9 +1,11 @@ from __future__ import annotations +from collections.abc import Mapping, MutableMapping, Sequence +from dataclasses import dataclass, fields import datetime as dt import decimal import enum import re -from typing import Annotated, Any, Callable, Optional +from typing import Annotated, Any, Callable, Optional, Self, cast from sqlalchemy import ( BigInteger, @@ -41,6 +43,16 @@ class Base(MappedAsDataclass, DeclarativeBase): } +@dataclass +class Serializable: + def serialize(self) -> Mapping[str, Any]: + return {field.name: getattr(self, field.name) for field in fields(self)} + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls(**map) + + class AccountType(enum.Enum): checking = enum.auto() savings = enum.auto() @@ -50,13 +62,7 @@ class AccountType(enum.Enum): MASTERCARD = enum.auto() -class Export: - @property - def format(self) -> dict[str, Any]: - raise NotImplementedError - - -class Bank(Base, Export): +class Bank(Base, Serializable): __tablename__ = "banks" name: Mapped[str] = mapped_column(primary_key=True) @@ -65,15 +71,29 @@ class Bank(Base, Export): nordigen: Mapped[Optional[Nordigen]] = relationship(default=None, lazy="joined") - @property - def format(self) -> dict[str, Any]: + def serialize(self) -> Mapping[str, Any]: + nordigen = None + if self.nordigen: + nordigen = { + "bank_id": self.nordigen.bank_id, + "requisition_id": self.nordigen.requisition_id, + "invert": self.nordigen.invert, + } + return dict( name=self.name, BIC=self.BIC, - type=self.type, - nordigen=self.nordigen.format if self.nordigen else None, + type=self.type.name, + nordigen=nordigen, ) + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + bank = cls(map["name"], map["BIC"], map["type"]) + if map["nordigen"]: + bank.nordigen = Nordigen(**map["nordigen"]) + return bank + bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))] @@ -88,7 +108,7 @@ idpk = Annotated[ money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))] -class Transaction(Base, Export): +class Transaction(Base, Serializable): __tablename__ = "transactions" id: Mapped[idpk] = mapped_column(init=False) @@ -109,18 +129,59 @@ class Transaction(Base, Export): type: Mapped[str] = mapped_column(init=False) __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "transaction"} - @property - def format(self) -> dict[str, Any]: + def serialize(self) -> Mapping[str, Any]: + category = None + if self.category: + category = { + "name": self.category.name, + "selector": self.category.selector.name, + } + return dict( - id=self.id, - date=self.date, + date=self.date.isoformat(), description=self.description, - amount=self.amount, + amount=str(self.amount), split=self.split, + category=category if category else None, + tags=[{"tag": tag.tag} for tag in self.tags], + note=self.note, type=self.type, - category=self.category.format if self.category else None, - # TODO note - tags=[tag.format for tag in self.tags] if self.tags else None, + ) + + @classmethod + def deserialize( + cls, map: Mapping[str, Any] + ) -> Transaction | BankTransaction | MoneyTransaction | SplitTransaction: + match map["type"]: + case "bank": + return BankTransaction.deserialize(map) + case "money": + return MoneyTransaction.deserialize(map) + case "split": + raise NotImplementedError + case _: + return cls._deserialize(map) + + @classmethod + def _deserialize(cls, map: Mapping[str, Any]) -> Self: + category = None + if map["category"]: + category = TransactionCategory(map["category"]["name"]) + if map["category"]["selector"]: + category.selector = map["category"]["selector"] + + tags: set[TransactionTag] = set() + if map["tags"]: + tags = set(t["tag"] for t in map["tags"]) + + return cls( + dt.date.fromisoformat(map["date"]), + map["description"], + map["amount"], + map["split"], + category, + tags, + map["note"], ) def __lt__(self, other: Transaction): @@ -132,41 +193,47 @@ idfk = Annotated[ ] -class BankTransaction(Transaction): +class BankTransaction(Transaction, Serializable): bank: Mapped[Optional[bankfk]] = mapped_column(default=None) __mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"} - @property - def format(self) -> dict[str, Any]: - return super().format | dict(bank=self.bank) + def serialize(self) -> Mapping[str, Any]: + map = cast(MutableMapping[str, Any], super().serialize()) + map["bank"] = self.bank + return map + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + transaction = cls._deserialize(map) + transaction.bank = map["bank"] + return transaction -class MoneyTransaction(Transaction): +class MoneyTransaction(Transaction, Serializable): __mapper_args__ = {"polymorphic_identity": "money"} + def serialize(self) -> Mapping[str, Any]: + return super().serialize() + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls._deserialize(map) + class SplitTransaction(Transaction): original: Mapped[Optional[idfk]] = mapped_column(default=None) __mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"} - @property - def format(self) -> dict[str, Any]: - return super().format | dict(original=self.original) - -class CategoryGroup(Base, Export): +class CategoryGroup(Base, Serializable): __tablename__ = "category_groups" name: Mapped[str] = mapped_column(primary_key=True) - @property - def format(self) -> dict[str, Any]: - return dict(name=self.name) - -class Category(Base, Export): +class Category(Base, Serializable, repr=False): __tablename__ = "categories" name: Mapped[str] = mapped_column(primary_key=True) @@ -175,27 +242,62 @@ class Category(Base, Export): ) rules: Mapped[list[CategoryRule]] = relationship( - cascade="all, delete-orphan", passive_deletes=True, default_factory=list + cascade="all, delete-orphan", + passive_deletes=True, + default_factory=list, + lazy="joined", ) schedule: Mapped[Optional[CategorySchedule]] = relationship( - cascade="all, delete-orphan", passive_deletes=True, default=None + cascade="all, delete-orphan", passive_deletes=True, default=None, lazy="joined" ) + def serialize(self) -> Mapping[str, Any]: + rules: Sequence[Mapping[str, Any]] = [] + for rule in self.rules: + rules.append( + { + "name": rule.name, + "start": rule.start, + "end": rule.end, + "description": rule.description, + "regex": rule.regex, + "bank": rule.bank, + "min": str(rule.min) if rule.min is not None else None, + "max": str(rule.max) if rule.max is not None else None, + } + ) + + schedule = None + if self.schedule: + schedule = { + "name": self.schedule.name, + "period": self.schedule.period.name if self.schedule.period else None, + "period_multiplier": self.schedule.period_multiplier, + "amount": self.schedule.amount, + } + + return dict( + name=self.name, + group=self.group, + rules=rules, + schedule=schedule, + ) + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls( + map["name"], + map["group"], + list(CategoryRule(**r) for r in map["rules"]), + CategorySchedule(**map["schedule"]) if map["schedule"] else None, + ) + def __repr__(self) -> str: return ( f"Category(name={self.name}, group={self.group}, #rules={len(self.rules)}," f" schedule={self.schedule})" ) - @property - def format(self) -> dict[str, Any]: - return dict( - name=self.name, - group=self.group if self.group else None, - rules=[rule.format for rule in self.rules], - schedule=self.schedule.format if self.schedule else None, - ) - catfk = Annotated[ str, @@ -212,7 +314,7 @@ class CategorySelector(enum.Enum): manual = enum.auto() -class TransactionCategory(Base, Export): +class TransactionCategory(Base): __tablename__ = "transactions_categorized" id: Mapped[idfk] = mapped_column(primary_key=True, init=False) @@ -224,12 +326,6 @@ class TransactionCategory(Base, Export): back_populates="category", init=False, compare=False ) - @property - def format(self): - return dict( - name=self.name, selector=self.selector.name - ) - class Note(Base): __tablename__ = "notes" @@ -238,7 +334,7 @@ class Note(Base): note: Mapped[str] -class Nordigen(Base, Export): +class Nordigen(Base): __tablename__ = "banks_nordigen" name: Mapped[bankfk] = mapped_column(primary_key=True, init=False) @@ -246,36 +342,51 @@ class Nordigen(Base, Export): requisition_id: Mapped[Optional[str]] invert: Mapped[Optional[bool]] = mapped_column(default=None) - @property - def format(self) -> dict[str, Any]: - return dict( - name=self.name, - bank_id=self.bank_id, - requisition_id=self.requisition_id, - invert=self.invert, - ) - -class Tag(Base): +class Tag(Base, Serializable): __tablename__ = "tags" name: Mapped[str] = mapped_column(primary_key=True) rules: Mapped[list[TagRule]] = relationship( - cascade="all, delete-orphan", passive_deletes=True, default_factory=list + cascade="all, delete-orphan", + passive_deletes=True, + default_factory=list, + lazy="joined", ) + def serialize(self) -> Mapping[str, Any]: + rules: Sequence[Mapping[str, Any]] = [] + for rule in self.rules: + rules.append( + { + "name": rule.tag, + "start": rule.start, + "end": rule.end, + "description": rule.description, + "regex": rule.regex, + "bank": rule.bank, + "min": str(rule.min) if rule.min is not None else None, + "max": str(rule.max) if rule.max is not None else None, + } + ) -class TransactionTag(Base, Export, unsafe_hash=True): + return dict(name=self.name, rules=rules) + + @classmethod + def deserialize(cls, map: Mapping[str, Any]) -> Self: + return cls( + map["name"], + list(TagRule(**r) for r in map["rules"]), + ) + + +class TransactionTag(Base, unsafe_hash=True): __tablename__ = "transactions_tagged" id: Mapped[idfk] = mapped_column(primary_key=True, init=False) tag: Mapped[str] = mapped_column(ForeignKey(Tag.name), primary_key=True) - @property - def format(self): - return dict(tag=self.tag) - class SchedulePeriod(enum.Enum): daily = enum.auto() @@ -284,7 +395,7 @@ class SchedulePeriod(enum.Enum): yearly = enum.auto() -class CategorySchedule(Base, Export): +class CategorySchedule(Base): __tablename__ = "category_schedules" name: Mapped[catfk] = mapped_column(primary_key=True) @@ -292,15 +403,6 @@ class CategorySchedule(Base, Export): period_multiplier: Mapped[Optional[int]] amount: Mapped[Optional[int]] - @property - def format(self) -> dict[str, Any]: - return dict( - name=self.name, - period=self.period, - period_multiplier=self.period_multiplier, - amount=self.amount, - ) - class Link(Base): __tablename__ = "links" @@ -309,7 +411,7 @@ class Link(Base): link: Mapped[idfk] = mapped_column(primary_key=True) -class Rule(Base, Export, init=False): +class Rule(Base, init=False): __tablename__ = "rules" id: Mapped[idpk] = mapped_column(init=False) @@ -355,19 +457,6 @@ class Rule(Base, Export, init=False): return False - @property - def format(self) -> dict[str, Any]: - return dict( - start=self.start, - end=self.end, - description=self.description, - regex=self.regex, - bank=self.bank, - min=self.min, - max=self.max, - type=self.type, - ) - @staticmethod def exists(r: Optional[Any], op: Callable[[Any], bool]) -> bool: return op(r) if r is not None else True @@ -388,10 +477,6 @@ class CategoryRule(Rule): "polymorphic_identity": "category_rule", } - @property - def format(self) -> dict[str, Any]: - return super().format | dict(name=self.name) - def __init__(self, name: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.name = name @@ -412,10 +497,6 @@ class TagRule(Rule): "polymorphic_identity": "tag_rule", } - @property - def format(self) -> dict[str, Any]: - return super().format | dict(tag=self.tag) - def __init__(self, name: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.tag = name diff --git a/pfbudget/utils/serializer.py b/pfbudget/utils/serializer.py deleted file mode 100644 index 35d3657..0000000 --- a/pfbudget/utils/serializer.py +++ /dev/null @@ -1,54 +0,0 @@ -from collections.abc import Mapping -from dataclasses import fields -from functools import singledispatch -from typing import Any - -from pfbudget.db.model import Transaction, TransactionCategory, TransactionTag, Note - - -class NotSerializableError(Exception): - pass - - -@singledispatch -def serialize(obj: Any) -> Mapping[str, Any]: - return {field.name: getattr(obj, field.name) for field in fields(obj)} - - -@serialize.register -def _(obj: Transaction) -> Mapping[str, Any]: - category = None - if obj.category: - category = { - "name": obj.category.name, - "selector": str(obj.category.selector) - if obj.category.selector - else None, - } - - return dict( - id=obj.id, - date=obj.date.isoformat(), - description=obj.description, - amount=str(obj.amount), - split=obj.split, - category=category if category else None, - tags=[{"tag": tag.tag} for tag in obj.tags], - note=obj.note, - type=obj.type, - ) - - -@serialize.register -def _(_: TransactionCategory) -> Mapping[str, Any]: - raise NotSerializableError("TransactionCategory") - - -@serialize.register -def _(_: TransactionTag) -> Mapping[str, Any]: - raise NotSerializableError("TransactionTag") - - -@serialize.register -def _(_: Note) -> Mapping[str, Any]: - raise NotSerializableError("Note") diff --git a/tests/mocks/categories.py b/tests/mocks/categories.py index 5a40cb7..9f5540d 100644 --- a/tests/mocks/categories.py +++ b/tests/mocks/categories.py @@ -1,11 +1,20 @@ from decimal import Decimal -from pfbudget.db.model import Category, CategoryRule, Tag, TagRule +from pfbudget.db.model import Category, CategoryGroup, CategoryRule, Tag, TagRule category_null = Category("null") +categorygroup1 = CategoryGroup("group#1") + category1 = Category( "cat#1", + "group#1", + rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))], +) + +category2 = Category( + "cat#2", + "group#1", rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))], ) diff --git a/tests/mocks/client.py b/tests/mocks/client.py new file mode 100644 index 0000000..f8118db --- /dev/null +++ b/tests/mocks/client.py @@ -0,0 +1,11 @@ +from pfbudget.db.client import Client +from pfbudget.db.model import Base + + +class MockClient(Client): + def __init__(self): + url = "sqlite://" + super().__init__( + url, execution_options={"schema_translate_map": {"pfbudget": None}} + ) + Base.metadata.create_all(self.engine) diff --git a/tests/mocks/transactions.py b/tests/mocks/transactions.py index a540fbf..6bdc311 100644 --- a/tests/mocks/transactions.py +++ b/tests/mocks/transactions.py @@ -2,7 +2,10 @@ from datetime import date from decimal import Decimal from pfbudget.db.model import ( + BankTransaction, CategorySelector, + MoneyTransaction, + SplitTransaction, Transaction, TransactionCategory, ) @@ -26,3 +29,22 @@ simple_transformed = [ category=TransactionCategory("category#2", CategorySelector.algorithm), ), ] + +bank = [ + BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#1"), + BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#2"), +] + +money = [ + MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")), + MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")), +] + +__original = Transaction(date(2023, 1, 1), "", Decimal("-10"), split=True) +__original.id = 1 + +split = [ + __original, + SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id), + SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id), +] diff --git a/tests/test_backup.py b/tests/test_backup.py new file mode 100644 index 0000000..037d772 --- /dev/null +++ b/tests/test_backup.py @@ -0,0 +1,60 @@ +from pathlib import Path +from typing import Any, Sequence, Type +import pytest + +from mocks import banks, categories, transactions +from mocks.client import MockClient + +from pfbudget.common.types import ExportFormat +from pfbudget.core.command import ExportCommand, ImportCommand +from pfbudget.db.client import Client +from pfbudget.db.model import Bank, Category, CategoryGroup, Tag, Transaction + + +@pytest.fixture +def client() -> Client: + return MockClient() + + +params = [ + (transactions.simple, Transaction), + (transactions.simple_transformed, Transaction), + (transactions.bank, Transaction), + (transactions.money, Transaction), + # (transactions.split, Transaction), NotImplemented + ([banks.checking, banks.cc], Bank), + ([categories.category_null, categories.category1, categories.category2], Category), + ( + [ + categories.categorygroup1, + categories.category_null, + categories.category1, + categories.category2, + ], + CategoryGroup, + ), + ([categories.tag_1], Tag), +] + + +class TestBackup: + @pytest.mark.parametrize("input, what", params) + def test_import(self, tmp_path: Path, input: Sequence[Any], what: Type[Any]): + file = tmp_path / "test.json" + + client = MockClient() + client.insert(input) + originals = client.select(what) + + assert originals + + command = ExportCommand(client, what, file, ExportFormat.JSON) + command.execute() + + other = MockClient() + command = ImportCommand(other, what, file, ExportFormat.JSON) + command.execute() + + imported = other.select(what) + + assert originals == imported diff --git a/tests/test_command.py b/tests/test_command.py index f55318c..cdc6353 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -2,15 +2,15 @@ from collections.abc import Sequence import json from pathlib import Path import pickle -from typing import Any +import pytest +from typing import Any, cast import mocks.transactions from pfbudget.common.types import ExportFormat -from pfbudget.core.command import ExportCommand +from pfbudget.core.command import ExportCommand, ImportCommand from pfbudget.db.client import Client from pfbudget.db.model import Transaction -from pfbudget.utils.serializer import serialize class FakeClient(Client): @@ -40,8 +40,13 @@ class FakeClient(Client): self._transactions = value +@pytest.fixture +def client() -> Client: + return FakeClient() + + class TestCommand: - def test_export_json(self, tmp_path: Path): + def test_export_json(self, tmp_path: Path, client: Client): client = FakeClient() file = tmp_path / "test.json" command = ExportCommand(client, Transaction, file, ExportFormat.JSON) @@ -49,19 +54,18 @@ class TestCommand: with open(file, newline="") as f: result = json.load(f) - assert result == [serialize(t) for t in mocks.transactions.simple] + assert result == [t.serialize() for t in mocks.transactions.simple] - client.transactions = mocks.transactions.simple_transformed + cast(FakeClient, client).transactions = mocks.transactions.simple_transformed command.execute() with open(file, newline="") as f: result = json.load(f) assert result == [ - serialize(t) for t in mocks.transactions.simple_transformed + t.serialize() for t in mocks.transactions.simple_transformed ] - def test_export_pickle(self, tmp_path: Path): - client = FakeClient() + def test_export_pickle(self, tmp_path: Path, client: Client): file = tmp_path / "test.pickle" command = ExportCommand(client, Transaction, file, ExportFormat.pickle) command.execute() @@ -70,9 +74,22 @@ class TestCommand: result = pickle.load(f) assert result == mocks.transactions.simple - client.transactions = mocks.transactions.simple_transformed + cast(FakeClient, client).transactions = mocks.transactions.simple_transformed command.execute() with open(file, "rb") as f: result = pickle.load(f) assert result == mocks.transactions.simple_transformed + + def test_import(self, tmp_path: Path, client: Client): + file = tmp_path / "test" + for format in list(ExportFormat): + command = ExportCommand(client, Transaction, file, format) + command.execute() + + command = ImportCommand(client, Transaction, file, format) + command.execute() + + transactions = cast(FakeClient, client).transactions + assert len(transactions) > 0 + assert transactions == client.select(Transaction) diff --git a/tests/test_database.py b/tests/test_database.py index d1187a0..848ed74 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -2,11 +2,12 @@ from datetime import date from decimal import Decimal import pytest +from mocks.client import MockClient + from pfbudget.db.client import Client from pfbudget.db.model import ( AccountType, Bank, - Base, Nordigen, CategorySelector, Transaction, @@ -16,10 +17,7 @@ from pfbudget.db.model import ( @pytest.fixture def client() -> Client: - url = "sqlite://" - client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}}) - Base.metadata.create_all(client.engine) - return client + return MockClient() @pytest.fixture