From 01df97ed46a7463d285bc5120507ecd5560ed711 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Murta?= Date: Mon, 1 May 2023 21:08:49 +0100 Subject: [PATCH] `back_populates` option on category relationship Due to the use of the dataclasses mixin on the SQLAlchemy types, a back_populates creates a RecursiveError when comparing two types. This occurs because the dataclass will overwrite the __eq__ operator, and it doesn't know when to stop comparing relationships. Removing the dataclasses isn't the best approach, since then __init__, __eq__ and __repr__ methods would have to be added to all types. Thus the solution was to remove the relationship on the child (on a one-to-one relationship) from the __eq__ operation, with the use of the compare parameter. Took the opportunity to define more logical __init__ methods on the `Rule` and child classes. Also revised the parameter options on some DB types. --- pfbudget/__main__.py | 28 ++++++++-------- pfbudget/db/model.py | 67 ++++++++++++++++++++++++--------------- tests/mocks/categories.py | 9 ++---- tests/test_database.py | 12 ++++--- tests/test_load.py | 4 +-- tests/test_psd2.py | 4 +-- tests/test_transform.py | 20 ++++++------ 7 files changed, 81 insertions(+), 63 deletions(-) diff --git a/pfbudget/__main__.py b/pfbudget/__main__.py index c7718ca..3f6c05b 100644 --- a/pfbudget/__main__.py +++ b/pfbudget/__main__.py @@ -163,14 +163,14 @@ if __name__ == "__main__": params = [ type.CategoryRule( - args["start"][0] if args["start"] else None, - args["end"][0] if args["end"] else None, - args["description"][0] if args["description"] else None, - args["regex"][0] if args["regex"] else None, - args["bank"][0] if args["bank"] else None, - args["min"][0] if args["min"] else None, - args["max"][0] if args["max"] else None, cat, + start=args["start"][0] if args["start"] else None, + end=args["end"][0] if args["end"] else None, + description=args["description"][0] if args["description"] else None, + regex=args["regex"][0] if args["regex"] else None, + bank=args["bank"][0] if args["bank"] else None, + min=args["min"][0] if args["min"] else None, + max=args["max"][0] if args["max"] else None, ) for cat in args["category"] ] @@ -215,14 +215,14 @@ if __name__ == "__main__": params = [ type.TagRule( - args["start"][0] if args["start"] else None, - args["end"][0] if args["end"] else None, - args["description"][0] if args["description"] else None, - args["regex"][0] if args["regex"] else None, - args["bank"][0] if args["bank"] else None, - args["min"][0] if args["min"] else None, - args["max"][0] if args["max"] else None, tag, + start=args["start"][0] if args["start"] else None, + end=args["end"][0] if args["end"] else None, + description=args["description"][0] if args["description"] else None, + regex=args["regex"][0] if args["regex"] else None, + bank=args["bank"][0] if args["bank"] else None, + min=args["min"][0] if args["min"] else None, + max=args["max"][0] if args["max"] else None, ) for tag in args["tag"] ] diff --git a/pfbudget/db/model.py b/pfbudget/db/model.py index 59f43a8..6337362 100644 --- a/pfbudget/db/model.py +++ b/pfbudget/db/model.py @@ -65,7 +65,7 @@ class Bank(Base, Export): BIC: Mapped[str] = mapped_column(String(8)) type: Mapped[accounttype] - nordigen: Mapped[Optional[Nordigen]] = relationship(lazy="joined", init=False) + nordigen: Mapped[Optional[Nordigen]] = relationship(init=False) @property def format(self) -> dict[str, Any]: @@ -98,16 +98,17 @@ class Transaction(Base, Export): description: Mapped[Optional[str]] amount: Mapped[money] - split: Mapped[bool] = mapped_column(init=False, default=False) + split: Mapped[bool] = mapped_column(default=False) + + category: Mapped[Optional[TransactionCategory]] = relationship( + back_populates="transaction", default=None + ) + tags: Mapped[set[TransactionTag]] = relationship(default_factory=set) + note: Mapped[Optional[Note]] = relationship( + cascade="all, delete-orphan", passive_deletes=True, default=None + ) type: Mapped[str] = mapped_column(init=False) - - category: Mapped[Optional[TransactionCategory]] = relationship(init=False) - note: Mapped[Optional[Note]] = relationship( - cascade="all, delete-orphan", init=False, passive_deletes=True - ) - tags: Mapped[set[TransactionTag]] = relationship(init=False) - __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "transaction"} @property @@ -134,7 +135,7 @@ idfk = Annotated[ class BankTransaction(Transaction): - bank: Mapped[bankfk] = mapped_column(nullable=True) + bank: Mapped[Optional[bankfk]] = mapped_column(default=None) __mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"} @@ -148,7 +149,7 @@ class MoneyTransaction(Transaction): class SplitTransaction(Transaction): - original: Mapped[idfk] = mapped_column(nullable=True) + original: Mapped[Optional[idfk]] = mapped_column(default=None) __mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"} @@ -204,6 +205,15 @@ catfk = Annotated[ ] +class Selector_T(enum.Enum): + unknown = enum.auto() + nullifier = enum.auto() + vacations = enum.auto() + rules = enum.auto() + algorithm = enum.auto() + manual = enum.auto() + + class TransactionCategory(Base, Export): __tablename__ = "transactions_categorized" @@ -211,7 +221,11 @@ class TransactionCategory(Base, Export): name: Mapped[catfk] selector: Mapped[CategorySelector] = relationship( - cascade="all, delete-orphan", lazy="joined" + cascade="all, delete-orphan", default=Selector_T.unknown + ) + + transaction: Mapped[Transaction] = relationship( + back_populates="category", init=False, compare=False ) @property @@ -234,7 +248,7 @@ class Nordigen(Base, Export): name: Mapped[bankfk] = mapped_column(primary_key=True) bank_id: Mapped[Optional[str]] requisition_id: Mapped[Optional[str]] - invert: Mapped[Optional[bool]] + invert: Mapped[Optional[bool]] = mapped_column(default=None) @property def format(self) -> dict[str, Any]: @@ -270,18 +284,9 @@ class TransactionTag(Base, Export): return hash(self.id) -class Selector_T(enum.Enum): - unknown = enum.auto() - nullifier = enum.auto() - vacations = enum.auto() - rules = enum.auto() - algorithm = enum.auto() - manual = enum.auto() - - categoryselector = Annotated[ Selector_T, - mapped_column(Enum(Selector_T, inherit_schema=True), default=Selector_T.unknown), + mapped_column(Enum(Selector_T, inherit_schema=True)), ] @@ -294,7 +299,7 @@ class CategorySelector(Base, Export): primary_key=True, init=False, ) - selector: Mapped[categoryselector] + selector: Mapped[categoryselector] = mapped_column(default=Selector_T.unknown) @property def format(self): @@ -336,7 +341,7 @@ class Link(Base): link: Mapped[idfk] = mapped_column(primary_key=True) -class Rule(Base, Export): +class Rule(Base, Export, init=False): __tablename__ = "rules" id: Mapped[idpk] = mapped_column(init=False) @@ -355,6 +360,10 @@ class Rule(Base, Export): "polymorphic_on": "type", } + def __init__(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + setattr(self, k, v) + def matches(self, t: BankTransaction) -> bool: valid = None if self.regex: @@ -415,6 +424,10 @@ class CategoryRule(Rule): def format(self) -> dict[str, Any]: return super().format | dict(name=self.name) + def __init__(self, name: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.name = name + def __hash__(self): return hash(self.id) @@ -438,5 +451,9 @@ class TagRule(Rule): def format(self) -> dict[str, Any]: return super().format | dict(tag=self.tag) + def __init__(self, name: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.tag = name + def __hash__(self): return hash(self.id) diff --git a/tests/mocks/categories.py b/tests/mocks/categories.py index ad5b8a6..de83881 100644 --- a/tests/mocks/categories.py +++ b/tests/mocks/categories.py @@ -2,14 +2,11 @@ from decimal import Decimal from pfbudget.db.model import Category, CategoryRule, Tag, TagRule -category_null = Category("null", None, set()) +category_null = Category("null") category1 = Category( "cat#1", - None, - {CategoryRule(None, None, "desc#1", None, None, None, Decimal(0), "cat#1")}, + rules={CategoryRule("cat#1", description="desc#1", max=Decimal(0))}, ) -tag_1 = Tag( - "tag#1", {TagRule(None, None, "desc#1", None, None, None, Decimal(0), "tag#1")} -) +tag_1 = Tag("tag#1", rules={TagRule("tag#1", description="desc#1", max=Decimal(0))}) diff --git a/tests/test_database.py b/tests/test_database.py index e6c5ca1..a29e6fe 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -39,12 +39,16 @@ def banks(client: Client) -> list[Bank]: @pytest.fixture def transactions(client: Client) -> list[Transaction]: transactions = [ - Transaction(date(2023, 1, 1), "", Decimal("-10")), + Transaction( + date(2023, 1, 1), + "", + Decimal("-10"), + category=TransactionCategory( + "category", CategorySelector(Selector_T.algorithm) + ), + ), Transaction(date(2023, 1, 2), "", Decimal("-50")), ] - transactions[0].category = TransactionCategory( - "name", CategorySelector(Selector_T.algorithm) - ) client.insert(transactions) for i, transaction in enumerate(transactions): diff --git a/tests/test_load.py b/tests/test_load.py index 2e438fb..13e4a19 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -31,8 +31,8 @@ class TestDatabaseLoad: def test_insert(self, loader: Loader): transactions = [ - BankTransaction(date(2023, 1, 1), "", Decimal("-500"), "Bank#1"), - BankTransaction(date(2023, 1, 2), "", Decimal("500"), "Bank#2"), + BankTransaction(date(2023, 1, 1), "", Decimal("-500"), bank="Bank#1"), + BankTransaction(date(2023, 1, 2), "", Decimal("500"), bank="Bank#2"), ] loader.load(transactions) diff --git a/tests/test_psd2.py b/tests/test_psd2.py index efc25b2..e4f7cf6 100644 --- a/tests/test_psd2.py +++ b/tests/test_psd2.py @@ -91,9 +91,9 @@ class TestExtractPSD2: def test_extract(self, extractor: Extractor, bank: Bank): assert extractor.extract(bank) == [ BankTransaction( - dt.date(2023, 1, 14), "string", Decimal("328.18"), "Bank#1" + dt.date(2023, 1, 14), "string", Decimal("328.18"), bank="Bank#1" ), BankTransaction( - dt.date(2023, 2, 14), "string", Decimal("947.26"), "Bank#1" + dt.date(2023, 2, 14), "string", Decimal("947.26"), bank="Bank#1" ), ] diff --git a/tests/test_transform.py b/tests/test_transform.py index ae9eb81..6d24662 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -20,8 +20,8 @@ from pfbudget.transform.transform import Transformer class TestTransform: def test_nullifier(self): transactions = [ - BankTransaction(date(2023, 1, 1), "", Decimal("-500"), "Bank#1"), - BankTransaction(date(2023, 1, 2), "", Decimal("500"), "Bank#2"), + BankTransaction(date(2023, 1, 1), "", Decimal("-500"), bank="Bank#1"), + BankTransaction(date(2023, 1, 2), "", Decimal("500"), bank="Bank#2"), ] for t in transactions: @@ -37,8 +37,8 @@ class TestTransform: def test_nullifier_inplace(self): transactions = [ - BankTransaction(date(2023, 1, 1), "", Decimal("-500"), "Bank#1"), - BankTransaction(date(2023, 1, 2), "", Decimal("500"), "Bank#2"), + BankTransaction(date(2023, 1, 1), "", Decimal("-500"), bank="Bank#1"), + BankTransaction(date(2023, 1, 2), "", Decimal("500"), bank="Bank#2"), ] for t in transactions: @@ -54,14 +54,14 @@ class TestTransform: def test_nullifier_with_rules(self): transactions = [ - BankTransaction(date(2023, 1, 1), "", Decimal("-500"), "Bank#1"), - BankTransaction(date(2023, 1, 2), "", Decimal("500"), "Bank#2"), + BankTransaction(date(2023, 1, 1), "", Decimal("-500"), bank="Bank#1"), + BankTransaction(date(2023, 1, 2), "", Decimal("500"), bank="Bank#2"), ] for t in transactions: assert not t.category - rules = [CategoryRule(None, None, None, None, "Bank#1", None, None, "null")] + rules = [CategoryRule("null", bank="Bank#1")] categorizer: Transformer = Nullifier(rules) transactions = categorizer.transform(transactions) @@ -69,7 +69,7 @@ class TestTransform: for t in transactions: assert not t.category - rules.append(CategoryRule(None, None, None, None, "Bank#2", None, None, "null")) + rules.append(CategoryRule("null", bank="Bank#2")) categorizer = Nullifier(rules) transactions = categorizer.transform(transactions) @@ -80,7 +80,7 @@ class TestTransform: def test_tagger(self): transactions = [ - BankTransaction(date(2023, 1, 1), "desc#1", Decimal("-10"), "Bank#1") + BankTransaction(date(2023, 1, 1), "desc#1", Decimal("-10"), bank="Bank#1") ] for t in transactions: @@ -94,7 +94,7 @@ class TestTransform: def test_categorize(self): transactions = [ - BankTransaction(date(2023, 1, 1), "desc#1", Decimal("-10"), "Bank#1") + BankTransaction(date(2023, 1, 1), "desc#1", Decimal("-10"), bank="Bank#1") ] for t in transactions: