From 13c783ca0e8a1152679e1ac060d3ba93e67aa86d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Murta?= Date: Sun, 30 Apr 2023 00:38:15 +0100 Subject: [PATCH] [Refactor] Clean out old DB client class Swap almost all remaining calls to the old postgresql only DB class with the new DB client. Warning! Some operations are currently not implement, such as setting category schedules and dismantling links. `update` and `delete` methods added to DB `Client`. --- pfbudget/__main__.py | 8 +-- pfbudget/core/manager.py | 102 +++++++++++-------------------- pfbudget/db/client.py | 12 +++- pfbudget/db/postgresql.py | 123 -------------------------------------- tests/test_database.py | 51 ++++++++++++++++ 5 files changed, 101 insertions(+), 195 deletions(-) delete mode 100644 pfbudget/db/postgresql.py diff --git a/pfbudget/__main__.py b/pfbudget/__main__.py index 1f788c5..c7718ca 100644 --- a/pfbudget/__main__.py +++ b/pfbudget/__main__.py @@ -130,12 +130,12 @@ if __name__ == "__main__": keys = {"category", "group"} assert args.keys() >= keys, f"missing {args.keys() - keys}" - params = [type.Category(cat) for cat in args["category"]] - params.append(args["group"]) + params = [{"name": cat, "group": args["group"]} for cat in args["category"]] case Operation.CategoryRemove: assert "category" in args, "argparser ill defined" - params = [type.Category(cat) for cat in args["category"]] + + params = args["category"] case Operation.CategorySchedule: keys = {"category", "period", "frequency"} @@ -246,7 +246,7 @@ if __name__ == "__main__": case Operation.GroupRemove: assert "group" in args, "argparser ill defined" - params = [type.CategoryGroup(group) for group in args["group"]] + params = args["group"] case Operation.Forge | Operation.Dismantle: keys = {"original", "links"} diff --git a/pfbudget/core/manager.py b/pfbudget/core/manager.py index 3fac47f..51677d7 100644 --- a/pfbudget/core/manager.py +++ b/pfbudget/core/manager.py @@ -26,7 +26,6 @@ from pfbudget.db.model import ( Transaction, TransactionCategory, ) -from pfbudget.db.postgresql import DbClient from pfbudget.extract.nordigen import NordigenClient, NordigenCredentialsManager from pfbudget.extract.parsers import parse_data from pfbudget.extract.psd2 import PSD2Extractor @@ -111,20 +110,16 @@ class Manager: Tagger(rules).transform_inplace(uncategorized) case Operation.BankMod: - with self.db.session() as session: - session.update(Bank, params) + self.database.update(Bank, params) case Operation.PSD2Mod: - with self.db.session() as session: - session.update(Nordigen, params) + self.database.update(Nordigen, params) case Operation.BankDel: - with self.db.session() as session: - session.remove_by_name(Bank, params) + self.database.delete(Bank, Bank.name, params) case Operation.PSD2Del: - with self.db.session() as session: - session.remove_by_name(Nordigen, params) + self.database.delete(Nordigen, Nordigen.name, params) case Operation.Token: Manager.nordigen_client().generate_token() @@ -150,40 +145,28 @@ class Manager: self.database.insert(params) case Operation.CategoryUpdate: - with self.db.session() as session: - session.updategroup(*params) + self.database.update(Category, params) case Operation.CategoryRemove: - with self.db.session() as session: - session.remove_by_name(Category, params) + self.database.delete(Category, Category.name, params) case Operation.CategorySchedule: - with self.db.session() as session: - session.updateschedules(params) + raise NotImplementedError case Operation.RuleRemove: - assert all(isinstance(param, int) for param in params) - with self.db.session() as session: - session.remove_by_id(CategoryRule, params) + self.database.delete(CategoryRule, CategoryRule.id, params) case Operation.TagRemove: - with self.db.session() as session: - session.remove_by_name(Tag, params) + self.database.delete(Tag, Tag.name, params) case Operation.TagRuleRemove: - assert all(isinstance(param, int) for param in params) - with self.db.session() as session: - session.remove_by_id(TagRule, params) + self.database.delete(TagRule, TagRule.id, params) case Operation.RuleModify | Operation.TagRuleModify: - assert all(isinstance(param, dict) for param in params) - with self.db.session() as session: - session.update(Rule, params) + self.database.update(Rule, params) case Operation.GroupRemove: - assert all(isinstance(param, CategoryGroup) for param in params) - with self.db.session() as session: - session.remove_by_name(CategoryGroup, params) + self.database.delete(CategoryGroup, CategoryGroup.name, params) case Operation.Forge: if not ( @@ -192,9 +175,14 @@ class Manager: ): raise TypeError("f{params} are not transaction ids") - with self.db.session() as session: - original = session.get(Transaction, Transaction.id, params[0])[0] - links = session.get(Transaction, Transaction.id, params[1]) + with self.database.session as session: + id = params[0] + original = session.select( + Transaction, lambda: Transaction.id == id + )[0] + + ids = params[1] + links = session.select(Transaction, lambda: Transaction.id.in_(ids)) if not original.category: original.category = self.askcategory(original) @@ -214,12 +202,7 @@ class Manager: session.insert(tobelinked) case Operation.Dismantle: - assert all(isinstance(param, Link) for param in params) - - with self.db.session() as session: - original = params[0].original - links = [link.link for link in params] - session.remove_links(original, links) + raise NotImplementedError case Operation.Split: if len(params) < 1 and not all( @@ -234,8 +217,10 @@ class Manager: f"{original.amount}€ != {sum(v for v, _ in params[1:])}€" ) - with self.db.session() as session: - originals = session.get(Transaction, Transaction.id, [original.id]) + with self.database.session as session: + originals = session.select( + Transaction, lambda: Transaction.id == original.id + ) assert len(originals) == 1, ">1 transactions matched {original.id}!" originals[0].split = True @@ -293,8 +278,7 @@ class Manager: transactions.append(transaction) if self.certify(transactions): - with self.db.session() as session: - session.insert(transactions) + self.database.insert(transactions) case Operation.ExportBanks: with self.database.session as session: @@ -309,8 +293,7 @@ class Manager: banks.append(bank) if self.certify(banks): - with self.db.session() as session: - session.insert(banks) + self.database.insert(banks) case Operation.ExportCategoryRules: with self.database.session as session: @@ -324,8 +307,7 @@ class Manager: rules = [CategoryRule(**row) for row in self.load(params[0], params[1])] if self.certify(rules): - with self.db.session() as session: - session.insert(rules) + self.database.insert(rules) case Operation.ExportTagRules: with self.database.session as session: @@ -337,8 +319,7 @@ class Manager: rules = [TagRule(**row) for row in self.load(params[0], params[1])] if self.certify(rules): - with self.db.session() as session: - session.insert(rules) + self.database.insert(rules) case Operation.ExportCategories: with self.database.session as session: @@ -363,8 +344,7 @@ class Manager: categories.append(category) if self.certify(categories): - with self.db.session() as session: - session.insert(categories) + self.database.insert(categories) case Operation.ExportCategoryGroups: with self.database.session as session: @@ -380,8 +360,7 @@ class Manager: ] if self.certify(groups): - with self.db.session() as session: - session.insert(groups) + self.database.insert(groups) def parse(self, filename: Path, args: dict): return parse_data(filename, args) @@ -389,13 +368,12 @@ class Manager: def askcategory(self, transaction: Transaction): selector = CategorySelector(Selector_T.manual) - with self.db.session() as session: - categories = session.get(Category) + categories = self.database.select(Category) - while True: - category = input(f"{transaction}: ") - if category in [c.name for c in categories]: - return TransactionCategory(category, selector) + while True: + category = input(f"{transaction}: ") + if category in [c.name for c in categories]: + return TransactionCategory(category, selector) @staticmethod def dump(fn, format, sequence): @@ -428,20 +406,12 @@ class Manager: return True return False - @property - def db(self) -> DbClient: - return DbClient(self._db, self._verbosity > 2) - @property def database(self) -> Client: if not self._database: self._database = Client(self._db, echo=self._verbosity > 2) return self._database - @db.setter - def db(self, url: str): - self._db = url - @staticmethod def nordigen_client() -> NordigenClient: return NordigenClient(NordigenCredentialsManager.default) diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 90339bd..e911e31 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from copy import deepcopy -from sqlalchemy import Engine, create_engine, select +from sqlalchemy import Engine, create_engine, delete, select, update from sqlalchemy.orm import Session, sessionmaker -from typing import Any, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Type, TypeVar # from pfbudget.db.exceptions import InsertError, SelectError @@ -52,6 +52,14 @@ class Client: def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]: return self.session.select(what, exists) + def update(self, what: Type[Any], values: Sequence[Mapping[str, Any]]) -> None: + with self._sessionmaker() as session, session.begin(): + session.execute(update(what), values) + + def delete(self, what: Type[Any], column: Any, values: Sequence[str]) -> None: + with self._sessionmaker() as session, session.begin(): + session.execute(delete(what).where(column.in_(values))) + @property def engine(self) -> Engine: return self._engine diff --git a/pfbudget/db/postgresql.py b/pfbudget/db/postgresql.py deleted file mode 100644 index 892c4e2..0000000 --- a/pfbudget/db/postgresql.py +++ /dev/null @@ -1,123 +0,0 @@ -from dataclasses import asdict -from sqlalchemy import create_engine, delete, select, update -from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import false -from typing import Sequence, Type, TypeVar - -from pfbudget.db.model import ( - Category, - CategoryGroup, - CategorySchedule, - Link, - Transaction, -) - - -class DbClient: - """ - General database client using sqlalchemy - """ - - __sessions: list[Session] - - def __init__(self, url: str, echo=False) -> None: - self._engine = create_engine(url, echo=echo) - - @property - def engine(self): - return self._engine - - class ClientSession: - def __init__(self, engine): - self.__engine = engine - - def __enter__(self): - self.__session = Session(self.__engine) - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - self.commit() - self.__session.close() - - def commit(self): - self.__session.commit() - - def expunge_all(self): - self.__session.expunge_all() - - T = TypeVar("T") - - def get(self, type: Type[T], column=None, values=None) -> Sequence[T]: - if column is not None: - if values: - if isinstance(values, Sequence): - stmt = select(type).where(column.in_(values)) - else: - stmt = select(type).where(column == values) - else: - stmt = select(type).where(column) - else: - stmt = select(type) - - return self.__session.scalars(stmt).all() - - def uncategorized(self) -> Sequence[Transaction]: - """Selects all valid uncategorized transactions - At this moment that includes: - - Categories w/o category - - AND non-split categories - - Returns: - Sequence[Transaction]: transactions left uncategorized - """ - stmt = ( - select(Transaction) - .where(~Transaction.category.has()) - .where(Transaction.split == false()) - ) - return self.__session.scalars(stmt).all() - - def insert(self, rows: list): - self.__session.add_all(rows) - - def remove_by_name(self, type, rows: list): - stmt = delete(type).where(type.name.in_([row.name for row in rows])) - self.__session.execute(stmt) - - def updategroup(self, categories: list[Category], group: CategoryGroup): - stmt = ( - update(Category) - .where(Category.name.in_([cat.name for cat in categories])) - .values(group=group) - ) - self.__session.execute(stmt) - - def updateschedules(self, schedules: list[CategorySchedule]): - stmt = insert(CategorySchedule).values([asdict(s) for s in schedules]) - stmt = stmt.on_conflict_do_update( - index_elements=[CategorySchedule.name], - set_=dict( - recurring=stmt.excluded.recurring, - period=stmt.excluded.period, - period_multiplier=stmt.excluded.period_multiplier, - ), - ) - self.__session.execute(stmt) - - def remove_by_id(self, type, ids: list[int]): - stmt = delete(type).where(type.id.in_(ids)) - self.__session.execute(stmt) - - def update(self, type, values: list[dict]): - print(type, values) - self.__session.execute(update(type), values) - - def remove_links(self, original: int, links: list[int]): - stmt = delete(Link).where( - Link.original == original, Link.link.in_(link for link in links) - ) - self.__session.execute(stmt) - - def session(self) -> ClientSession: - return self.ClientSession(self.engine) diff --git a/tests/test_database.py b/tests/test_database.py index 8fc8f9e..e6c5ca1 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -96,3 +96,54 @@ class TestDatabase: names = [banks[0].name, banks[1].name] result = client.select(Bank, lambda: Bank.name.in_(names)) assert result == [banks[0], banks[1]] + + def test_update_bank_with_session(self, client: Client, banks: list[Bank]): + with client.session as session: + name = banks[0].name + bank = session.select(Bank, lambda: Bank.name == name)[0] + bank.name = "anotherbank" + + result = client.select(Bank, lambda: Bank.name == "anotherbank") + assert len(result) == 1 + + def test_update_bank(self, client: Client, banks: list[Bank]): + name = banks[0].name + + result = client.select(Bank, lambda: Bank.name == name) + assert result[0].type == AccountType.checking + + update = {"name": name, "type": AccountType.savings} + client.update(Bank, [update]) + + result = client.select(Bank, lambda: Bank.name == name) + assert result[0].type == AccountType.savings + + def test_update_nordigen(self, client: Client, banks: list[Bank]): + name = banks[0].name + + result = client.select(Nordigen, lambda: Nordigen.name == name) + assert result[0].requisition_id == "req" + + update = {"name": name, "requisition_id": "anotherreq"} + client.update(Nordigen, [update]) + + result = client.select(Nordigen, lambda: Nordigen.name == name) + assert result[0].requisition_id == "anotherreq" + + result = client.select(Bank, lambda: Bank.name == name) + assert getattr(result[0].nordigen, "requisition_id", None) == "anotherreq" + + def test_remove_bank(self, client: Client, banks: list[Bank]): + name = banks[0].name + + result = client.select(Bank) + assert len(result) == 3 + + client.delete(Bank, Bank.name, [name]) + result = client.select(Bank) + assert len(result) == 2 + + names = [banks[1].name, banks[2].name] + client.delete(Bank, Bank.name, names) + result = client.select(Bank) + assert len(result) == 0