Full backup creation and import commands

Using the same logic as the single Export/Import commands, implement the
entire backup command by exporting all the serializable classes into a
single json file.
To select the correct class upon import, save a new property on the
backup json, the class_, which contains the name of the class to be
imported.

Fix the note serialization.
This commit is contained in:
Luís Murta 2023-05-16 22:08:21 +01:00
parent 2cf0ba4374
commit e6622d1e19
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
4 changed files with 111 additions and 13 deletions

View File

@ -6,7 +6,17 @@ from typing import Type
from pfbudget.common.types import ExportFormat
from pfbudget.db.client import Client
from pfbudget.db.model import Serializable
from pfbudget.db.model import (
Bank,
Category,
CategoryGroup,
Serializable,
Tag,
Transaction,
)
# required for the backup import
import pfbudget.db.model
class Command(ABC):
@ -68,3 +78,51 @@ class ImportCommand(Command):
class ImportFailedError(Exception):
pass
class BackupCommand(Command):
def __init__(self, client: Client, fn: Path, format: ExportFormat) -> None:
self.__client = client
self.fn = fn
self.format = format
def execute(self) -> None:
banks = self.__client.select(Bank)
groups = self.__client.select(CategoryGroup)
categories = self.__client.select(Category)
tags = self.__client.select(Tag)
transactions = self.__client.select(Transaction)
values = [*banks, *groups, *categories, *tags, *transactions]
match self.format:
case ExportFormat.JSON:
with open(self.fn, "w", newline="") as f:
json.dump([e.serialize() for e in values], f, indent=4)
case ExportFormat.pickle:
raise AttributeError("pickle export not working at the moment!")
class ImportBackupCommand(Command):
def __init__(self, client: Client, fn: Path, format: ExportFormat) -> None:
self.__client = client
self.fn = fn
self.format = format
def execute(self) -> None:
match self.format:
case ExportFormat.JSON:
with open(self.fn, "r") as f:
try:
values = json.load(f)
values = [
getattr(pfbudget.db.model, v["class_"]).deserialize(v)
for v in values
]
except json.JSONDecodeError as e:
raise ImportFailedError(e)
case ExportFormat.pickle:
raise AttributeError("pickle import not working at the moment!")
self.__client.insert(values)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping, MutableMapping, Sequence
from dataclasses import dataclass, fields
from dataclasses import dataclass
import datetime as dt
import decimal
import enum
@ -46,11 +46,11 @@ class Base(MappedAsDataclass, DeclarativeBase):
@dataclass
class Serializable:
def serialize(self) -> Mapping[str, Any]:
return {field.name: getattr(self, field.name) for field in fields(self)}
return dict(class_=type(self).__name__)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(**map)
raise NotImplementedError
class AccountType(enum.Enum):
@ -80,7 +80,7 @@ class Bank(Base, Serializable):
"invert": self.nordigen.invert,
}
return dict(
return super().serialize() | dict(
name=self.name,
BIC=self.BIC,
type=self.type.name,
@ -137,7 +137,7 @@ class Transaction(Base, Serializable):
"selector": self.category.selector.name,
}
return dict(
return super().serialize() | dict(
id=self.id,
date=self.date.isoformat(),
description=self.description,
@ -145,7 +145,7 @@ class Transaction(Base, Serializable):
split=self.split,
category=category if category else None,
tags=[{"tag": tag.tag} for tag in self.tags],
note=self.note,
note={"note": self.note.note} if self.note else None,
type=self.type,
)
@ -175,6 +175,10 @@ class Transaction(Base, Serializable):
if map["tags"]:
tags = set(TransactionTag(t["tag"]) for t in map["tags"])
note = None
if map["note"]:
note = Note(map["note"]["note"])
result = cls(
dt.date.fromisoformat(map["date"]),
map["description"],
@ -182,7 +186,7 @@ class Transaction(Base, Serializable):
map["split"],
category,
tags,
map["note"],
note,
)
if map["id"]:
@ -248,6 +252,17 @@ class CategoryGroup(Base, Serializable):
name: Mapped[str] = mapped_column(primary_key=True)
categories: Mapped[list[Category]] = relationship(
default_factory=list, lazy="joined"
)
def serialize(self) -> Mapping[str, Any]:
return super().serialize() | dict(name=self.name)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(map["name"])
class Category(Base, Serializable, repr=False):
__tablename__ = "categories"
@ -290,7 +305,7 @@ class Category(Base, Serializable, repr=False):
"amount": self.schedule.amount,
}
return dict(
return super().serialize() | dict(
name=self.name,
group=self.group,
rules=rules,
@ -398,7 +413,7 @@ class Tag(Base, Serializable):
}
)
return dict(name=self.name, rules=rules)
return super().serialize() | dict(name=self.name, rules=rules)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:

View File

@ -43,7 +43,7 @@ money = [
]
__original = Transaction(date(2023, 1, 1), "", Decimal("-10"), split=True)
__original.id = 1
__original.id = 9000
split = [
__original,

View File

@ -6,7 +6,13 @@ from mocks import banks, categories, transactions
from mocks.client import MockClient
from pfbudget.common.types import ExportFormat
from pfbudget.core.command import ExportCommand, ImportCommand, ImportFailedError
from pfbudget.core.command import (
BackupCommand,
ExportCommand,
ImportBackupCommand,
ImportCommand,
ImportFailedError,
)
from pfbudget.db.client import Client
from pfbudget.db.model import (
Bank,
@ -87,7 +93,6 @@ class TestBackup:
with pytest.raises(AttributeError):
command.execute()
@pytest.mark.parametrize("input, what", not_serializable)
def test_try_backup_not_serializable(
self, tmp_path: Path, input: Sequence[Any], what: Type[Any]
@ -112,3 +117,23 @@ class TestBackup:
imported = other.select(what)
assert not imported
def test_full_backup(self, tmp_path: Path):
file = tmp_path / "test.json"
client = MockClient()
client.insert([e for t in params for e in t[0]])
originals = client.select(Transaction)
assert originals
command = BackupCommand(client, file, ExportFormat.JSON)
command.execute()
other = MockClient()
command = ImportBackupCommand(other, file, ExportFormat.JSON)
command.execute()
imported = other.select(Transaction)
assert originals == imported