[Refactor] Clean out old DB client class

Swap almost all remaining calls to the old postgresql only DB class with
the new DB client.

Warning! Some operations are currently not implement, such as setting
category schedules and dismantling links.

`update` and `delete` methods added to DB `Client`.
This commit is contained in:
Luís Murta 2023-04-30 00:38:15 +01:00
parent da44ba5306
commit 13c783ca0e
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
5 changed files with 101 additions and 195 deletions

View File

@ -130,12 +130,12 @@ if __name__ == "__main__":
keys = {"category", "group"} keys = {"category", "group"}
assert args.keys() >= keys, f"missing {args.keys() - keys}" assert args.keys() >= keys, f"missing {args.keys() - keys}"
params = [type.Category(cat) for cat in args["category"]] params = [{"name": cat, "group": args["group"]} for cat in args["category"]]
params.append(args["group"])
case Operation.CategoryRemove: case Operation.CategoryRemove:
assert "category" in args, "argparser ill defined" assert "category" in args, "argparser ill defined"
params = [type.Category(cat) for cat in args["category"]]
params = args["category"]
case Operation.CategorySchedule: case Operation.CategorySchedule:
keys = {"category", "period", "frequency"} keys = {"category", "period", "frequency"}
@ -246,7 +246,7 @@ if __name__ == "__main__":
case Operation.GroupRemove: case Operation.GroupRemove:
assert "group" in args, "argparser ill defined" assert "group" in args, "argparser ill defined"
params = [type.CategoryGroup(group) for group in args["group"]] params = args["group"]
case Operation.Forge | Operation.Dismantle: case Operation.Forge | Operation.Dismantle:
keys = {"original", "links"} keys = {"original", "links"}

View File

@ -26,7 +26,6 @@ from pfbudget.db.model import (
Transaction, Transaction,
TransactionCategory, TransactionCategory,
) )
from pfbudget.db.postgresql import DbClient
from pfbudget.extract.nordigen import NordigenClient, NordigenCredentialsManager from pfbudget.extract.nordigen import NordigenClient, NordigenCredentialsManager
from pfbudget.extract.parsers import parse_data from pfbudget.extract.parsers import parse_data
from pfbudget.extract.psd2 import PSD2Extractor from pfbudget.extract.psd2 import PSD2Extractor
@ -111,20 +110,16 @@ class Manager:
Tagger(rules).transform_inplace(uncategorized) Tagger(rules).transform_inplace(uncategorized)
case Operation.BankMod: case Operation.BankMod:
with self.db.session() as session: self.database.update(Bank, params)
session.update(Bank, params)
case Operation.PSD2Mod: case Operation.PSD2Mod:
with self.db.session() as session: self.database.update(Nordigen, params)
session.update(Nordigen, params)
case Operation.BankDel: case Operation.BankDel:
with self.db.session() as session: self.database.delete(Bank, Bank.name, params)
session.remove_by_name(Bank, params)
case Operation.PSD2Del: case Operation.PSD2Del:
with self.db.session() as session: self.database.delete(Nordigen, Nordigen.name, params)
session.remove_by_name(Nordigen, params)
case Operation.Token: case Operation.Token:
Manager.nordigen_client().generate_token() Manager.nordigen_client().generate_token()
@ -150,40 +145,28 @@ class Manager:
self.database.insert(params) self.database.insert(params)
case Operation.CategoryUpdate: case Operation.CategoryUpdate:
with self.db.session() as session: self.database.update(Category, params)
session.updategroup(*params)
case Operation.CategoryRemove: case Operation.CategoryRemove:
with self.db.session() as session: self.database.delete(Category, Category.name, params)
session.remove_by_name(Category, params)
case Operation.CategorySchedule: case Operation.CategorySchedule:
with self.db.session() as session: raise NotImplementedError
session.updateschedules(params)
case Operation.RuleRemove: case Operation.RuleRemove:
assert all(isinstance(param, int) for param in params) self.database.delete(CategoryRule, CategoryRule.id, params)
with self.db.session() as session:
session.remove_by_id(CategoryRule, params)
case Operation.TagRemove: case Operation.TagRemove:
with self.db.session() as session: self.database.delete(Tag, Tag.name, params)
session.remove_by_name(Tag, params)
case Operation.TagRuleRemove: case Operation.TagRuleRemove:
assert all(isinstance(param, int) for param in params) self.database.delete(TagRule, TagRule.id, params)
with self.db.session() as session:
session.remove_by_id(TagRule, params)
case Operation.RuleModify | Operation.TagRuleModify: case Operation.RuleModify | Operation.TagRuleModify:
assert all(isinstance(param, dict) for param in params) self.database.update(Rule, params)
with self.db.session() as session:
session.update(Rule, params)
case Operation.GroupRemove: case Operation.GroupRemove:
assert all(isinstance(param, CategoryGroup) for param in params) self.database.delete(CategoryGroup, CategoryGroup.name, params)
with self.db.session() as session:
session.remove_by_name(CategoryGroup, params)
case Operation.Forge: case Operation.Forge:
if not ( if not (
@ -192,9 +175,14 @@ class Manager:
): ):
raise TypeError("f{params} are not transaction ids") raise TypeError("f{params} are not transaction ids")
with self.db.session() as session: with self.database.session as session:
original = session.get(Transaction, Transaction.id, params[0])[0] id = params[0]
links = session.get(Transaction, Transaction.id, params[1]) original = session.select(
Transaction, lambda: Transaction.id == id
)[0]
ids = params[1]
links = session.select(Transaction, lambda: Transaction.id.in_(ids))
if not original.category: if not original.category:
original.category = self.askcategory(original) original.category = self.askcategory(original)
@ -214,12 +202,7 @@ class Manager:
session.insert(tobelinked) session.insert(tobelinked)
case Operation.Dismantle: case Operation.Dismantle:
assert all(isinstance(param, Link) for param in params) raise NotImplementedError
with self.db.session() as session:
original = params[0].original
links = [link.link for link in params]
session.remove_links(original, links)
case Operation.Split: case Operation.Split:
if len(params) < 1 and not all( if len(params) < 1 and not all(
@ -234,8 +217,10 @@ class Manager:
f"{original.amount}€ != {sum(v for v, _ in params[1:])}" f"{original.amount}€ != {sum(v for v, _ in params[1:])}"
) )
with self.db.session() as session: with self.database.session as session:
originals = session.get(Transaction, Transaction.id, [original.id]) originals = session.select(
Transaction, lambda: Transaction.id == original.id
)
assert len(originals) == 1, ">1 transactions matched {original.id}!" assert len(originals) == 1, ">1 transactions matched {original.id}!"
originals[0].split = True originals[0].split = True
@ -293,8 +278,7 @@ class Manager:
transactions.append(transaction) transactions.append(transaction)
if self.certify(transactions): if self.certify(transactions):
with self.db.session() as session: self.database.insert(transactions)
session.insert(transactions)
case Operation.ExportBanks: case Operation.ExportBanks:
with self.database.session as session: with self.database.session as session:
@ -309,8 +293,7 @@ class Manager:
banks.append(bank) banks.append(bank)
if self.certify(banks): if self.certify(banks):
with self.db.session() as session: self.database.insert(banks)
session.insert(banks)
case Operation.ExportCategoryRules: case Operation.ExportCategoryRules:
with self.database.session as session: with self.database.session as session:
@ -324,8 +307,7 @@ class Manager:
rules = [CategoryRule(**row) for row in self.load(params[0], params[1])] rules = [CategoryRule(**row) for row in self.load(params[0], params[1])]
if self.certify(rules): if self.certify(rules):
with self.db.session() as session: self.database.insert(rules)
session.insert(rules)
case Operation.ExportTagRules: case Operation.ExportTagRules:
with self.database.session as session: with self.database.session as session:
@ -337,8 +319,7 @@ class Manager:
rules = [TagRule(**row) for row in self.load(params[0], params[1])] rules = [TagRule(**row) for row in self.load(params[0], params[1])]
if self.certify(rules): if self.certify(rules):
with self.db.session() as session: self.database.insert(rules)
session.insert(rules)
case Operation.ExportCategories: case Operation.ExportCategories:
with self.database.session as session: with self.database.session as session:
@ -363,8 +344,7 @@ class Manager:
categories.append(category) categories.append(category)
if self.certify(categories): if self.certify(categories):
with self.db.session() as session: self.database.insert(categories)
session.insert(categories)
case Operation.ExportCategoryGroups: case Operation.ExportCategoryGroups:
with self.database.session as session: with self.database.session as session:
@ -380,8 +360,7 @@ class Manager:
] ]
if self.certify(groups): if self.certify(groups):
with self.db.session() as session: self.database.insert(groups)
session.insert(groups)
def parse(self, filename: Path, args: dict): def parse(self, filename: Path, args: dict):
return parse_data(filename, args) return parse_data(filename, args)
@ -389,13 +368,12 @@ class Manager:
def askcategory(self, transaction: Transaction): def askcategory(self, transaction: Transaction):
selector = CategorySelector(Selector_T.manual) selector = CategorySelector(Selector_T.manual)
with self.db.session() as session: categories = self.database.select(Category)
categories = session.get(Category)
while True: while True:
category = input(f"{transaction}: ") category = input(f"{transaction}: ")
if category in [c.name for c in categories]: if category in [c.name for c in categories]:
return TransactionCategory(category, selector) return TransactionCategory(category, selector)
@staticmethod @staticmethod
def dump(fn, format, sequence): def dump(fn, format, sequence):
@ -428,20 +406,12 @@ class Manager:
return True return True
return False return False
@property
def db(self) -> DbClient:
return DbClient(self._db, self._verbosity > 2)
@property @property
def database(self) -> Client: def database(self) -> Client:
if not self._database: if not self._database:
self._database = Client(self._db, echo=self._verbosity > 2) self._database = Client(self._db, echo=self._verbosity > 2)
return self._database return self._database
@db.setter
def db(self, url: str):
self._db = url
@staticmethod @staticmethod
def nordigen_client() -> NordigenClient: def nordigen_client() -> NordigenClient:
return NordigenClient(NordigenCredentialsManager.default) return NordigenClient(NordigenCredentialsManager.default)

View File

@ -1,8 +1,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from sqlalchemy import Engine, create_engine, select from sqlalchemy import Engine, create_engine, delete, select, update
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Type, TypeVar from typing import Any, Mapping, Optional, Type, TypeVar
# from pfbudget.db.exceptions import InsertError, SelectError # from pfbudget.db.exceptions import InsertError, SelectError
@ -52,6 +52,14 @@ class Client:
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]: def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
return self.session.select(what, exists) return self.session.select(what, exists)
def update(self, what: Type[Any], values: Sequence[Mapping[str, Any]]) -> None:
with self._sessionmaker() as session, session.begin():
session.execute(update(what), values)
def delete(self, what: Type[Any], column: Any, values: Sequence[str]) -> None:
with self._sessionmaker() as session, session.begin():
session.execute(delete(what).where(column.in_(values)))
@property @property
def engine(self) -> Engine: def engine(self) -> Engine:
return self._engine return self._engine

View File

@ -1,123 +0,0 @@
from dataclasses import asdict
from sqlalchemy import create_engine, delete, select, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import false
from typing import Sequence, Type, TypeVar
from pfbudget.db.model import (
Category,
CategoryGroup,
CategorySchedule,
Link,
Transaction,
)
class DbClient:
"""
General database client using sqlalchemy
"""
__sessions: list[Session]
def __init__(self, url: str, echo=False) -> None:
self._engine = create_engine(url, echo=echo)
@property
def engine(self):
return self._engine
class ClientSession:
def __init__(self, engine):
self.__engine = engine
def __enter__(self):
self.__session = Session(self.__engine)
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.commit()
self.__session.close()
def commit(self):
self.__session.commit()
def expunge_all(self):
self.__session.expunge_all()
T = TypeVar("T")
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
if column is not None:
if values:
if isinstance(values, Sequence):
stmt = select(type).where(column.in_(values))
else:
stmt = select(type).where(column == values)
else:
stmt = select(type).where(column)
else:
stmt = select(type)
return self.__session.scalars(stmt).all()
def uncategorized(self) -> Sequence[Transaction]:
"""Selects all valid uncategorized transactions
At this moment that includes:
- Categories w/o category
- AND non-split categories
Returns:
Sequence[Transaction]: transactions left uncategorized
"""
stmt = (
select(Transaction)
.where(~Transaction.category.has())
.where(Transaction.split == false())
)
return self.__session.scalars(stmt).all()
def insert(self, rows: list):
self.__session.add_all(rows)
def remove_by_name(self, type, rows: list):
stmt = delete(type).where(type.name.in_([row.name for row in rows]))
self.__session.execute(stmt)
def updategroup(self, categories: list[Category], group: CategoryGroup):
stmt = (
update(Category)
.where(Category.name.in_([cat.name for cat in categories]))
.values(group=group)
)
self.__session.execute(stmt)
def updateschedules(self, schedules: list[CategorySchedule]):
stmt = insert(CategorySchedule).values([asdict(s) for s in schedules])
stmt = stmt.on_conflict_do_update(
index_elements=[CategorySchedule.name],
set_=dict(
recurring=stmt.excluded.recurring,
period=stmt.excluded.period,
period_multiplier=stmt.excluded.period_multiplier,
),
)
self.__session.execute(stmt)
def remove_by_id(self, type, ids: list[int]):
stmt = delete(type).where(type.id.in_(ids))
self.__session.execute(stmt)
def update(self, type, values: list[dict]):
print(type, values)
self.__session.execute(update(type), values)
def remove_links(self, original: int, links: list[int]):
stmt = delete(Link).where(
Link.original == original, Link.link.in_(link for link in links)
)
self.__session.execute(stmt)
def session(self) -> ClientSession:
return self.ClientSession(self.engine)

View File

@ -96,3 +96,54 @@ class TestDatabase:
names = [banks[0].name, banks[1].name] names = [banks[0].name, banks[1].name]
result = client.select(Bank, lambda: Bank.name.in_(names)) result = client.select(Bank, lambda: Bank.name.in_(names))
assert result == [banks[0], banks[1]] assert result == [banks[0], banks[1]]
def test_update_bank_with_session(self, client: Client, banks: list[Bank]):
with client.session as session:
name = banks[0].name
bank = session.select(Bank, lambda: Bank.name == name)[0]
bank.name = "anotherbank"
result = client.select(Bank, lambda: Bank.name == "anotherbank")
assert len(result) == 1
def test_update_bank(self, client: Client, banks: list[Bank]):
name = banks[0].name
result = client.select(Bank, lambda: Bank.name == name)
assert result[0].type == AccountType.checking
update = {"name": name, "type": AccountType.savings}
client.update(Bank, [update])
result = client.select(Bank, lambda: Bank.name == name)
assert result[0].type == AccountType.savings
def test_update_nordigen(self, client: Client, banks: list[Bank]):
name = banks[0].name
result = client.select(Nordigen, lambda: Nordigen.name == name)
assert result[0].requisition_id == "req"
update = {"name": name, "requisition_id": "anotherreq"}
client.update(Nordigen, [update])
result = client.select(Nordigen, lambda: Nordigen.name == name)
assert result[0].requisition_id == "anotherreq"
result = client.select(Bank, lambda: Bank.name == name)
assert getattr(result[0].nordigen, "requisition_id", None) == "anotherreq"
def test_remove_bank(self, client: Client, banks: list[Bank]):
name = banks[0].name
result = client.select(Bank)
assert len(result) == 3
client.delete(Bank, Bank.name, [name])
result = client.select(Bank)
assert len(result) == 2
names = [banks[1].name, banks[2].name]
client.delete(Bank, Bank.name, names)
result = client.select(Bank)
assert len(result) == 0