ImportCommand and Serializable types

The new command ImportCommand takes a Serializable type, from which it
can call the deserialize method to generate a DB ORM type. The
Serializable interface also declares the serialize method.

(De)serialization moved to the ORM types, due to the inability to
properly use overloading.
Possible improvement for the future is to merge serialization
information on JSONDecoder/Encoder classes.

Adds a MockClient with the in-memory SQLite DB which can be used by
tests.
Most types export/import functionally tested using two DBs and comparing
entries.
This commit is contained in:
Luís Murta 2023-05-11 00:26:18 +01:00
parent 21099909c9
commit 638b833c74
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
9 changed files with 345 additions and 177 deletions

View File

@ -2,11 +2,11 @@ from abc import ABC, abstractmethod
import json import json
from pathlib import Path from pathlib import Path
import pickle import pickle
from typing import Any, Type from typing import Type
from pfbudget.common.types import ExportFormat from pfbudget.common.types import ExportFormat
from pfbudget.db.client import Client from pfbudget.db.client import Client
from pfbudget.utils.serializer import serialize from pfbudget.db.model import Serializable
class Command(ABC): class Command(ABC):
@ -19,7 +19,9 @@ class Command(ABC):
class ExportCommand(Command): class ExportCommand(Command):
def __init__(self, client: Client, what: Type[Any], fn: Path, format: ExportFormat): def __init__(
self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat
):
self.__client = client self.__client = client
self.what = what self.what = what
self.fn = fn self.fn = fn
@ -30,7 +32,29 @@ class ExportCommand(Command):
match self.format: match self.format:
case ExportFormat.JSON: case ExportFormat.JSON:
with open(self.fn, "w", newline="") as f: with open(self.fn, "w", newline="") as f:
json.dump([serialize(e) for e in values], f, indent=4) json.dump([e.serialize() for e in values], f, indent=4)
case ExportFormat.pickle: case ExportFormat.pickle:
with open(self.fn, "wb") as f: with open(self.fn, "wb") as f:
pickle.dump(values, f) pickle.dump(values, f)
class ImportCommand(Command):
def __init__(
self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat
):
self.__client = client
self.what = what
self.fn = fn
self.format = format
def execute(self) -> None:
match self.format:
case ExportFormat.JSON:
with open(self.fn, "r") as f:
values = json.load(f)
values = [self.what.deserialize(v) for v in values]
case ExportFormat.pickle:
with open(self.fn, "rb") as f:
values = pickle.load(f)
self.__client.insert(values)

View File

@ -1,9 +1,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, MutableMapping, Sequence
from dataclasses import dataclass, fields
import datetime as dt import datetime as dt
import decimal import decimal
import enum import enum
import re import re
from typing import Annotated, Any, Callable, Optional from typing import Annotated, Any, Callable, Optional, Self, cast
from sqlalchemy import ( from sqlalchemy import (
BigInteger, BigInteger,
@ -41,6 +43,16 @@ class Base(MappedAsDataclass, DeclarativeBase):
} }
@dataclass
class Serializable:
def serialize(self) -> Mapping[str, Any]:
return {field.name: getattr(self, field.name) for field in fields(self)}
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(**map)
class AccountType(enum.Enum): class AccountType(enum.Enum):
checking = enum.auto() checking = enum.auto()
savings = enum.auto() savings = enum.auto()
@ -50,13 +62,7 @@ class AccountType(enum.Enum):
MASTERCARD = enum.auto() MASTERCARD = enum.auto()
class Export: class Bank(Base, Serializable):
@property
def format(self) -> dict[str, Any]:
raise NotImplementedError
class Bank(Base, Export):
__tablename__ = "banks" __tablename__ = "banks"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
@ -65,15 +71,29 @@ class Bank(Base, Export):
nordigen: Mapped[Optional[Nordigen]] = relationship(default=None, lazy="joined") nordigen: Mapped[Optional[Nordigen]] = relationship(default=None, lazy="joined")
@property def serialize(self) -> Mapping[str, Any]:
def format(self) -> dict[str, Any]: nordigen = None
if self.nordigen:
nordigen = {
"bank_id": self.nordigen.bank_id,
"requisition_id": self.nordigen.requisition_id,
"invert": self.nordigen.invert,
}
return dict( return dict(
name=self.name, name=self.name,
BIC=self.BIC, BIC=self.BIC,
type=self.type, type=self.type.name,
nordigen=self.nordigen.format if self.nordigen else None, nordigen=nordigen,
) )
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
bank = cls(map["name"], map["BIC"], map["type"])
if map["nordigen"]:
bank.nordigen = Nordigen(**map["nordigen"])
return bank
bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))] bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))]
@ -88,7 +108,7 @@ idpk = Annotated[
money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))] money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))]
class Transaction(Base, Export): class Transaction(Base, Serializable):
__tablename__ = "transactions" __tablename__ = "transactions"
id: Mapped[idpk] = mapped_column(init=False) id: Mapped[idpk] = mapped_column(init=False)
@ -109,18 +129,59 @@ class Transaction(Base, Export):
type: Mapped[str] = mapped_column(init=False) type: Mapped[str] = mapped_column(init=False)
__mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "transaction"} __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "transaction"}
@property def serialize(self) -> Mapping[str, Any]:
def format(self) -> dict[str, Any]: category = None
if self.category:
category = {
"name": self.category.name,
"selector": self.category.selector.name,
}
return dict( return dict(
id=self.id, date=self.date.isoformat(),
date=self.date,
description=self.description, description=self.description,
amount=self.amount, amount=str(self.amount),
split=self.split, split=self.split,
category=category if category else None,
tags=[{"tag": tag.tag} for tag in self.tags],
note=self.note,
type=self.type, type=self.type,
category=self.category.format if self.category else None, )
# TODO note
tags=[tag.format for tag in self.tags] if self.tags else None, @classmethod
def deserialize(
cls, map: Mapping[str, Any]
) -> Transaction | BankTransaction | MoneyTransaction | SplitTransaction:
match map["type"]:
case "bank":
return BankTransaction.deserialize(map)
case "money":
return MoneyTransaction.deserialize(map)
case "split":
raise NotImplementedError
case _:
return cls._deserialize(map)
@classmethod
def _deserialize(cls, map: Mapping[str, Any]) -> Self:
category = None
if map["category"]:
category = TransactionCategory(map["category"]["name"])
if map["category"]["selector"]:
category.selector = map["category"]["selector"]
tags: set[TransactionTag] = set()
if map["tags"]:
tags = set(t["tag"] for t in map["tags"])
return cls(
dt.date.fromisoformat(map["date"]),
map["description"],
map["amount"],
map["split"],
category,
tags,
map["note"],
) )
def __lt__(self, other: Transaction): def __lt__(self, other: Transaction):
@ -132,41 +193,47 @@ idfk = Annotated[
] ]
class BankTransaction(Transaction): class BankTransaction(Transaction, Serializable):
bank: Mapped[Optional[bankfk]] = mapped_column(default=None) bank: Mapped[Optional[bankfk]] = mapped_column(default=None)
__mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"} __mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"}
@property def serialize(self) -> Mapping[str, Any]:
def format(self) -> dict[str, Any]: map = cast(MutableMapping[str, Any], super().serialize())
return super().format | dict(bank=self.bank) map["bank"] = self.bank
return map
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
transaction = cls._deserialize(map)
transaction.bank = map["bank"]
return transaction
class MoneyTransaction(Transaction): class MoneyTransaction(Transaction, Serializable):
__mapper_args__ = {"polymorphic_identity": "money"} __mapper_args__ = {"polymorphic_identity": "money"}
def serialize(self) -> Mapping[str, Any]:
return super().serialize()
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls._deserialize(map)
class SplitTransaction(Transaction): class SplitTransaction(Transaction):
original: Mapped[Optional[idfk]] = mapped_column(default=None) original: Mapped[Optional[idfk]] = mapped_column(default=None)
__mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"} __mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"}
@property
def format(self) -> dict[str, Any]:
return super().format | dict(original=self.original)
class CategoryGroup(Base, Serializable):
class CategoryGroup(Base, Export):
__tablename__ = "category_groups" __tablename__ = "category_groups"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
@property
def format(self) -> dict[str, Any]:
return dict(name=self.name)
class Category(Base, Serializable, repr=False):
class Category(Base, Export):
__tablename__ = "categories" __tablename__ = "categories"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
@ -175,10 +242,54 @@ class Category(Base, Export):
) )
rules: Mapped[list[CategoryRule]] = relationship( rules: Mapped[list[CategoryRule]] = relationship(
cascade="all, delete-orphan", passive_deletes=True, default_factory=list cascade="all, delete-orphan",
passive_deletes=True,
default_factory=list,
lazy="joined",
) )
schedule: Mapped[Optional[CategorySchedule]] = relationship( schedule: Mapped[Optional[CategorySchedule]] = relationship(
cascade="all, delete-orphan", passive_deletes=True, default=None cascade="all, delete-orphan", passive_deletes=True, default=None, lazy="joined"
)
def serialize(self) -> Mapping[str, Any]:
rules: Sequence[Mapping[str, Any]] = []
for rule in self.rules:
rules.append(
{
"name": rule.name,
"start": rule.start,
"end": rule.end,
"description": rule.description,
"regex": rule.regex,
"bank": rule.bank,
"min": str(rule.min) if rule.min is not None else None,
"max": str(rule.max) if rule.max is not None else None,
}
)
schedule = None
if self.schedule:
schedule = {
"name": self.schedule.name,
"period": self.schedule.period.name if self.schedule.period else None,
"period_multiplier": self.schedule.period_multiplier,
"amount": self.schedule.amount,
}
return dict(
name=self.name,
group=self.group,
rules=rules,
schedule=schedule,
)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(
map["name"],
map["group"],
list(CategoryRule(**r) for r in map["rules"]),
CategorySchedule(**map["schedule"]) if map["schedule"] else None,
) )
def __repr__(self) -> str: def __repr__(self) -> str:
@ -187,15 +298,6 @@ class Category(Base, Export):
f" schedule={self.schedule})" f" schedule={self.schedule})"
) )
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
group=self.group if self.group else None,
rules=[rule.format for rule in self.rules],
schedule=self.schedule.format if self.schedule else None,
)
catfk = Annotated[ catfk = Annotated[
str, str,
@ -212,7 +314,7 @@ class CategorySelector(enum.Enum):
manual = enum.auto() manual = enum.auto()
class TransactionCategory(Base, Export): class TransactionCategory(Base):
__tablename__ = "transactions_categorized" __tablename__ = "transactions_categorized"
id: Mapped[idfk] = mapped_column(primary_key=True, init=False) id: Mapped[idfk] = mapped_column(primary_key=True, init=False)
@ -224,12 +326,6 @@ class TransactionCategory(Base, Export):
back_populates="category", init=False, compare=False back_populates="category", init=False, compare=False
) )
@property
def format(self):
return dict(
name=self.name, selector=self.selector.name
)
class Note(Base): class Note(Base):
__tablename__ = "notes" __tablename__ = "notes"
@ -238,7 +334,7 @@ class Note(Base):
note: Mapped[str] note: Mapped[str]
class Nordigen(Base, Export): class Nordigen(Base):
__tablename__ = "banks_nordigen" __tablename__ = "banks_nordigen"
name: Mapped[bankfk] = mapped_column(primary_key=True, init=False) name: Mapped[bankfk] = mapped_column(primary_key=True, init=False)
@ -246,36 +342,51 @@ class Nordigen(Base, Export):
requisition_id: Mapped[Optional[str]] requisition_id: Mapped[Optional[str]]
invert: Mapped[Optional[bool]] = mapped_column(default=None) invert: Mapped[Optional[bool]] = mapped_column(default=None)
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
bank_id=self.bank_id,
requisition_id=self.requisition_id,
invert=self.invert,
)
class Tag(Base, Serializable):
class Tag(Base):
__tablename__ = "tags" __tablename__ = "tags"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
rules: Mapped[list[TagRule]] = relationship( rules: Mapped[list[TagRule]] = relationship(
cascade="all, delete-orphan", passive_deletes=True, default_factory=list cascade="all, delete-orphan",
passive_deletes=True,
default_factory=list,
lazy="joined",
)
def serialize(self) -> Mapping[str, Any]:
rules: Sequence[Mapping[str, Any]] = []
for rule in self.rules:
rules.append(
{
"name": rule.tag,
"start": rule.start,
"end": rule.end,
"description": rule.description,
"regex": rule.regex,
"bank": rule.bank,
"min": str(rule.min) if rule.min is not None else None,
"max": str(rule.max) if rule.max is not None else None,
}
)
return dict(name=self.name, rules=rules)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(
map["name"],
list(TagRule(**r) for r in map["rules"]),
) )
class TransactionTag(Base, Export, unsafe_hash=True): class TransactionTag(Base, unsafe_hash=True):
__tablename__ = "transactions_tagged" __tablename__ = "transactions_tagged"
id: Mapped[idfk] = mapped_column(primary_key=True, init=False) id: Mapped[idfk] = mapped_column(primary_key=True, init=False)
tag: Mapped[str] = mapped_column(ForeignKey(Tag.name), primary_key=True) tag: Mapped[str] = mapped_column(ForeignKey(Tag.name), primary_key=True)
@property
def format(self):
return dict(tag=self.tag)
class SchedulePeriod(enum.Enum): class SchedulePeriod(enum.Enum):
daily = enum.auto() daily = enum.auto()
@ -284,7 +395,7 @@ class SchedulePeriod(enum.Enum):
yearly = enum.auto() yearly = enum.auto()
class CategorySchedule(Base, Export): class CategorySchedule(Base):
__tablename__ = "category_schedules" __tablename__ = "category_schedules"
name: Mapped[catfk] = mapped_column(primary_key=True) name: Mapped[catfk] = mapped_column(primary_key=True)
@ -292,15 +403,6 @@ class CategorySchedule(Base, Export):
period_multiplier: Mapped[Optional[int]] period_multiplier: Mapped[Optional[int]]
amount: Mapped[Optional[int]] amount: Mapped[Optional[int]]
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
period=self.period,
period_multiplier=self.period_multiplier,
amount=self.amount,
)
class Link(Base): class Link(Base):
__tablename__ = "links" __tablename__ = "links"
@ -309,7 +411,7 @@ class Link(Base):
link: Mapped[idfk] = mapped_column(primary_key=True) link: Mapped[idfk] = mapped_column(primary_key=True)
class Rule(Base, Export, init=False): class Rule(Base, init=False):
__tablename__ = "rules" __tablename__ = "rules"
id: Mapped[idpk] = mapped_column(init=False) id: Mapped[idpk] = mapped_column(init=False)
@ -355,19 +457,6 @@ class Rule(Base, Export, init=False):
return False return False
@property
def format(self) -> dict[str, Any]:
return dict(
start=self.start,
end=self.end,
description=self.description,
regex=self.regex,
bank=self.bank,
min=self.min,
max=self.max,
type=self.type,
)
@staticmethod @staticmethod
def exists(r: Optional[Any], op: Callable[[Any], bool]) -> bool: def exists(r: Optional[Any], op: Callable[[Any], bool]) -> bool:
return op(r) if r is not None else True return op(r) if r is not None else True
@ -388,10 +477,6 @@ class CategoryRule(Rule):
"polymorphic_identity": "category_rule", "polymorphic_identity": "category_rule",
} }
@property
def format(self) -> dict[str, Any]:
return super().format | dict(name=self.name)
def __init__(self, name: str, **kwargs: Any) -> None: def __init__(self, name: str, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.name = name self.name = name
@ -412,10 +497,6 @@ class TagRule(Rule):
"polymorphic_identity": "tag_rule", "polymorphic_identity": "tag_rule",
} }
@property
def format(self) -> dict[str, Any]:
return super().format | dict(tag=self.tag)
def __init__(self, name: str, **kwargs: Any) -> None: def __init__(self, name: str, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.tag = name self.tag = name

View File

@ -1,54 +0,0 @@
from collections.abc import Mapping
from dataclasses import fields
from functools import singledispatch
from typing import Any
from pfbudget.db.model import Transaction, TransactionCategory, TransactionTag, Note
class NotSerializableError(Exception):
pass
@singledispatch
def serialize(obj: Any) -> Mapping[str, Any]:
return {field.name: getattr(obj, field.name) for field in fields(obj)}
@serialize.register
def _(obj: Transaction) -> Mapping[str, Any]:
category = None
if obj.category:
category = {
"name": obj.category.name,
"selector": str(obj.category.selector)
if obj.category.selector
else None,
}
return dict(
id=obj.id,
date=obj.date.isoformat(),
description=obj.description,
amount=str(obj.amount),
split=obj.split,
category=category if category else None,
tags=[{"tag": tag.tag} for tag in obj.tags],
note=obj.note,
type=obj.type,
)
@serialize.register
def _(_: TransactionCategory) -> Mapping[str, Any]:
raise NotSerializableError("TransactionCategory")
@serialize.register
def _(_: TransactionTag) -> Mapping[str, Any]:
raise NotSerializableError("TransactionTag")
@serialize.register
def _(_: Note) -> Mapping[str, Any]:
raise NotSerializableError("Note")

View File

@ -1,11 +1,20 @@
from decimal import Decimal from decimal import Decimal
from pfbudget.db.model import Category, CategoryRule, Tag, TagRule from pfbudget.db.model import Category, CategoryGroup, CategoryRule, Tag, TagRule
category_null = Category("null") category_null = Category("null")
categorygroup1 = CategoryGroup("group#1")
category1 = Category( category1 = Category(
"cat#1", "cat#1",
"group#1",
rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))],
)
category2 = Category(
"cat#2",
"group#1",
rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))], rules=[CategoryRule("cat#1", description="desc#1", max=Decimal(0))],
) )

11
tests/mocks/client.py Normal file
View File

@ -0,0 +1,11 @@
from pfbudget.db.client import Client
from pfbudget.db.model import Base
class MockClient(Client):
def __init__(self):
url = "sqlite://"
super().__init__(
url, execution_options={"schema_translate_map": {"pfbudget": None}}
)
Base.metadata.create_all(self.engine)

View File

@ -2,7 +2,10 @@ from datetime import date
from decimal import Decimal from decimal import Decimal
from pfbudget.db.model import ( from pfbudget.db.model import (
BankTransaction,
CategorySelector, CategorySelector,
MoneyTransaction,
SplitTransaction,
Transaction, Transaction,
TransactionCategory, TransactionCategory,
) )
@ -26,3 +29,22 @@ simple_transformed = [
category=TransactionCategory("category#2", CategorySelector.algorithm), category=TransactionCategory("category#2", CategorySelector.algorithm),
), ),
] ]
bank = [
BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#1"),
BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#2"),
]
money = [
MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")),
MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")),
]
__original = Transaction(date(2023, 1, 1), "", Decimal("-10"), split=True)
__original.id = 1
split = [
__original,
SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id),
SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id),
]

60
tests/test_backup.py Normal file
View File

@ -0,0 +1,60 @@
from pathlib import Path
from typing import Any, Sequence, Type
import pytest
from mocks import banks, categories, transactions
from mocks.client import MockClient
from pfbudget.common.types import ExportFormat
from pfbudget.core.command import ExportCommand, ImportCommand
from pfbudget.db.client import Client
from pfbudget.db.model import Bank, Category, CategoryGroup, Tag, Transaction
@pytest.fixture
def client() -> Client:
return MockClient()
params = [
(transactions.simple, Transaction),
(transactions.simple_transformed, Transaction),
(transactions.bank, Transaction),
(transactions.money, Transaction),
# (transactions.split, Transaction), NotImplemented
([banks.checking, banks.cc], Bank),
([categories.category_null, categories.category1, categories.category2], Category),
(
[
categories.categorygroup1,
categories.category_null,
categories.category1,
categories.category2,
],
CategoryGroup,
),
([categories.tag_1], Tag),
]
class TestBackup:
@pytest.mark.parametrize("input, what", params)
def test_import(self, tmp_path: Path, input: Sequence[Any], what: Type[Any]):
file = tmp_path / "test.json"
client = MockClient()
client.insert(input)
originals = client.select(what)
assert originals
command = ExportCommand(client, what, file, ExportFormat.JSON)
command.execute()
other = MockClient()
command = ImportCommand(other, what, file, ExportFormat.JSON)
command.execute()
imported = other.select(what)
assert originals == imported

View File

@ -2,15 +2,15 @@ from collections.abc import Sequence
import json import json
from pathlib import Path from pathlib import Path
import pickle import pickle
from typing import Any import pytest
from typing import Any, cast
import mocks.transactions import mocks.transactions
from pfbudget.common.types import ExportFormat from pfbudget.common.types import ExportFormat
from pfbudget.core.command import ExportCommand from pfbudget.core.command import ExportCommand, ImportCommand
from pfbudget.db.client import Client from pfbudget.db.client import Client
from pfbudget.db.model import Transaction from pfbudget.db.model import Transaction
from pfbudget.utils.serializer import serialize
class FakeClient(Client): class FakeClient(Client):
@ -40,8 +40,13 @@ class FakeClient(Client):
self._transactions = value self._transactions = value
@pytest.fixture
def client() -> Client:
return FakeClient()
class TestCommand: class TestCommand:
def test_export_json(self, tmp_path: Path): def test_export_json(self, tmp_path: Path, client: Client):
client = FakeClient() client = FakeClient()
file = tmp_path / "test.json" file = tmp_path / "test.json"
command = ExportCommand(client, Transaction, file, ExportFormat.JSON) command = ExportCommand(client, Transaction, file, ExportFormat.JSON)
@ -49,19 +54,18 @@ class TestCommand:
with open(file, newline="") as f: with open(file, newline="") as f:
result = json.load(f) result = json.load(f)
assert result == [serialize(t) for t in mocks.transactions.simple] assert result == [t.serialize() for t in mocks.transactions.simple]
client.transactions = mocks.transactions.simple_transformed cast(FakeClient, client).transactions = mocks.transactions.simple_transformed
command.execute() command.execute()
with open(file, newline="") as f: with open(file, newline="") as f:
result = json.load(f) result = json.load(f)
assert result == [ assert result == [
serialize(t) for t in mocks.transactions.simple_transformed t.serialize() for t in mocks.transactions.simple_transformed
] ]
def test_export_pickle(self, tmp_path: Path): def test_export_pickle(self, tmp_path: Path, client: Client):
client = FakeClient()
file = tmp_path / "test.pickle" file = tmp_path / "test.pickle"
command = ExportCommand(client, Transaction, file, ExportFormat.pickle) command = ExportCommand(client, Transaction, file, ExportFormat.pickle)
command.execute() command.execute()
@ -70,9 +74,22 @@ class TestCommand:
result = pickle.load(f) result = pickle.load(f)
assert result == mocks.transactions.simple assert result == mocks.transactions.simple
client.transactions = mocks.transactions.simple_transformed cast(FakeClient, client).transactions = mocks.transactions.simple_transformed
command.execute() command.execute()
with open(file, "rb") as f: with open(file, "rb") as f:
result = pickle.load(f) result = pickle.load(f)
assert result == mocks.transactions.simple_transformed assert result == mocks.transactions.simple_transformed
def test_import(self, tmp_path: Path, client: Client):
file = tmp_path / "test"
for format in list(ExportFormat):
command = ExportCommand(client, Transaction, file, format)
command.execute()
command = ImportCommand(client, Transaction, file, format)
command.execute()
transactions = cast(FakeClient, client).transactions
assert len(transactions) > 0
assert transactions == client.select(Transaction)

View File

@ -2,11 +2,12 @@ from datetime import date
from decimal import Decimal from decimal import Decimal
import pytest import pytest
from mocks.client import MockClient
from pfbudget.db.client import Client from pfbudget.db.client import Client
from pfbudget.db.model import ( from pfbudget.db.model import (
AccountType, AccountType,
Bank, Bank,
Base,
Nordigen, Nordigen,
CategorySelector, CategorySelector,
Transaction, Transaction,
@ -16,10 +17,7 @@ from pfbudget.db.model import (
@pytest.fixture @pytest.fixture
def client() -> Client: def client() -> Client:
url = "sqlite://" return MockClient()
client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}})
Base.metadata.create_all(client.engine)
return client
@pytest.fixture @pytest.fixture