ImportCommand and Serializable types

The new command ImportCommand takes a Serializable type, from which it
can call the deserialize method to generate a DB ORM type. The
Serializable interface also declares the serialize method.

(De)serialization moved to the ORM types, due to the inability to
properly use overloading.
Possible improvement for the future is to merge serialization
information on JSONDecoder/Encoder classes.

Adds a MockClient with the in-memory SQLite DB which can be used by
tests.
Most types export/import functionally tested using two DBs and comparing
entries.
This commit is contained in:
Luís Murta 2023-05-11 00:26:18 +01:00
parent 21099909c9
commit 638b833c74
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
9 changed files with 345 additions and 177 deletions

View File

@ -2,11 +2,11 @@ from abc import ABC, abstractmethod
import json
from pathlib import Path
import pickle
from typing import Any, Type
from typing import Type
from pfbudget.common.types import ExportFormat
from pfbudget.db.client import Client
from pfbudget.utils.serializer import serialize
from pfbudget.db.model import Serializable
class Command(ABC):
@ -19,7 +19,9 @@ class Command(ABC):
class ExportCommand(Command):
def __init__(self, client: Client, what: Type[Any], fn: Path, format: ExportFormat):
def __init__(
self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat
):
self.__client = client
self.what = what
self.fn = fn
@ -30,7 +32,29 @@ class ExportCommand(Command):
match self.format:
case ExportFormat.JSON:
with open(self.fn, "w", newline="") as f:
json.dump([serialize(e) for e in values], f, indent=4)
json.dump([e.serialize() for e in values], f, indent=4)
case ExportFormat.pickle:
with open(self.fn, "wb") as f:
pickle.dump(values, f)
class ImportCommand(Command):
def __init__(
self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat
):
self.__client = client
self.what = what
self.fn = fn
self.format = format
def execute(self) -> None:
match self.format:
case ExportFormat.JSON:
with open(self.fn, "r") as f:
values = json.load(f)
values = [self.what.deserialize(v) for v in values]
case ExportFormat.pickle:
with open(self.fn, "rb") as f:
values = pickle.load(f)
self.__client.insert(values)

View File

@ -1,9 +1,11 @@
from __future__ import annotations
from collections.abc import Mapping, MutableMapping, Sequence
from dataclasses import dataclass, fields
import datetime as dt
import decimal
import enum
import re
from typing import Annotated, Any, Callable, Optional
from typing import Annotated, Any, Callable, Optional, Self, cast
from sqlalchemy import (
BigInteger,
@ -41,6 +43,16 @@ class Base(MappedAsDataclass, DeclarativeBase):
}
@dataclass
class Serializable:
def serialize(self) -> Mapping[str, Any]:
return {field.name: getattr(self, field.name) for field in fields(self)}
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(**map)
class AccountType(enum.Enum):
checking = enum.auto()
savings = enum.auto()
@ -50,13 +62,7 @@ class AccountType(enum.Enum):
MASTERCARD = enum.auto()
class Export:
@property
def format(self) -> dict[str, Any]:
raise NotImplementedError
class Bank(Base, Export):
class Bank(Base, Serializable):
__tablename__ = "banks"
name: Mapped[str] = mapped_column(primary_key=True)
@ -65,15 +71,29 @@ class Bank(Base, Export):
nordigen: Mapped[Optional[Nordigen]] = relationship(default=None, lazy="joined")
@property
def format(self) -> dict[str, Any]:
def serialize(self) -> Mapping[str, Any]:
nordigen = None
if self.nordigen:
nordigen = {
"bank_id": self.nordigen.bank_id,
"requisition_id": self.nordigen.requisition_id,
"invert": self.nordigen.invert,
}
return dict(
name=self.name,
BIC=self.BIC,
type=self.type,
nordigen=self.nordigen.format if self.nordigen else None,
type=self.type.name,
nordigen=nordigen,
)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
bank = cls(map["name"], map["BIC"], map["type"])
if map["nordigen"]:
bank.nordigen = Nordigen(**map["nordigen"])
return bank
bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))]
@ -88,7 +108,7 @@ idpk = Annotated[
money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))]
class Transaction(Base, Export):
class Transaction(Base, Serializable):
__tablename__ = "transactions"
id: Mapped[idpk] = mapped_column(init=False)
@ -109,18 +129,59 @@ class Transaction(Base, Export):
type: Mapped[str] = mapped_column(init=False)
__mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "transaction"}
@property
def format(self) -> dict[str, Any]:
def serialize(self) -> Mapping[str, Any]:
category = None
if self.category:
category = {
"name": self.category.name,
"selector": self.category.selector.name,
}
return dict(
id=self.id,
date=self.date,
date=self.date.isoformat(),
description=self.description,
amount=self.amount,
amount=str(self.amount),
split=self.split,
category=category if category else None,
tags=[{"tag": tag.tag} for tag in self.tags],
note=self.note,
type=self.type,
category=self.category.format if self.category else None,
# TODO note
tags=[tag.format for tag in self.tags] if self.tags else None,
)
@classmethod
def deserialize(
cls, map: Mapping[str, Any]
) -> Transaction | BankTransaction | MoneyTransaction | SplitTransaction:
match map["type"]:
case "bank":
return BankTransaction.deserialize(map)
case "money":
return MoneyTransaction.deserialize(map)
case "split":
raise NotImplementedError
case _:
return cls._deserialize(map)
@classmethod
def _deserialize(cls, map: Mapping[str, Any]) -> Self:
category = None
if map["category"]:
category = TransactionCategory(map["category"]["name"])
if map["category"]["selector"]:
category.selector = map["category"]["selector"]
tags: set[TransactionTag] = set()
if map["tags"]:
tags = set(t["tag"] for t in map["tags"])
return cls(
dt.date.fromisoformat(map["date"]),
map["description"],
map["amount"],
map["split"],
category,
tags,
map["note"],
)
def __lt__(self, other: Transaction):
@ -132,41 +193,47 @@ idfk = Annotated[
]
class BankTransaction(Transaction):
class BankTransaction(Transaction, Serializable):
bank: Mapped[Optional[bankfk]] = mapped_column(default=None)
__mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"}
@property
def format(self) -> dict[str, Any]:
return super().format | dict(bank=self.bank)
def serialize(self) -> Mapping[str, Any]:
map = cast(MutableMapping[str, Any], super().serialize())
map["bank"] = self.bank
return map
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
transaction = cls._deserialize(map)
transaction.bank = map["bank"]
return transaction
class MoneyTransaction(Transaction):
class MoneyTransaction(Transaction, Serializable):
__mapper_args__ = {"polymorphic_identity": "money"}
def serialize(self) -> Mapping[str, Any]:
return super().serialize()
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls._deserialize(map)
class SplitTransaction(Transaction):
original: Mapped[Optional[idfk]] = mapped_column(default=None)
__mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"}
@property
def format(self) -> dict[str, Any]:
return super().format | dict(original=self.original)
class CategoryGroup(Base, Export):
class CategoryGroup(Base, Serializable):
__tablename__ = "category_groups"
name: Mapped[str] = mapped_column(primary_key=True)
@property
def format(self) -> dict[str, Any]:
return dict(name=self.name)
class Category(Base, Export):
class Category(Base, Serializable, repr=False):
__tablename__ = "categories"
name: Mapped[str] = mapped_column(primary_key=True)
@ -175,10 +242,54 @@ class Category(Base, Export):
)
rules: Mapped[list[CategoryRule]] = relationship(
cascade="all, delete-orphan", passive_deletes=True, default_factory=list
cascade="all, delete-orphan",
passive_deletes=True,
default_factory=list,
lazy="joined",
)
schedule: Mapped[Optional[CategorySchedule]] = relationship(
cascade="all, delete-orphan", passive_deletes=True, default=None
cascade="all, delete-orphan", passive_deletes=True, default=None, lazy="joined"
)
def serialize(self) -> Mapping[str, Any]:
rules: Sequence[Mapping[str, Any]] = []
for rule in self.rules:
rules.append(
{
"name": rule.name,
"start": rule.start,
"end": rule.end,
"description": rule.description,
"regex": rule.regex,
"bank": rule.bank,
"min": str(rule.min) if rule.min is not None else None,
"max": str(rule.max) if rule.max is not None else None,
}
)
schedule = None
if self.schedule:
schedule = {
"name": self.schedule.name,
"period": self.schedule.period.name if self.schedule.period else None,
"period_multiplier": self.schedule.period_multiplier,
"amount": self.schedule.amount,
}
return dict(
name=self.name,
group=self.group,
rules=rules,
schedule=schedule,
)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(
map["name"],
map["group"],
list(CategoryRule(**r) for r in map["rules"]),
CategorySchedule(**map["schedule"]) if map["schedule"] else None,
)
def __repr__(self) -> str:
@ -187,15 +298,6 @@ class Category(Base, Export):
f" schedule={self.schedule})"
)
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
group=self.group if self.group else None,
rules=[rule.format for rule in self.rules],
schedule=self.schedule.format if self.schedule else None,
)
catfk = Annotated[
str,
@ -212,7 +314,7 @@ class CategorySelector(enum.Enum):
manual = enum.auto()
class TransactionCategory(Base, Export):
class TransactionCategory(Base):
__tablename__ = "transactions_categorized"
id: Mapped[idfk] = mapped_column(primary_key=True, init=False)
@ -224,12 +326,6 @@ class TransactionCategory(Base, Export):
back_populates="category", init=False, compare=False
)
@property
def format(self):
return dict(
name=self.name, selector=self.selector.name
)
class Note(Base):
__tablename__ = "notes"
@ -238,7 +334,7 @@ class Note(Base):
note: Mapped[str]
class Nordigen(Base, Export):
class Nordigen(Base):
__tablename__ = "banks_nordigen"
name: Mapped[bankfk] = mapped_column(primary_key=True, init=False)
@ -246,36 +342,51 @@ class Nordigen(Base, Export):
requisition_id: Mapped[Optional[str]]
invert: Mapped[Optional[bool]] = mapped_column(default=None)
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
bank_id=self.bank_id,
requisition_id=self.requisition_id,
invert=self.invert,
)
class Tag(Base):
class Tag(Base, Serializable):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(primary_key=True)
rules: Mapped[list[TagRule]] = relationship(
cascade="all, delete-orphan", passive_deletes=True, default_factory=list
cascade="all, delete-orphan",
passive_deletes=True,
default_factory=list,
lazy="joined",
)
def serialize(self) -> Mapping[str, Any]:
rules: Sequence[Mapping[str, Any]] = []
for rule in self.rules:
rules.append(
{
"name": rule.tag,
"start": rule.start,
"end": rule.end,
"description": rule.description,
"regex": rule.regex,
"bank": rule.bank,
"min": str(rule.min) if rule.min is not None else None,
"max": str(rule.max) if rule.max is not None else None,
}
)
return dict(name=self.name, rules=rules)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(
map["name"],
list(TagRule(**r) for r in map["rules"]),
)
class TransactionTag(Base, Export, unsafe_hash=True):
class TransactionTag(Base, unsafe_hash=True):
__tablename__ = "transactions_tagged"
id: Mapped[idfk] = mapped_column(primary_key=True, init=False)
tag: Mapped[str] = mapped_column(ForeignKey(Tag.name), primary_key=True)
@property
def format(self):
return dict(tag=self.tag)
class SchedulePeriod(enum.Enum):
daily = enum.auto()
@ -284,7 +395,7 @@ class SchedulePeriod(enum.Enum):
yearly = enum.auto()
class CategorySchedule(Base, Export):
class CategorySchedule(Base):
__tablename__ = "category_schedules"
name: Mapped[catfk] = mapped_column(primary_key=True)
@ -292,15 +403,6 @@ class CategorySchedule(Base, Export):
period_multiplier: Mapped[Optional[int]]
amount: Mapped[Optional[int]]
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
period=self.period,
period_multiplier=self.period_multiplier,
amount=self.amount,
)
class Link(Base):
__tablename__ = "links"
@ -309,7 +411,7 @@ class Link(Base):
link: Mapped[idfk] = mapped_column(primary_key=True)
class Rule(Base, Export, init=False):
class Rule(Base, init=False):
__tablename__ = "rules"
id: Mapped[idpk] = mapped_column(init=False)
@ -355,19 +457,6 @@ class Rule(Base, Export, init=False):
return False
@property
def format(self) -> dict[str, Any]:
return dict(
start=self.start,
end=self.end,
description=self.description,
regex=self.regex,
bank=self.bank,
min=self.min,
max=self.max,
type=self.type,
)
@staticmethod
def exists(r: Optional[Any], op: Callable[[Any], bool]) -> bool:
return op(r) if r is not None else True
@ -388,10 +477,6 @@ class CategoryRule(Rule):
"polymorphic_identity": "category_rule",
}
@property
def format(self) -> dict[str, Any]:
return super().format | dict(name=self.name)
def __init__(self, name: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.name = name
@ -412,10 +497,6 @@ class TagRule(Rule):
"polymorphic_identity": "tag_rule",
}
@property
def format(self) -> dict[str, Any]:
return super().format | dict(tag=self.tag)
def __init__(self, name: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.tag = name

View File

@ -1,54 +0,0 @@
from collections.abc import Mapping
from dataclasses import fields
from functools import singledispatch
from typing import Any
from pfbudget.db.model import Transaction, TransactionCategory, TransactionTag, Note
class NotSerializableError(Exception):
pass
@singledispatch
def serialize(obj: Any) -> Mapping[str, Any]:
return {field.name: getattr(obj, field.name) for field in fields(obj)}
@serialize.register
def _(obj: Transaction) -> Mapping[str, Any]:
category = None
if obj.category:
category = {
"name": obj.category.name,
"selector": str(obj.category.selector)
if obj.category.selector
else None,
}
return dict(
id=obj.id,
date=obj.date.isoformat(),
description=obj.description,
amount=str(obj.amount),
split=obj.split,
category=category if category else None,
tags=[{"tag": tag.tag} for tag in obj.tags],
note=obj.note,
type=obj.type,
)
@serialize.register
def _(_: TransactionCategory) -> Mapping[str, Any]:
raise NotSerializableError("TransactionCategory")
@serialize.register
def _(_: TransactionTag) -> Mapping[str, Any]:
raise NotSerializableError("TransactionTag")
@serialize.register
def _(_: Note) -> Mapping[str, Any]:
raise NotSerializableError("Note")

View File

@ -1,11 +1,20 @@
from decimal import Decimal
from pfbudget.db.model import Category, CategoryRule, Tag, TagRule
from pfbudget.db.model import Category, CategoryGroup, CategoryRule, Tag, TagRule
category_null = Category("null")
categorygroup1 = CategoryGroup("group#1")
category1 = Category(
"cat#1",
"group#1",
rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))],
)
category2 = Category(
"cat#2",
"group#1",
rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))],
)

11
tests/mocks/client.py Normal file
View File

@ -0,0 +1,11 @@
from pfbudget.db.client import Client
from pfbudget.db.model import Base
class MockClient(Client):
def __init__(self):
url = "sqlite://"
super().__init__(
url, execution_options={"schema_translate_map": {"pfbudget": None}}
)
Base.metadata.create_all(self.engine)

View File

@ -2,7 +2,10 @@ from datetime import date
from decimal import Decimal
from pfbudget.db.model import (
BankTransaction,
CategorySelector,
MoneyTransaction,
SplitTransaction,
Transaction,
TransactionCategory,
)
@ -26,3 +29,22 @@ simple_transformed = [
category=TransactionCategory("category#2", CategorySelector.algorithm),
),
]
bank = [
BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#1"),
BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#2"),
]
money = [
MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")),
MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")),
]
__original = Transaction(date(2023, 1, 1), "", Decimal("-10"), split=True)
__original.id = 1
split = [
__original,
SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id),
SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id),
]

60
tests/test_backup.py Normal file
View File

@ -0,0 +1,60 @@
from pathlib import Path
from typing import Any, Sequence, Type
import pytest
from mocks import banks, categories, transactions
from mocks.client import MockClient
from pfbudget.common.types import ExportFormat
from pfbudget.core.command import ExportCommand, ImportCommand
from pfbudget.db.client import Client
from pfbudget.db.model import Bank, Category, CategoryGroup, Tag, Transaction
@pytest.fixture
def client() -> Client:
return MockClient()
params = [
(transactions.simple, Transaction),
(transactions.simple_transformed, Transaction),
(transactions.bank, Transaction),
(transactions.money, Transaction),
# (transactions.split, Transaction), NotImplemented
([banks.checking, banks.cc], Bank),
([categories.category_null, categories.category1, categories.category2], Category),
(
[
categories.categorygroup1,
categories.category_null,
categories.category1,
categories.category2,
],
CategoryGroup,
),
([categories.tag_1], Tag),
]
class TestBackup:
@pytest.mark.parametrize("input, what", params)
def test_import(self, tmp_path: Path, input: Sequence[Any], what: Type[Any]):
file = tmp_path / "test.json"
client = MockClient()
client.insert(input)
originals = client.select(what)
assert originals
command = ExportCommand(client, what, file, ExportFormat.JSON)
command.execute()
other = MockClient()
command = ImportCommand(other, what, file, ExportFormat.JSON)
command.execute()
imported = other.select(what)
assert originals == imported

View File

@ -2,15 +2,15 @@ from collections.abc import Sequence
import json
from pathlib import Path
import pickle
from typing import Any
import pytest
from typing import Any, cast
import mocks.transactions
from pfbudget.common.types import ExportFormat
from pfbudget.core.command import ExportCommand
from pfbudget.core.command import ExportCommand, ImportCommand
from pfbudget.db.client import Client
from pfbudget.db.model import Transaction
from pfbudget.utils.serializer import serialize
class FakeClient(Client):
@ -40,8 +40,13 @@ class FakeClient(Client):
self._transactions = value
@pytest.fixture
def client() -> Client:
return FakeClient()
class TestCommand:
def test_export_json(self, tmp_path: Path):
def test_export_json(self, tmp_path: Path, client: Client):
client = FakeClient()
file = tmp_path / "test.json"
command = ExportCommand(client, Transaction, file, ExportFormat.JSON)
@ -49,19 +54,18 @@ class TestCommand:
with open(file, newline="") as f:
result = json.load(f)
assert result == [serialize(t) for t in mocks.transactions.simple]
assert result == [t.serialize() for t in mocks.transactions.simple]
client.transactions = mocks.transactions.simple_transformed
cast(FakeClient, client).transactions = mocks.transactions.simple_transformed
command.execute()
with open(file, newline="") as f:
result = json.load(f)
assert result == [
serialize(t) for t in mocks.transactions.simple_transformed
t.serialize() for t in mocks.transactions.simple_transformed
]
def test_export_pickle(self, tmp_path: Path):
client = FakeClient()
def test_export_pickle(self, tmp_path: Path, client: Client):
file = tmp_path / "test.pickle"
command = ExportCommand(client, Transaction, file, ExportFormat.pickle)
command.execute()
@ -70,9 +74,22 @@ class TestCommand:
result = pickle.load(f)
assert result == mocks.transactions.simple
client.transactions = mocks.transactions.simple_transformed
cast(FakeClient, client).transactions = mocks.transactions.simple_transformed
command.execute()
with open(file, "rb") as f:
result = pickle.load(f)
assert result == mocks.transactions.simple_transformed
def test_import(self, tmp_path: Path, client: Client):
file = tmp_path / "test"
for format in list(ExportFormat):
command = ExportCommand(client, Transaction, file, format)
command.execute()
command = ImportCommand(client, Transaction, file, format)
command.execute()
transactions = cast(FakeClient, client).transactions
assert len(transactions) > 0
assert transactions == client.select(Transaction)

View File

@ -2,11 +2,12 @@ from datetime import date
from decimal import Decimal
import pytest
from mocks.client import MockClient
from pfbudget.db.client import Client
from pfbudget.db.model import (
AccountType,
Bank,
Base,
Nordigen,
CategorySelector,
Transaction,
@ -16,10 +17,7 @@ from pfbudget.db.model import (
@pytest.fixture
def client() -> Client:
url = "sqlite://"
client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}})
Base.metadata.create_all(client.engine)
return client
return MockClient()
@pytest.fixture