Clear up forge/dismantle logic
This commit is contained in:
parent
f7df033d58
commit
6f68d971ee
@ -237,7 +237,7 @@ if __name__ == "__main__":
|
|||||||
keys = {"original", "links"}
|
keys = {"original", "links"}
|
||||||
assert args.keys() >= keys, f"missing {args.keys() - keys}"
|
assert args.keys() >= keys, f"missing {args.keys() - keys}"
|
||||||
|
|
||||||
params = [type.Link(args["original"][0], link) for link in args["links"]]
|
params = [args["original"][0], args["links"]]
|
||||||
|
|
||||||
case (
|
case (
|
||||||
Operation.Export
|
Operation.Export
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from pfbudget.db.model import (
|
|||||||
MoneyTransaction,
|
MoneyTransaction,
|
||||||
Nordigen,
|
Nordigen,
|
||||||
Rule,
|
Rule,
|
||||||
|
Selector_T,
|
||||||
SplitTransaction,
|
SplitTransaction,
|
||||||
Tag,
|
Tag,
|
||||||
TagRule,
|
TagRule,
|
||||||
@ -168,8 +169,32 @@ class Manager:
|
|||||||
session.remove_by_name(CategoryGroup, params)
|
session.remove_by_name(CategoryGroup, params)
|
||||||
|
|
||||||
case Operation.Forge:
|
case Operation.Forge:
|
||||||
|
if not (
|
||||||
|
isinstance(params[0], int)
|
||||||
|
and all(isinstance(p, int) for p in params[1])
|
||||||
|
):
|
||||||
|
raise TypeError("f{params} are not transaction ids")
|
||||||
|
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(params)
|
original = session.get(Transaction, Transaction.id, params[0])[0]
|
||||||
|
links = session.get(Transaction, Transaction.id, params[1])
|
||||||
|
|
||||||
|
if not original.category:
|
||||||
|
original.category = self.askcategory(original)
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
if (
|
||||||
|
not link.category
|
||||||
|
or link.category.name != original.category.name
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"{link} category will change to"
|
||||||
|
f" {original.category.name}"
|
||||||
|
)
|
||||||
|
link.category = original.category
|
||||||
|
|
||||||
|
tobelinked = [Link(original.id, link.id) for link in links]
|
||||||
|
session.add(tobelinked)
|
||||||
|
|
||||||
case Operation.Dismantle:
|
case Operation.Dismantle:
|
||||||
assert all(isinstance(param, Link) for param in params)
|
assert all(isinstance(param, Link) for param in params)
|
||||||
@ -202,7 +227,8 @@ class Manager:
|
|||||||
if originals[0].date != t.date:
|
if originals[0].date != t.date:
|
||||||
t.date = originals[0].date
|
t.date = originals[0].date
|
||||||
print(
|
print(
|
||||||
f"{t.date} is different from original date {originals[0].date}, using original"
|
f"{t.date} is different from original date"
|
||||||
|
f" {originals[0].date}, using original"
|
||||||
)
|
)
|
||||||
|
|
||||||
splitted = SplitTransaction(
|
splitted = SplitTransaction(
|
||||||
@ -326,6 +352,17 @@ class Manager:
|
|||||||
def parse(self, filename: Path, args: dict):
|
def parse(self, filename: Path, args: dict):
|
||||||
return parse_data(filename, args)
|
return parse_data(filename, args)
|
||||||
|
|
||||||
|
def askcategory(self, transaction: Transaction):
|
||||||
|
selector = CategorySelector(Selector_T.manual)
|
||||||
|
|
||||||
|
with self.db.session() as session:
|
||||||
|
categories = session.get(Category)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
category = input(f"{transaction}: ")
|
||||||
|
if category in [c.name for c in categories]:
|
||||||
|
return TransactionCategory(category, selector)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dump(fn, sequence):
|
def dump(fn, sequence):
|
||||||
with open(fn, "wb") as f:
|
with open(fn, "wb") as f:
|
||||||
|
|||||||
@ -51,7 +51,10 @@ class DbClient:
|
|||||||
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
|
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
|
||||||
if column is not None:
|
if column is not None:
|
||||||
if values:
|
if values:
|
||||||
|
if isinstance(values, Sequence):
|
||||||
stmt = select(type).where(column.in_(values))
|
stmt = select(type).where(column.in_(values))
|
||||||
|
else:
|
||||||
|
stmt = select(type).where(column == values)
|
||||||
else:
|
else:
|
||||||
stmt = select(type).where(column)
|
stmt = select(type).where(column)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user