[Fix] Eager loading for subclasses

and CategoryGroup import/export.
This commit is contained in:
Luís Murta 2023-05-17 21:38:48 +01:00
parent e6622d1e19
commit d11f753aa0
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
2 changed files with 12 additions and 9 deletions

View File

@ -252,10 +252,6 @@ class CategoryGroup(Base, Serializable):
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
categories: Mapped[list[Category]] = relationship(
default_factory=list, lazy="joined"
)
def serialize(self) -> Mapping[str, Any]: def serialize(self) -> Mapping[str, Any]:
return super().serialize() | dict(name=self.name) return super().serialize() | dict(name=self.name)
@ -524,6 +520,7 @@ class CategoryRule(Rule):
__mapper_args__ = { __mapper_args__ = {
"polymorphic_identity": "category_rule", "polymorphic_identity": "category_rule",
"polymorphic_load": "selectin",
} }
@ -542,4 +539,5 @@ class TagRule(Rule):
__mapper_args__ = { __mapper_args__ = {
"polymorphic_identity": "tag_rule", "polymorphic_identity": "tag_rule",
"polymorphic_load": "selectin",
} }

View File

@ -17,6 +17,7 @@ from pfbudget.db.client import Client
from pfbudget.db.model import ( from pfbudget.db.model import (
Bank, Bank,
BankTransaction, BankTransaction,
Base,
Category, Category,
CategoryGroup, CategoryGroup,
MoneyTransaction, MoneyTransaction,
@ -123,9 +124,6 @@ class TestBackup:
client = MockClient() client = MockClient()
client.insert([e for t in params for e in t[0]]) client.insert([e for t in params for e in t[0]])
originals = client.select(Transaction)
assert originals
command = BackupCommand(client, file, ExportFormat.JSON) command = BackupCommand(client, file, ExportFormat.JSON)
command.execute() command.execute()
@ -134,6 +132,13 @@ class TestBackup:
command = ImportBackupCommand(other, file, ExportFormat.JSON) command = ImportBackupCommand(other, file, ExportFormat.JSON)
command.execute() command.execute()
imported = other.select(Transaction) def subclasses(cls: Type[Any]) -> set[Type[Any]]:
return set(cls.__subclasses__()) | {
s for c in cls.__subclasses__() for s in subclasses(c)
}
assert originals == imported for t in [cls for cls in subclasses(Base)]:
originals = client.select(t)
imported = other.select(t)
assert originals == imported, f"{t}"