Clear up forge/dismantle logic
This commit is contained in:
parent
f7df033d58
commit
6f68d971ee
@ -237,7 +237,7 @@ if __name__ == "__main__":
|
||||
keys = {"original", "links"}
|
||||
assert args.keys() >= keys, f"missing {args.keys() - keys}"
|
||||
|
||||
params = [type.Link(args["original"][0], link) for link in args["links"]]
|
||||
params = [args["original"][0], args["links"]]
|
||||
|
||||
case (
|
||||
Operation.Export
|
||||
|
||||
@ -17,6 +17,7 @@ from pfbudget.db.model import (
|
||||
MoneyTransaction,
|
||||
Nordigen,
|
||||
Rule,
|
||||
Selector_T,
|
||||
SplitTransaction,
|
||||
Tag,
|
||||
TagRule,
|
||||
@ -168,8 +169,32 @@ class Manager:
|
||||
session.remove_by_name(CategoryGroup, params)
|
||||
|
||||
case Operation.Forge:
|
||||
if not (
|
||||
isinstance(params[0], int)
|
||||
and all(isinstance(p, int) for p in params[1])
|
||||
):
|
||||
raise TypeError("f{params} are not transaction ids")
|
||||
|
||||
with self.db.session() as session:
|
||||
session.add(params)
|
||||
original = session.get(Transaction, Transaction.id, params[0])[0]
|
||||
links = session.get(Transaction, Transaction.id, params[1])
|
||||
|
||||
if not original.category:
|
||||
original.category = self.askcategory(original)
|
||||
|
||||
for link in links:
|
||||
if (
|
||||
not link.category
|
||||
or link.category.name != original.category.name
|
||||
):
|
||||
print(
|
||||
f"{link} category will change to"
|
||||
f" {original.category.name}"
|
||||
)
|
||||
link.category = original.category
|
||||
|
||||
tobelinked = [Link(original.id, link.id) for link in links]
|
||||
session.add(tobelinked)
|
||||
|
||||
case Operation.Dismantle:
|
||||
assert all(isinstance(param, Link) for param in params)
|
||||
@ -202,7 +227,8 @@ class Manager:
|
||||
if originals[0].date != t.date:
|
||||
t.date = originals[0].date
|
||||
print(
|
||||
f"{t.date} is different from original date {originals[0].date}, using original"
|
||||
f"{t.date} is different from original date"
|
||||
f" {originals[0].date}, using original"
|
||||
)
|
||||
|
||||
splitted = SplitTransaction(
|
||||
@ -326,6 +352,17 @@ class Manager:
|
||||
def parse(self, filename: Path, args: dict):
|
||||
return parse_data(filename, args)
|
||||
|
||||
def askcategory(self, transaction: Transaction):
|
||||
selector = CategorySelector(Selector_T.manual)
|
||||
|
||||
with self.db.session() as session:
|
||||
categories = session.get(Category)
|
||||
|
||||
while True:
|
||||
category = input(f"{transaction}: ")
|
||||
if category in [c.name for c in categories]:
|
||||
return TransactionCategory(category, selector)
|
||||
|
||||
@staticmethod
|
||||
def dump(fn, sequence):
|
||||
with open(fn, "wb") as f:
|
||||
|
||||
@ -51,7 +51,10 @@ class DbClient:
|
||||
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
|
||||
if column is not None:
|
||||
if values:
|
||||
if isinstance(values, Sequence):
|
||||
stmt = select(type).where(column.in_(values))
|
||||
else:
|
||||
stmt = select(type).where(column == values)
|
||||
else:
|
||||
stmt = select(type).where(column)
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user