diff --git a/pfbudget/__main__.py b/pfbudget/__main__.py index 1f788c5..c7718ca 100644 --- a/pfbudget/__main__.py +++ b/pfbudget/__main__.py @@ -130,12 +130,12 @@ if __name__ == "__main__": keys = {"category", "group"} assert args.keys() >= keys, f"missing {args.keys() - keys}" - params = [type.Category(cat) for cat in args["category"]] - params.append(args["group"]) + params = [{"name": cat, "group": args["group"]} for cat in args["category"]] case Operation.CategoryRemove: assert "category" in args, "argparser ill defined" - params = [type.Category(cat) for cat in args["category"]] + + params = args["category"] case Operation.CategorySchedule: keys = {"category", "period", "frequency"} @@ -246,7 +246,7 @@ if __name__ == "__main__": case Operation.GroupRemove: assert "group" in args, "argparser ill defined" - params = [type.CategoryGroup(group) for group in args["group"]] + params = args["group"] case Operation.Forge | Operation.Dismantle: keys = {"original", "links"} diff --git a/pfbudget/core/manager.py b/pfbudget/core/manager.py index 3fac47f..51677d7 100644 --- a/pfbudget/core/manager.py +++ b/pfbudget/core/manager.py @@ -26,7 +26,6 @@ from pfbudget.db.model import ( Transaction, TransactionCategory, ) -from pfbudget.db.postgresql import DbClient from pfbudget.extract.nordigen import NordigenClient, NordigenCredentialsManager from pfbudget.extract.parsers import parse_data from pfbudget.extract.psd2 import PSD2Extractor @@ -111,20 +110,16 @@ class Manager: Tagger(rules).transform_inplace(uncategorized) case Operation.BankMod: - with self.db.session() as session: - session.update(Bank, params) + self.database.update(Bank, params) case Operation.PSD2Mod: - with self.db.session() as session: - session.update(Nordigen, params) + self.database.update(Nordigen, params) case Operation.BankDel: - with self.db.session() as session: - session.remove_by_name(Bank, params) + self.database.delete(Bank, Bank.name, params) case Operation.PSD2Del: - with self.db.session() as session: - session.remove_by_name(Nordigen, params) + self.database.delete(Nordigen, Nordigen.name, params) case Operation.Token: Manager.nordigen_client().generate_token() @@ -150,40 +145,28 @@ class Manager: self.database.insert(params) case Operation.CategoryUpdate: - with self.db.session() as session: - session.updategroup(*params) + self.database.update(Category, params) case Operation.CategoryRemove: - with self.db.session() as session: - session.remove_by_name(Category, params) + self.database.delete(Category, Category.name, params) case Operation.CategorySchedule: - with self.db.session() as session: - session.updateschedules(params) + raise NotImplementedError case Operation.RuleRemove: - assert all(isinstance(param, int) for param in params) - with self.db.session() as session: - session.remove_by_id(CategoryRule, params) + self.database.delete(CategoryRule, CategoryRule.id, params) case Operation.TagRemove: - with self.db.session() as session: - session.remove_by_name(Tag, params) + self.database.delete(Tag, Tag.name, params) case Operation.TagRuleRemove: - assert all(isinstance(param, int) for param in params) - with self.db.session() as session: - session.remove_by_id(TagRule, params) + self.database.delete(TagRule, TagRule.id, params) case Operation.RuleModify | Operation.TagRuleModify: - assert all(isinstance(param, dict) for param in params) - with self.db.session() as session: - session.update(Rule, params) + self.database.update(Rule, params) case Operation.GroupRemove: - assert all(isinstance(param, CategoryGroup) for param in params) - with self.db.session() as session: - session.remove_by_name(CategoryGroup, params) + self.database.delete(CategoryGroup, CategoryGroup.name, params) case Operation.Forge: if not ( @@ -192,9 +175,14 @@ class Manager: ): raise TypeError("f{params} are not transaction ids") - with self.db.session() as session: - original = session.get(Transaction, Transaction.id, params[0])[0] - links = session.get(Transaction, Transaction.id, params[1]) + with self.database.session as session: + id = params[0] + original = session.select( + Transaction, lambda: Transaction.id == id + )[0] + + ids = params[1] + links = session.select(Transaction, lambda: Transaction.id.in_(ids)) if not original.category: original.category = self.askcategory(original) @@ -214,12 +202,7 @@ class Manager: session.insert(tobelinked) case Operation.Dismantle: - assert all(isinstance(param, Link) for param in params) - - with self.db.session() as session: - original = params[0].original - links = [link.link for link in params] - session.remove_links(original, links) + raise NotImplementedError case Operation.Split: if len(params) < 1 and not all( @@ -234,8 +217,10 @@ class Manager: f"{original.amount}€ != {sum(v for v, _ in params[1:])}€" ) - with self.db.session() as session: - originals = session.get(Transaction, Transaction.id, [original.id]) + with self.database.session as session: + originals = session.select( + Transaction, lambda: Transaction.id == original.id + ) assert len(originals) == 1, ">1 transactions matched {original.id}!" originals[0].split = True @@ -293,8 +278,7 @@ class Manager: transactions.append(transaction) if self.certify(transactions): - with self.db.session() as session: - session.insert(transactions) + self.database.insert(transactions) case Operation.ExportBanks: with self.database.session as session: @@ -309,8 +293,7 @@ class Manager: banks.append(bank) if self.certify(banks): - with self.db.session() as session: - session.insert(banks) + self.database.insert(banks) case Operation.ExportCategoryRules: with self.database.session as session: @@ -324,8 +307,7 @@ class Manager: rules = [CategoryRule(**row) for row in self.load(params[0], params[1])] if self.certify(rules): - with self.db.session() as session: - session.insert(rules) + self.database.insert(rules) case Operation.ExportTagRules: with self.database.session as session: @@ -337,8 +319,7 @@ class Manager: rules = [TagRule(**row) for row in self.load(params[0], params[1])] if self.certify(rules): - with self.db.session() as session: - session.insert(rules) + self.database.insert(rules) case Operation.ExportCategories: with self.database.session as session: @@ -363,8 +344,7 @@ class Manager: categories.append(category) if self.certify(categories): - with self.db.session() as session: - session.insert(categories) + self.database.insert(categories) case Operation.ExportCategoryGroups: with self.database.session as session: @@ -380,8 +360,7 @@ class Manager: ] if self.certify(groups): - with self.db.session() as session: - session.insert(groups) + self.database.insert(groups) def parse(self, filename: Path, args: dict): return parse_data(filename, args) @@ -389,13 +368,12 @@ class Manager: def askcategory(self, transaction: Transaction): selector = CategorySelector(Selector_T.manual) - with self.db.session() as session: - categories = session.get(Category) + categories = self.database.select(Category) - while True: - category = input(f"{transaction}: ") - if category in [c.name for c in categories]: - return TransactionCategory(category, selector) + while True: + category = input(f"{transaction}: ") + if category in [c.name for c in categories]: + return TransactionCategory(category, selector) @staticmethod def dump(fn, format, sequence): @@ -428,20 +406,12 @@ class Manager: return True return False - @property - def db(self) -> DbClient: - return DbClient(self._db, self._verbosity > 2) - @property def database(self) -> Client: if not self._database: self._database = Client(self._db, echo=self._verbosity > 2) return self._database - @db.setter - def db(self, url: str): - self._db = url - @staticmethod def nordigen_client() -> NordigenClient: return NordigenClient(NordigenCredentialsManager.default) diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 90339bd..e911e31 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from copy import deepcopy -from sqlalchemy import Engine, create_engine, select +from sqlalchemy import Engine, create_engine, delete, select, update from sqlalchemy.orm import Session, sessionmaker -from typing import Any, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Type, TypeVar # from pfbudget.db.exceptions import InsertError, SelectError @@ -52,6 +52,14 @@ class Client: def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]: return self.session.select(what, exists) + def update(self, what: Type[Any], values: Sequence[Mapping[str, Any]]) -> None: + with self._sessionmaker() as session, session.begin(): + session.execute(update(what), values) + + def delete(self, what: Type[Any], column: Any, values: Sequence[str]) -> None: + with self._sessionmaker() as session, session.begin(): + session.execute(delete(what).where(column.in_(values))) + @property def engine(self) -> Engine: return self._engine diff --git a/pfbudget/db/postgresql.py b/pfbudget/db/postgresql.py deleted file mode 100644 index 892c4e2..0000000 --- a/pfbudget/db/postgresql.py +++ /dev/null @@ -1,123 +0,0 @@ -from dataclasses import asdict -from sqlalchemy import create_engine, delete, select, update -from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import false -from typing import Sequence, Type, TypeVar - -from pfbudget.db.model import ( - Category, - CategoryGroup, - CategorySchedule, - Link, - Transaction, -) - - -class DbClient: - """ - General database client using sqlalchemy - """ - - __sessions: list[Session] - - def __init__(self, url: str, echo=False) -> None: - self._engine = create_engine(url, echo=echo) - - @property - def engine(self): - return self._engine - - class ClientSession: - def __init__(self, engine): - self.__engine = engine - - def __enter__(self): - self.__session = Session(self.__engine) - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - self.commit() - self.__session.close() - - def commit(self): - self.__session.commit() - - def expunge_all(self): - self.__session.expunge_all() - - T = TypeVar("T") - - def get(self, type: Type[T], column=None, values=None) -> Sequence[T]: - if column is not None: - if values: - if isinstance(values, Sequence): - stmt = select(type).where(column.in_(values)) - else: - stmt = select(type).where(column == values) - else: - stmt = select(type).where(column) - else: - stmt = select(type) - - return self.__session.scalars(stmt).all() - - def uncategorized(self) -> Sequence[Transaction]: - """Selects all valid uncategorized transactions - At this moment that includes: - - Categories w/o category - - AND non-split categories - - Returns: - Sequence[Transaction]: transactions left uncategorized - """ - stmt = ( - select(Transaction) - .where(~Transaction.category.has()) - .where(Transaction.split == false()) - ) - return self.__session.scalars(stmt).all() - - def insert(self, rows: list): - self.__session.add_all(rows) - - def remove_by_name(self, type, rows: list): - stmt = delete(type).where(type.name.in_([row.name for row in rows])) - self.__session.execute(stmt) - - def updategroup(self, categories: list[Category], group: CategoryGroup): - stmt = ( - update(Category) - .where(Category.name.in_([cat.name for cat in categories])) - .values(group=group) - ) - self.__session.execute(stmt) - - def updateschedules(self, schedules: list[CategorySchedule]): - stmt = insert(CategorySchedule).values([asdict(s) for s in schedules]) - stmt = stmt.on_conflict_do_update( - index_elements=[CategorySchedule.name], - set_=dict( - recurring=stmt.excluded.recurring, - period=stmt.excluded.period, - period_multiplier=stmt.excluded.period_multiplier, - ), - ) - self.__session.execute(stmt) - - def remove_by_id(self, type, ids: list[int]): - stmt = delete(type).where(type.id.in_(ids)) - self.__session.execute(stmt) - - def update(self, type, values: list[dict]): - print(type, values) - self.__session.execute(update(type), values) - - def remove_links(self, original: int, links: list[int]): - stmt = delete(Link).where( - Link.original == original, Link.link.in_(link for link in links) - ) - self.__session.execute(stmt) - - def session(self) -> ClientSession: - return self.ClientSession(self.engine) diff --git a/tests/test_database.py b/tests/test_database.py index 8fc8f9e..e6c5ca1 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -96,3 +96,54 @@ class TestDatabase: names = [banks[0].name, banks[1].name] result = client.select(Bank, lambda: Bank.name.in_(names)) assert result == [banks[0], banks[1]] + + def test_update_bank_with_session(self, client: Client, banks: list[Bank]): + with client.session as session: + name = banks[0].name + bank = session.select(Bank, lambda: Bank.name == name)[0] + bank.name = "anotherbank" + + result = client.select(Bank, lambda: Bank.name == "anotherbank") + assert len(result) == 1 + + def test_update_bank(self, client: Client, banks: list[Bank]): + name = banks[0].name + + result = client.select(Bank, lambda: Bank.name == name) + assert result[0].type == AccountType.checking + + update = {"name": name, "type": AccountType.savings} + client.update(Bank, [update]) + + result = client.select(Bank, lambda: Bank.name == name) + assert result[0].type == AccountType.savings + + def test_update_nordigen(self, client: Client, banks: list[Bank]): + name = banks[0].name + + result = client.select(Nordigen, lambda: Nordigen.name == name) + assert result[0].requisition_id == "req" + + update = {"name": name, "requisition_id": "anotherreq"} + client.update(Nordigen, [update]) + + result = client.select(Nordigen, lambda: Nordigen.name == name) + assert result[0].requisition_id == "anotherreq" + + result = client.select(Bank, lambda: Bank.name == name) + assert getattr(result[0].nordigen, "requisition_id", None) == "anotherreq" + + def test_remove_bank(self, client: Client, banks: list[Bank]): + name = banks[0].name + + result = client.select(Bank) + assert len(result) == 3 + + client.delete(Bank, Bank.name, [name]) + result = client.select(Bank) + assert len(result) == 2 + + names = [banks[1].name, banks[2].name] + client.delete(Bank, Bank.name, names) + result = client.select(Bank) + assert len(result) == 0