[Refactor] Clean out old DB client class
Swap almost all remaining calls to the old postgresql only DB class with the new DB client. Warning! Some operations are currently not implement, such as setting category schedules and dismantling links. `update` and `delete` methods added to DB `Client`.
This commit is contained in:
parent
da44ba5306
commit
13c783ca0e
@ -130,12 +130,12 @@ if __name__ == "__main__":
|
||||
keys = {"category", "group"}
|
||||
assert args.keys() >= keys, f"missing {args.keys() - keys}"
|
||||
|
||||
params = [type.Category(cat) for cat in args["category"]]
|
||||
params.append(args["group"])
|
||||
params = [{"name": cat, "group": args["group"]} for cat in args["category"]]
|
||||
|
||||
case Operation.CategoryRemove:
|
||||
assert "category" in args, "argparser ill defined"
|
||||
params = [type.Category(cat) for cat in args["category"]]
|
||||
|
||||
params = args["category"]
|
||||
|
||||
case Operation.CategorySchedule:
|
||||
keys = {"category", "period", "frequency"}
|
||||
@ -246,7 +246,7 @@ if __name__ == "__main__":
|
||||
|
||||
case Operation.GroupRemove:
|
||||
assert "group" in args, "argparser ill defined"
|
||||
params = [type.CategoryGroup(group) for group in args["group"]]
|
||||
params = args["group"]
|
||||
|
||||
case Operation.Forge | Operation.Dismantle:
|
||||
keys = {"original", "links"}
|
||||
|
||||
@ -26,7 +26,6 @@ from pfbudget.db.model import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
)
|
||||
from pfbudget.db.postgresql import DbClient
|
||||
from pfbudget.extract.nordigen import NordigenClient, NordigenCredentialsManager
|
||||
from pfbudget.extract.parsers import parse_data
|
||||
from pfbudget.extract.psd2 import PSD2Extractor
|
||||
@ -111,20 +110,16 @@ class Manager:
|
||||
Tagger(rules).transform_inplace(uncategorized)
|
||||
|
||||
case Operation.BankMod:
|
||||
with self.db.session() as session:
|
||||
session.update(Bank, params)
|
||||
self.database.update(Bank, params)
|
||||
|
||||
case Operation.PSD2Mod:
|
||||
with self.db.session() as session:
|
||||
session.update(Nordigen, params)
|
||||
self.database.update(Nordigen, params)
|
||||
|
||||
case Operation.BankDel:
|
||||
with self.db.session() as session:
|
||||
session.remove_by_name(Bank, params)
|
||||
self.database.delete(Bank, Bank.name, params)
|
||||
|
||||
case Operation.PSD2Del:
|
||||
with self.db.session() as session:
|
||||
session.remove_by_name(Nordigen, params)
|
||||
self.database.delete(Nordigen, Nordigen.name, params)
|
||||
|
||||
case Operation.Token:
|
||||
Manager.nordigen_client().generate_token()
|
||||
@ -150,40 +145,28 @@ class Manager:
|
||||
self.database.insert(params)
|
||||
|
||||
case Operation.CategoryUpdate:
|
||||
with self.db.session() as session:
|
||||
session.updategroup(*params)
|
||||
self.database.update(Category, params)
|
||||
|
||||
case Operation.CategoryRemove:
|
||||
with self.db.session() as session:
|
||||
session.remove_by_name(Category, params)
|
||||
self.database.delete(Category, Category.name, params)
|
||||
|
||||
case Operation.CategorySchedule:
|
||||
with self.db.session() as session:
|
||||
session.updateschedules(params)
|
||||
raise NotImplementedError
|
||||
|
||||
case Operation.RuleRemove:
|
||||
assert all(isinstance(param, int) for param in params)
|
||||
with self.db.session() as session:
|
||||
session.remove_by_id(CategoryRule, params)
|
||||
self.database.delete(CategoryRule, CategoryRule.id, params)
|
||||
|
||||
case Operation.TagRemove:
|
||||
with self.db.session() as session:
|
||||
session.remove_by_name(Tag, params)
|
||||
self.database.delete(Tag, Tag.name, params)
|
||||
|
||||
case Operation.TagRuleRemove:
|
||||
assert all(isinstance(param, int) for param in params)
|
||||
with self.db.session() as session:
|
||||
session.remove_by_id(TagRule, params)
|
||||
self.database.delete(TagRule, TagRule.id, params)
|
||||
|
||||
case Operation.RuleModify | Operation.TagRuleModify:
|
||||
assert all(isinstance(param, dict) for param in params)
|
||||
with self.db.session() as session:
|
||||
session.update(Rule, params)
|
||||
self.database.update(Rule, params)
|
||||
|
||||
case Operation.GroupRemove:
|
||||
assert all(isinstance(param, CategoryGroup) for param in params)
|
||||
with self.db.session() as session:
|
||||
session.remove_by_name(CategoryGroup, params)
|
||||
self.database.delete(CategoryGroup, CategoryGroup.name, params)
|
||||
|
||||
case Operation.Forge:
|
||||
if not (
|
||||
@ -192,9 +175,14 @@ class Manager:
|
||||
):
|
||||
raise TypeError("f{params} are not transaction ids")
|
||||
|
||||
with self.db.session() as session:
|
||||
original = session.get(Transaction, Transaction.id, params[0])[0]
|
||||
links = session.get(Transaction, Transaction.id, params[1])
|
||||
with self.database.session as session:
|
||||
id = params[0]
|
||||
original = session.select(
|
||||
Transaction, lambda: Transaction.id == id
|
||||
)[0]
|
||||
|
||||
ids = params[1]
|
||||
links = session.select(Transaction, lambda: Transaction.id.in_(ids))
|
||||
|
||||
if not original.category:
|
||||
original.category = self.askcategory(original)
|
||||
@ -214,12 +202,7 @@ class Manager:
|
||||
session.insert(tobelinked)
|
||||
|
||||
case Operation.Dismantle:
|
||||
assert all(isinstance(param, Link) for param in params)
|
||||
|
||||
with self.db.session() as session:
|
||||
original = params[0].original
|
||||
links = [link.link for link in params]
|
||||
session.remove_links(original, links)
|
||||
raise NotImplementedError
|
||||
|
||||
case Operation.Split:
|
||||
if len(params) < 1 and not all(
|
||||
@ -234,8 +217,10 @@ class Manager:
|
||||
f"{original.amount}€ != {sum(v for v, _ in params[1:])}€"
|
||||
)
|
||||
|
||||
with self.db.session() as session:
|
||||
originals = session.get(Transaction, Transaction.id, [original.id])
|
||||
with self.database.session as session:
|
||||
originals = session.select(
|
||||
Transaction, lambda: Transaction.id == original.id
|
||||
)
|
||||
assert len(originals) == 1, ">1 transactions matched {original.id}!"
|
||||
|
||||
originals[0].split = True
|
||||
@ -293,8 +278,7 @@ class Manager:
|
||||
transactions.append(transaction)
|
||||
|
||||
if self.certify(transactions):
|
||||
with self.db.session() as session:
|
||||
session.insert(transactions)
|
||||
self.database.insert(transactions)
|
||||
|
||||
case Operation.ExportBanks:
|
||||
with self.database.session as session:
|
||||
@ -309,8 +293,7 @@ class Manager:
|
||||
banks.append(bank)
|
||||
|
||||
if self.certify(banks):
|
||||
with self.db.session() as session:
|
||||
session.insert(banks)
|
||||
self.database.insert(banks)
|
||||
|
||||
case Operation.ExportCategoryRules:
|
||||
with self.database.session as session:
|
||||
@ -324,8 +307,7 @@ class Manager:
|
||||
rules = [CategoryRule(**row) for row in self.load(params[0], params[1])]
|
||||
|
||||
if self.certify(rules):
|
||||
with self.db.session() as session:
|
||||
session.insert(rules)
|
||||
self.database.insert(rules)
|
||||
|
||||
case Operation.ExportTagRules:
|
||||
with self.database.session as session:
|
||||
@ -337,8 +319,7 @@ class Manager:
|
||||
rules = [TagRule(**row) for row in self.load(params[0], params[1])]
|
||||
|
||||
if self.certify(rules):
|
||||
with self.db.session() as session:
|
||||
session.insert(rules)
|
||||
self.database.insert(rules)
|
||||
|
||||
case Operation.ExportCategories:
|
||||
with self.database.session as session:
|
||||
@ -363,8 +344,7 @@ class Manager:
|
||||
categories.append(category)
|
||||
|
||||
if self.certify(categories):
|
||||
with self.db.session() as session:
|
||||
session.insert(categories)
|
||||
self.database.insert(categories)
|
||||
|
||||
case Operation.ExportCategoryGroups:
|
||||
with self.database.session as session:
|
||||
@ -380,8 +360,7 @@ class Manager:
|
||||
]
|
||||
|
||||
if self.certify(groups):
|
||||
with self.db.session() as session:
|
||||
session.insert(groups)
|
||||
self.database.insert(groups)
|
||||
|
||||
def parse(self, filename: Path, args: dict):
|
||||
return parse_data(filename, args)
|
||||
@ -389,13 +368,12 @@ class Manager:
|
||||
def askcategory(self, transaction: Transaction):
|
||||
selector = CategorySelector(Selector_T.manual)
|
||||
|
||||
with self.db.session() as session:
|
||||
categories = session.get(Category)
|
||||
categories = self.database.select(Category)
|
||||
|
||||
while True:
|
||||
category = input(f"{transaction}: ")
|
||||
if category in [c.name for c in categories]:
|
||||
return TransactionCategory(category, selector)
|
||||
while True:
|
||||
category = input(f"{transaction}: ")
|
||||
if category in [c.name for c in categories]:
|
||||
return TransactionCategory(category, selector)
|
||||
|
||||
@staticmethod
|
||||
def dump(fn, format, sequence):
|
||||
@ -428,20 +406,12 @@ class Manager:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def db(self) -> DbClient:
|
||||
return DbClient(self._db, self._verbosity > 2)
|
||||
|
||||
@property
|
||||
def database(self) -> Client:
|
||||
if not self._database:
|
||||
self._database = Client(self._db, echo=self._verbosity > 2)
|
||||
return self._database
|
||||
|
||||
@db.setter
|
||||
def db(self, url: str):
|
||||
self._db = url
|
||||
|
||||
@staticmethod
|
||||
def nordigen_client() -> NordigenClient:
|
||||
return NordigenClient(NordigenCredentialsManager.default)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from sqlalchemy import Engine, create_engine, select
|
||||
from sqlalchemy import Engine, create_engine, delete, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
from typing import Any, Mapping, Optional, Type, TypeVar
|
||||
|
||||
# from pfbudget.db.exceptions import InsertError, SelectError
|
||||
|
||||
@ -52,6 +52,14 @@ class Client:
|
||||
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
|
||||
return self.session.select(what, exists)
|
||||
|
||||
def update(self, what: Type[Any], values: Sequence[Mapping[str, Any]]) -> None:
|
||||
with self._sessionmaker() as session, session.begin():
|
||||
session.execute(update(what), values)
|
||||
|
||||
def delete(self, what: Type[Any], column: Any, values: Sequence[str]) -> None:
|
||||
with self._sessionmaker() as session, session.begin():
|
||||
session.execute(delete(what).where(column.in_(values)))
|
||||
|
||||
@property
|
||||
def engine(self) -> Engine:
|
||||
return self._engine
|
||||
|
||||
@ -1,123 +0,0 @@
|
||||
from dataclasses import asdict
|
||||
from sqlalchemy import create_engine, delete, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import false
|
||||
from typing import Sequence, Type, TypeVar
|
||||
|
||||
from pfbudget.db.model import (
|
||||
Category,
|
||||
CategoryGroup,
|
||||
CategorySchedule,
|
||||
Link,
|
||||
Transaction,
|
||||
)
|
||||
|
||||
|
||||
class DbClient:
|
||||
"""
|
||||
General database client using sqlalchemy
|
||||
"""
|
||||
|
||||
__sessions: list[Session]
|
||||
|
||||
def __init__(self, url: str, echo=False) -> None:
|
||||
self._engine = create_engine(url, echo=echo)
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
return self._engine
|
||||
|
||||
class ClientSession:
|
||||
def __init__(self, engine):
|
||||
self.__engine = engine
|
||||
|
||||
def __enter__(self):
|
||||
self.__session = Session(self.__engine)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
self.commit()
|
||||
self.__session.close()
|
||||
|
||||
def commit(self):
|
||||
self.__session.commit()
|
||||
|
||||
def expunge_all(self):
|
||||
self.__session.expunge_all()
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
|
||||
if column is not None:
|
||||
if values:
|
||||
if isinstance(values, Sequence):
|
||||
stmt = select(type).where(column.in_(values))
|
||||
else:
|
||||
stmt = select(type).where(column == values)
|
||||
else:
|
||||
stmt = select(type).where(column)
|
||||
else:
|
||||
stmt = select(type)
|
||||
|
||||
return self.__session.scalars(stmt).all()
|
||||
|
||||
def uncategorized(self) -> Sequence[Transaction]:
|
||||
"""Selects all valid uncategorized transactions
|
||||
At this moment that includes:
|
||||
- Categories w/o category
|
||||
- AND non-split categories
|
||||
|
||||
Returns:
|
||||
Sequence[Transaction]: transactions left uncategorized
|
||||
"""
|
||||
stmt = (
|
||||
select(Transaction)
|
||||
.where(~Transaction.category.has())
|
||||
.where(Transaction.split == false())
|
||||
)
|
||||
return self.__session.scalars(stmt).all()
|
||||
|
||||
def insert(self, rows: list):
|
||||
self.__session.add_all(rows)
|
||||
|
||||
def remove_by_name(self, type, rows: list):
|
||||
stmt = delete(type).where(type.name.in_([row.name for row in rows]))
|
||||
self.__session.execute(stmt)
|
||||
|
||||
def updategroup(self, categories: list[Category], group: CategoryGroup):
|
||||
stmt = (
|
||||
update(Category)
|
||||
.where(Category.name.in_([cat.name for cat in categories]))
|
||||
.values(group=group)
|
||||
)
|
||||
self.__session.execute(stmt)
|
||||
|
||||
def updateschedules(self, schedules: list[CategorySchedule]):
|
||||
stmt = insert(CategorySchedule).values([asdict(s) for s in schedules])
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[CategorySchedule.name],
|
||||
set_=dict(
|
||||
recurring=stmt.excluded.recurring,
|
||||
period=stmt.excluded.period,
|
||||
period_multiplier=stmt.excluded.period_multiplier,
|
||||
),
|
||||
)
|
||||
self.__session.execute(stmt)
|
||||
|
||||
def remove_by_id(self, type, ids: list[int]):
|
||||
stmt = delete(type).where(type.id.in_(ids))
|
||||
self.__session.execute(stmt)
|
||||
|
||||
def update(self, type, values: list[dict]):
|
||||
print(type, values)
|
||||
self.__session.execute(update(type), values)
|
||||
|
||||
def remove_links(self, original: int, links: list[int]):
|
||||
stmt = delete(Link).where(
|
||||
Link.original == original, Link.link.in_(link for link in links)
|
||||
)
|
||||
self.__session.execute(stmt)
|
||||
|
||||
def session(self) -> ClientSession:
|
||||
return self.ClientSession(self.engine)
|
||||
@ -96,3 +96,54 @@ class TestDatabase:
|
||||
names = [banks[0].name, banks[1].name]
|
||||
result = client.select(Bank, lambda: Bank.name.in_(names))
|
||||
assert result == [banks[0], banks[1]]
|
||||
|
||||
def test_update_bank_with_session(self, client: Client, banks: list[Bank]):
|
||||
with client.session as session:
|
||||
name = banks[0].name
|
||||
bank = session.select(Bank, lambda: Bank.name == name)[0]
|
||||
bank.name = "anotherbank"
|
||||
|
||||
result = client.select(Bank, lambda: Bank.name == "anotherbank")
|
||||
assert len(result) == 1
|
||||
|
||||
def test_update_bank(self, client: Client, banks: list[Bank]):
|
||||
name = banks[0].name
|
||||
|
||||
result = client.select(Bank, lambda: Bank.name == name)
|
||||
assert result[0].type == AccountType.checking
|
||||
|
||||
update = {"name": name, "type": AccountType.savings}
|
||||
client.update(Bank, [update])
|
||||
|
||||
result = client.select(Bank, lambda: Bank.name == name)
|
||||
assert result[0].type == AccountType.savings
|
||||
|
||||
def test_update_nordigen(self, client: Client, banks: list[Bank]):
|
||||
name = banks[0].name
|
||||
|
||||
result = client.select(Nordigen, lambda: Nordigen.name == name)
|
||||
assert result[0].requisition_id == "req"
|
||||
|
||||
update = {"name": name, "requisition_id": "anotherreq"}
|
||||
client.update(Nordigen, [update])
|
||||
|
||||
result = client.select(Nordigen, lambda: Nordigen.name == name)
|
||||
assert result[0].requisition_id == "anotherreq"
|
||||
|
||||
result = client.select(Bank, lambda: Bank.name == name)
|
||||
assert getattr(result[0].nordigen, "requisition_id", None) == "anotherreq"
|
||||
|
||||
def test_remove_bank(self, client: Client, banks: list[Bank]):
|
||||
name = banks[0].name
|
||||
|
||||
result = client.select(Bank)
|
||||
assert len(result) == 3
|
||||
|
||||
client.delete(Bank, Bank.name, [name])
|
||||
result = client.select(Bank)
|
||||
assert len(result) == 2
|
||||
|
||||
names = [banks[1].name, banks[2].name]
|
||||
client.delete(Bank, Bank.name, names)
|
||||
result = client.select(Bank)
|
||||
assert len(result) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user