Clear up forge/dismantle logic

This commit is contained in:
Luís Murta 2023-02-11 22:48:04 +00:00
parent f7df033d58
commit 6f68d971ee
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
3 changed files with 44 additions and 4 deletions

View File

@ -237,7 +237,7 @@ if __name__ == "__main__":
keys = {"original", "links"} keys = {"original", "links"}
assert args.keys() >= keys, f"missing {args.keys() - keys}" assert args.keys() >= keys, f"missing {args.keys() - keys}"
params = [type.Link(args["original"][0], link) for link in args["links"]] params = [args["original"][0], args["links"]]
case ( case (
Operation.Export Operation.Export

View File

@ -17,6 +17,7 @@ from pfbudget.db.model import (
MoneyTransaction, MoneyTransaction,
Nordigen, Nordigen,
Rule, Rule,
Selector_T,
SplitTransaction, SplitTransaction,
Tag, Tag,
TagRule, TagRule,
@ -168,8 +169,32 @@ class Manager:
session.remove_by_name(CategoryGroup, params) session.remove_by_name(CategoryGroup, params)
case Operation.Forge: case Operation.Forge:
if not (
isinstance(params[0], int)
and all(isinstance(p, int) for p in params[1])
):
raise TypeError("f{params} are not transaction ids")
with self.db.session() as session: with self.db.session() as session:
session.add(params) original = session.get(Transaction, Transaction.id, params[0])[0]
links = session.get(Transaction, Transaction.id, params[1])
if not original.category:
original.category = self.askcategory(original)
for link in links:
if (
not link.category
or link.category.name != original.category.name
):
print(
f"{link} category will change to"
f" {original.category.name}"
)
link.category = original.category
tobelinked = [Link(original.id, link.id) for link in links]
session.add(tobelinked)
case Operation.Dismantle: case Operation.Dismantle:
assert all(isinstance(param, Link) for param in params) assert all(isinstance(param, Link) for param in params)
@ -202,7 +227,8 @@ class Manager:
if originals[0].date != t.date: if originals[0].date != t.date:
t.date = originals[0].date t.date = originals[0].date
print( print(
f"{t.date} is different from original date {originals[0].date}, using original" f"{t.date} is different from original date"
f" {originals[0].date}, using original"
) )
splitted = SplitTransaction( splitted = SplitTransaction(
@ -326,6 +352,17 @@ class Manager:
def parse(self, filename: Path, args: dict): def parse(self, filename: Path, args: dict):
return parse_data(filename, args) return parse_data(filename, args)
def askcategory(self, transaction: Transaction):
selector = CategorySelector(Selector_T.manual)
with self.db.session() as session:
categories = session.get(Category)
while True:
category = input(f"{transaction}: ")
if category in [c.name for c in categories]:
return TransactionCategory(category, selector)
@staticmethod @staticmethod
def dump(fn, sequence): def dump(fn, sequence):
with open(fn, "wb") as f: with open(fn, "wb") as f:

View File

@ -51,7 +51,10 @@ class DbClient:
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]: def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
if column is not None: if column is not None:
if values: if values:
stmt = select(type).where(column.in_(values)) if isinstance(values, Sequence):
stmt = select(type).where(column.in_(values))
else:
stmt = select(type).where(column == values)
else: else:
stmt = select(type).where(column) stmt = select(type).where(column)
else: else: