[Refactor] Clean out old DB client class

Swap almost all remaining calls to the old postgresql only DB class with
the new DB client.

Warning! Some operations are currently not implement, such as setting
category schedules and dismantling links.

`update` and `delete` methods added to DB `Client`.
This commit is contained in:
Luís Murta 2023-04-30 00:38:15 +01:00
parent da44ba5306
commit 13c783ca0e
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
5 changed files with 101 additions and 195 deletions

View File

@ -130,12 +130,12 @@ if __name__ == "__main__":
keys = {"category", "group"}
assert args.keys() >= keys, f"missing {args.keys() - keys}"
params = [type.Category(cat) for cat in args["category"]]
params.append(args["group"])
params = [{"name": cat, "group": args["group"]} for cat in args["category"]]
case Operation.CategoryRemove:
assert "category" in args, "argparser ill defined"
params = [type.Category(cat) for cat in args["category"]]
params = args["category"]
case Operation.CategorySchedule:
keys = {"category", "period", "frequency"}
@ -246,7 +246,7 @@ if __name__ == "__main__":
case Operation.GroupRemove:
assert "group" in args, "argparser ill defined"
params = [type.CategoryGroup(group) for group in args["group"]]
params = args["group"]
case Operation.Forge | Operation.Dismantle:
keys = {"original", "links"}

View File

@ -26,7 +26,6 @@ from pfbudget.db.model import (
Transaction,
TransactionCategory,
)
from pfbudget.db.postgresql import DbClient
from pfbudget.extract.nordigen import NordigenClient, NordigenCredentialsManager
from pfbudget.extract.parsers import parse_data
from pfbudget.extract.psd2 import PSD2Extractor
@ -111,20 +110,16 @@ class Manager:
Tagger(rules).transform_inplace(uncategorized)
case Operation.BankMod:
with self.db.session() as session:
session.update(Bank, params)
self.database.update(Bank, params)
case Operation.PSD2Mod:
with self.db.session() as session:
session.update(Nordigen, params)
self.database.update(Nordigen, params)
case Operation.BankDel:
with self.db.session() as session:
session.remove_by_name(Bank, params)
self.database.delete(Bank, Bank.name, params)
case Operation.PSD2Del:
with self.db.session() as session:
session.remove_by_name(Nordigen, params)
self.database.delete(Nordigen, Nordigen.name, params)
case Operation.Token:
Manager.nordigen_client().generate_token()
@ -150,40 +145,28 @@ class Manager:
self.database.insert(params)
case Operation.CategoryUpdate:
with self.db.session() as session:
session.updategroup(*params)
self.database.update(Category, params)
case Operation.CategoryRemove:
with self.db.session() as session:
session.remove_by_name(Category, params)
self.database.delete(Category, Category.name, params)
case Operation.CategorySchedule:
with self.db.session() as session:
session.updateschedules(params)
raise NotImplementedError
case Operation.RuleRemove:
assert all(isinstance(param, int) for param in params)
with self.db.session() as session:
session.remove_by_id(CategoryRule, params)
self.database.delete(CategoryRule, CategoryRule.id, params)
case Operation.TagRemove:
with self.db.session() as session:
session.remove_by_name(Tag, params)
self.database.delete(Tag, Tag.name, params)
case Operation.TagRuleRemove:
assert all(isinstance(param, int) for param in params)
with self.db.session() as session:
session.remove_by_id(TagRule, params)
self.database.delete(TagRule, TagRule.id, params)
case Operation.RuleModify | Operation.TagRuleModify:
assert all(isinstance(param, dict) for param in params)
with self.db.session() as session:
session.update(Rule, params)
self.database.update(Rule, params)
case Operation.GroupRemove:
assert all(isinstance(param, CategoryGroup) for param in params)
with self.db.session() as session:
session.remove_by_name(CategoryGroup, params)
self.database.delete(CategoryGroup, CategoryGroup.name, params)
case Operation.Forge:
if not (
@ -192,9 +175,14 @@ class Manager:
):
raise TypeError("f{params} are not transaction ids")
with self.db.session() as session:
original = session.get(Transaction, Transaction.id, params[0])[0]
links = session.get(Transaction, Transaction.id, params[1])
with self.database.session as session:
id = params[0]
original = session.select(
Transaction, lambda: Transaction.id == id
)[0]
ids = params[1]
links = session.select(Transaction, lambda: Transaction.id.in_(ids))
if not original.category:
original.category = self.askcategory(original)
@ -214,12 +202,7 @@ class Manager:
session.insert(tobelinked)
case Operation.Dismantle:
assert all(isinstance(param, Link) for param in params)
with self.db.session() as session:
original = params[0].original
links = [link.link for link in params]
session.remove_links(original, links)
raise NotImplementedError
case Operation.Split:
if len(params) < 1 and not all(
@ -234,8 +217,10 @@ class Manager:
f"{original.amount}€ != {sum(v for v, _ in params[1:])}"
)
with self.db.session() as session:
originals = session.get(Transaction, Transaction.id, [original.id])
with self.database.session as session:
originals = session.select(
Transaction, lambda: Transaction.id == original.id
)
assert len(originals) == 1, ">1 transactions matched {original.id}!"
originals[0].split = True
@ -293,8 +278,7 @@ class Manager:
transactions.append(transaction)
if self.certify(transactions):
with self.db.session() as session:
session.insert(transactions)
self.database.insert(transactions)
case Operation.ExportBanks:
with self.database.session as session:
@ -309,8 +293,7 @@ class Manager:
banks.append(bank)
if self.certify(banks):
with self.db.session() as session:
session.insert(banks)
self.database.insert(banks)
case Operation.ExportCategoryRules:
with self.database.session as session:
@ -324,8 +307,7 @@ class Manager:
rules = [CategoryRule(**row) for row in self.load(params[0], params[1])]
if self.certify(rules):
with self.db.session() as session:
session.insert(rules)
self.database.insert(rules)
case Operation.ExportTagRules:
with self.database.session as session:
@ -337,8 +319,7 @@ class Manager:
rules = [TagRule(**row) for row in self.load(params[0], params[1])]
if self.certify(rules):
with self.db.session() as session:
session.insert(rules)
self.database.insert(rules)
case Operation.ExportCategories:
with self.database.session as session:
@ -363,8 +344,7 @@ class Manager:
categories.append(category)
if self.certify(categories):
with self.db.session() as session:
session.insert(categories)
self.database.insert(categories)
case Operation.ExportCategoryGroups:
with self.database.session as session:
@ -380,8 +360,7 @@ class Manager:
]
if self.certify(groups):
with self.db.session() as session:
session.insert(groups)
self.database.insert(groups)
def parse(self, filename: Path, args: dict):
return parse_data(filename, args)
@ -389,8 +368,7 @@ class Manager:
def askcategory(self, transaction: Transaction):
selector = CategorySelector(Selector_T.manual)
with self.db.session() as session:
categories = session.get(Category)
categories = self.database.select(Category)
while True:
category = input(f"{transaction}: ")
@ -428,20 +406,12 @@ class Manager:
return True
return False
@property
def db(self) -> DbClient:
return DbClient(self._db, self._verbosity > 2)
@property
def database(self) -> Client:
if not self._database:
self._database = Client(self._db, echo=self._verbosity > 2)
return self._database
@db.setter
def db(self, url: str):
self._db = url
@staticmethod
def nordigen_client() -> NordigenClient:
return NordigenClient(NordigenCredentialsManager.default)

View File

@ -1,8 +1,8 @@
from collections.abc import Sequence
from copy import deepcopy
from sqlalchemy import Engine, create_engine, select
from sqlalchemy import Engine, create_engine, delete, select, update
from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Type, TypeVar
# from pfbudget.db.exceptions import InsertError, SelectError
@ -52,6 +52,14 @@ class Client:
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
return self.session.select(what, exists)
def update(self, what: Type[Any], values: Sequence[Mapping[str, Any]]) -> None:
with self._sessionmaker() as session, session.begin():
session.execute(update(what), values)
def delete(self, what: Type[Any], column: Any, values: Sequence[str]) -> None:
with self._sessionmaker() as session, session.begin():
session.execute(delete(what).where(column.in_(values)))
@property
def engine(self) -> Engine:
return self._engine

View File

@ -1,123 +0,0 @@
from dataclasses import asdict
from sqlalchemy import create_engine, delete, select, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import false
from typing import Sequence, Type, TypeVar
from pfbudget.db.model import (
Category,
CategoryGroup,
CategorySchedule,
Link,
Transaction,
)
class DbClient:
"""
General database client using sqlalchemy
"""
__sessions: list[Session]
def __init__(self, url: str, echo=False) -> None:
self._engine = create_engine(url, echo=echo)
@property
def engine(self):
return self._engine
class ClientSession:
def __init__(self, engine):
self.__engine = engine
def __enter__(self):
self.__session = Session(self.__engine)
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.commit()
self.__session.close()
def commit(self):
self.__session.commit()
def expunge_all(self):
self.__session.expunge_all()
T = TypeVar("T")
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
if column is not None:
if values:
if isinstance(values, Sequence):
stmt = select(type).where(column.in_(values))
else:
stmt = select(type).where(column == values)
else:
stmt = select(type).where(column)
else:
stmt = select(type)
return self.__session.scalars(stmt).all()
def uncategorized(self) -> Sequence[Transaction]:
"""Selects all valid uncategorized transactions
At this moment that includes:
- Categories w/o category
- AND non-split categories
Returns:
Sequence[Transaction]: transactions left uncategorized
"""
stmt = (
select(Transaction)
.where(~Transaction.category.has())
.where(Transaction.split == false())
)
return self.__session.scalars(stmt).all()
def insert(self, rows: list):
self.__session.add_all(rows)
def remove_by_name(self, type, rows: list):
stmt = delete(type).where(type.name.in_([row.name for row in rows]))
self.__session.execute(stmt)
def updategroup(self, categories: list[Category], group: CategoryGroup):
stmt = (
update(Category)
.where(Category.name.in_([cat.name for cat in categories]))
.values(group=group)
)
self.__session.execute(stmt)
def updateschedules(self, schedules: list[CategorySchedule]):
stmt = insert(CategorySchedule).values([asdict(s) for s in schedules])
stmt = stmt.on_conflict_do_update(
index_elements=[CategorySchedule.name],
set_=dict(
recurring=stmt.excluded.recurring,
period=stmt.excluded.period,
period_multiplier=stmt.excluded.period_multiplier,
),
)
self.__session.execute(stmt)
def remove_by_id(self, type, ids: list[int]):
stmt = delete(type).where(type.id.in_(ids))
self.__session.execute(stmt)
def update(self, type, values: list[dict]):
print(type, values)
self.__session.execute(update(type), values)
def remove_links(self, original: int, links: list[int]):
stmt = delete(Link).where(
Link.original == original, Link.link.in_(link for link in links)
)
self.__session.execute(stmt)
def session(self) -> ClientSession:
return self.ClientSession(self.engine)

View File

@ -96,3 +96,54 @@ class TestDatabase:
names = [banks[0].name, banks[1].name]
result = client.select(Bank, lambda: Bank.name.in_(names))
assert result == [banks[0], banks[1]]
def test_update_bank_with_session(self, client: Client, banks: list[Bank]):
with client.session as session:
name = banks[0].name
bank = session.select(Bank, lambda: Bank.name == name)[0]
bank.name = "anotherbank"
result = client.select(Bank, lambda: Bank.name == "anotherbank")
assert len(result) == 1
def test_update_bank(self, client: Client, banks: list[Bank]):
name = banks[0].name
result = client.select(Bank, lambda: Bank.name == name)
assert result[0].type == AccountType.checking
update = {"name": name, "type": AccountType.savings}
client.update(Bank, [update])
result = client.select(Bank, lambda: Bank.name == name)
assert result[0].type == AccountType.savings
def test_update_nordigen(self, client: Client, banks: list[Bank]):
name = banks[0].name
result = client.select(Nordigen, lambda: Nordigen.name == name)
assert result[0].requisition_id == "req"
update = {"name": name, "requisition_id": "anotherreq"}
client.update(Nordigen, [update])
result = client.select(Nordigen, lambda: Nordigen.name == name)
assert result[0].requisition_id == "anotherreq"
result = client.select(Bank, lambda: Bank.name == name)
assert getattr(result[0].nordigen, "requisition_id", None) == "anotherreq"
def test_remove_bank(self, client: Client, banks: list[Bank]):
name = banks[0].name
result = client.select(Bank)
assert len(result) == 3
client.delete(Bank, Bank.name, [name])
result = client.select(Bank)
assert len(result) == 2
names = [banks[1].name, banks[2].name]
client.delete(Bank, Bank.name, names)
result = client.select(Bank)
assert len(result) == 0