Adds ImportFailedError for non-serializable types

This commit is contained in:
Luís Murta 2023-05-11 23:49:37 +01:00
parent 638b833c74
commit 729e15d4e8
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
4 changed files with 88 additions and 8 deletions

View File

@ -51,10 +51,18 @@ class ImportCommand(Command):
match self.format:
case ExportFormat.JSON:
with open(self.fn, "r") as f:
try:
values = json.load(f)
values = [self.what.deserialize(v) for v in values]
except json.JSONDecodeError as e:
raise ImportFailedError(e)
case ExportFormat.pickle:
with open(self.fn, "rb") as f:
values = pickle.load(f)
self.__client.insert(values)
class ImportFailedError(Exception):
pass

View File

@ -158,7 +158,7 @@ class Transaction(Base, Serializable):
case "money":
return MoneyTransaction.deserialize(map)
case "split":
raise NotImplementedError
return SplitTransaction.deserialize(map)
case _:
return cls._deserialize(map)
@ -193,7 +193,7 @@ idfk = Annotated[
]
class BankTransaction(Transaction, Serializable):
class BankTransaction(Transaction):
bank: Mapped[Optional[bankfk]] = mapped_column(default=None)
__mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"}
@ -210,7 +210,7 @@ class BankTransaction(Transaction, Serializable):
return transaction
class MoneyTransaction(Transaction, Serializable):
class MoneyTransaction(Transaction):
__mapper_args__ = {"polymorphic_identity": "money"}
def serialize(self) -> Mapping[str, Any]:
@ -226,6 +226,13 @@ class SplitTransaction(Transaction):
__mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"}
def serialize(self) -> Mapping[str, Any]:
raise AttributeError
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
raise AttributeError
class CategoryGroup(Base, Serializable):
__tablename__ = "category_groups"

View File

@ -5,9 +5,11 @@ from pfbudget.db.model import (
BankTransaction,
CategorySelector,
MoneyTransaction,
Note,
SplitTransaction,
Transaction,
TransactionCategory,
TransactionTag,
)
simple = [
@ -48,3 +50,21 @@ split = [
SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id),
SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id),
]
tagged = [
Transaction(
date(2023, 1, 1),
"",
Decimal("-10"),
tags={TransactionTag("tag#1"), TransactionTag("tag#1")},
)
]
noted = [
Transaction(
date(2023, 1, 1),
"",
Decimal("-10"),
note=Note("note#1"),
)
]

View File

@ -6,9 +6,21 @@ from mocks import banks, categories, transactions
from mocks.client import MockClient
from pfbudget.common.types import ExportFormat
from pfbudget.core.command import ExportCommand, ImportCommand
from pfbudget.core.command import ExportCommand, ImportCommand, ImportFailedError
from pfbudget.db.client import Client
from pfbudget.db.model import Bank, Category, CategoryGroup, Tag, Transaction
from pfbudget.db.model import (
Bank,
BankTransaction,
Category,
CategoryGroup,
MoneyTransaction,
Note,
SplitTransaction,
Tag,
Transaction,
TransactionCategory,
TransactionTag,
)
@pytest.fixture
@ -20,8 +32,9 @@ params = [
(transactions.simple, Transaction),
(transactions.simple_transformed, Transaction),
(transactions.bank, Transaction),
(transactions.bank, BankTransaction),
(transactions.money, Transaction),
# (transactions.split, Transaction), NotImplemented
(transactions.money, MoneyTransaction),
([banks.checking, banks.cc], Bank),
([categories.category_null, categories.category1, categories.category2], Category),
(
@ -36,6 +49,13 @@ params = [
([categories.tag_1], Tag),
]
not_serializable = [
(transactions.split, SplitTransaction),
(transactions.simple_transformed, TransactionCategory),
(transactions.tagged, TransactionTag),
(transactions.noted, Note),
]
class TestBackup:
@pytest.mark.parametrize("input, what", params)
@ -58,3 +78,28 @@ class TestBackup:
imported = other.select(what)
assert originals == imported
@pytest.mark.parametrize("input, what", not_serializable)
def test_try_backup_not_serializable(
self, tmp_path: Path, input: Sequence[Any], what: Type[Any]
):
file = tmp_path / "test.json"
client = MockClient()
client.insert(input)
originals = client.select(what)
assert originals
command = ExportCommand(client, what, file, ExportFormat.JSON)
with pytest.raises(AttributeError):
command.execute()
other = MockClient()
command = ImportCommand(other, what, file, ExportFormat.JSON)
with pytest.raises(ImportFailedError):
command.execute()
imported = other.select(what)
assert not imported