From 6f68d971ee285bb899f640ceb7b0b547e2135293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Murta?= Date: Sat, 11 Feb 2023 22:48:04 +0000 Subject: [PATCH] Clear up forge/dismantle logic --- pfbudget/__main__.py | 2 +- pfbudget/core/manager.py | 41 ++++++++++++++++++++++++++++++++++++++-- pfbudget/db/client.py | 5 ++++- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/pfbudget/__main__.py b/pfbudget/__main__.py index 0132d84..98fcb65 100644 --- a/pfbudget/__main__.py +++ b/pfbudget/__main__.py @@ -237,7 +237,7 @@ if __name__ == "__main__": keys = {"original", "links"} assert args.keys() >= keys, f"missing {args.keys() - keys}" - params = [type.Link(args["original"][0], link) for link in args["links"]] + params = [args["original"][0], args["links"]] case ( Operation.Export diff --git a/pfbudget/core/manager.py b/pfbudget/core/manager.py index f620220..27df90b 100644 --- a/pfbudget/core/manager.py +++ b/pfbudget/core/manager.py @@ -17,6 +17,7 @@ from pfbudget.db.model import ( MoneyTransaction, Nordigen, Rule, + Selector_T, SplitTransaction, Tag, TagRule, @@ -168,8 +169,32 @@ class Manager: session.remove_by_name(CategoryGroup, params) case Operation.Forge: + if not ( + isinstance(params[0], int) + and all(isinstance(p, int) for p in params[1]) + ): + raise TypeError("f{params} are not transaction ids") + with self.db.session() as session: - session.add(params) + original = session.get(Transaction, Transaction.id, params[0])[0] + links = session.get(Transaction, Transaction.id, params[1]) + + if not original.category: + original.category = self.askcategory(original) + + for link in links: + if ( + not link.category + or link.category.name != original.category.name + ): + print( + f"{link} category will change to" + f" {original.category.name}" + ) + link.category = original.category + + tobelinked = [Link(original.id, link.id) for link in links] + session.add(tobelinked) case Operation.Dismantle: assert all(isinstance(param, Link) for param in params) @@ -202,7 +227,8 @@ class Manager: if originals[0].date != t.date: t.date = originals[0].date print( - f"{t.date} is different from original date {originals[0].date}, using original" + f"{t.date} is different from original date" + f" {originals[0].date}, using original" ) splitted = SplitTransaction( @@ -326,6 +352,17 @@ class Manager: def parse(self, filename: Path, args: dict): return parse_data(filename, args) + def askcategory(self, transaction: Transaction): + selector = CategorySelector(Selector_T.manual) + + with self.db.session() as session: + categories = session.get(Category) + + while True: + category = input(f"{transaction}: ") + if category in [c.name for c in categories]: + return TransactionCategory(category, selector) + @staticmethod def dump(fn, sequence): with open(fn, "wb") as f: diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 7d1fbe5..4c52820 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -51,7 +51,10 @@ class DbClient: def get(self, type: Type[T], column=None, values=None) -> Sequence[T]: if column is not None: if values: - stmt = select(type).where(column.in_(values)) + if isinstance(values, Sequence): + stmt = select(type).where(column.in_(values)) + else: + stmt = select(type).where(column == values) else: stmt = select(type).where(column) else: