diff --git a/pfbudget/__main__.py b/pfbudget/__main__.py index 19e1f42..1f788c5 100644 --- a/pfbudget/__main__.py +++ b/pfbudget/__main__.py @@ -53,7 +53,7 @@ if __name__ == "__main__": if not args["all"]: params.append(args["banks"]) else: - params.append([]) + params.append(None) case Operation.BankAdd: keys = {"bank", "bic", "type"} diff --git a/pfbudget/core/manager.py b/pfbudget/core/manager.py index d1d0fa3..3fac47f 100644 --- a/pfbudget/core/manager.py +++ b/pfbudget/core/manager.py @@ -72,18 +72,16 @@ class Manager: len(transactions) > 0 and input(f"{transactions[:5]}\nCommit? (y/n)") == "y" ): - with self.db.session() as session: - session.insert(sorted(transactions)) + self.database.insert(sorted(transactions)) case Operation.Download: - client = Manager.nordigen_client() - with self.db.session() as session: - if len(params[3]) == 0: - banks = session.get(Bank, Bank.nordigen) - else: - banks = session.get(Bank, Bank.name, params[3]) - session.expunge_all() + if params[3]: + values = params[3] + banks = self.database.select(Bank, lambda: Bank.name.in_(values)) + else: + banks = self.database.select(Bank, Bank.nordigen) + client = Manager.nordigen_client() extractor = PSD2Extractor(client) transactions = [] for bank in banks: @@ -91,18 +89,17 @@ class Manager: # dry-run if not params[2]: - with self.db.session() as session: - session.insert(sorted(transactions)) + self.database.insert(sorted(transactions)) else: print(sorted(transactions)) case Operation.Categorize: - with self.db.session() as session: - uncategorized = session.get( - BankTransaction, ~BankTransaction.category.has() + with self.database.session as session: + uncategorized = session.select( + BankTransaction, lambda: ~BankTransaction.category.has() ) - categories = session.get(Category) - tags = session.get(Tag) + categories = session.select(Category) + tags = session.select(Tag) rules = [cat.rules for cat in categories if cat.name == "null"] Nullifier(rules).transform_inplace(uncategorized) @@ -144,13 +141,13 @@ class Manager: case ( Operation.BankAdd | Operation.CategoryAdd + | Operation.GroupAdd | Operation.PSD2Add | Operation.RuleAdd | Operation.TagAdd | Operation.TagRuleAdd ): - with self.db.session() as session: - session.insert(params) + self.database.insert(params) case Operation.CategoryUpdate: with self.db.session() as session: @@ -183,10 +180,6 @@ class Manager: with self.db.session() as session: session.update(Rule, params) - case Operation.GroupAdd: - with self.db.session() as session: - session.insert(params) - case Operation.GroupRemove: assert all(isinstance(param, CategoryGroup) for param in params) with self.db.session() as session: @@ -436,7 +429,7 @@ class Manager: return False @property - def db(self) -> Client: + def db(self) -> DbClient: return DbClient(self._db, self._verbosity > 2) @property diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 6de48e0..90339bd 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -1,10 +1,10 @@ +from collections.abc import Sequence from copy import deepcopy from sqlalchemy import Engine, create_engine, select from sqlalchemy.orm import Session, sessionmaker -from typing import Any, Optional, Sequence, Type, TypeVar +from typing import Any, Optional, Type, TypeVar # from pfbudget.db.exceptions import InsertError, SelectError -from pfbudget.db.model import Transaction class DatabaseSession: @@ -22,15 +22,14 @@ class DatabaseSession: self.__session.commit() self.__session.close() - def insert(self, transactions: Sequence[Transaction]) -> None: - self.__session.add_all(transactions) + def insert(self, sequence: Sequence[Any]) -> None: + self.__session.add_all(sequence) T = TypeVar("T") - C = TypeVar("C") - def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]: + def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]: if exists: - stmt = select(what).where(exists) + stmt = select(what).filter(exists) else: stmt = select(what) @@ -41,17 +40,16 @@ class Client: def __init__(self, url: str, **kwargs: Any): assert url, "Database URL is empty!" self._engine = create_engine(url, **kwargs) - self._sessionmaker: Optional[sessionmaker[Session]] = None + self._sessionmaker = sessionmaker(self._engine) - def insert(self, transactions: Sequence[Transaction]) -> None: - new = deepcopy(transactions) + def insert(self, sequence: Sequence[Any]) -> None: + new = deepcopy(sequence) with self.session as session: session.insert(new) T = TypeVar("T") - C = TypeVar("C") - def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]: + def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]: return self.session.select(what, exists) @property @@ -60,7 +58,4 @@ class Client: @property def session(self) -> DatabaseSession: - if not self._sessionmaker: - self._sessionmaker = sessionmaker(self._engine) - return DatabaseSession(self._sessionmaker()) diff --git a/tests/test_database.py b/tests/test_database.py index 327b32b..8fc8f9e 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -3,7 +3,16 @@ from decimal import Decimal import pytest from pfbudget.db.client import Client -from pfbudget.db.model import Base, Transaction +from pfbudget.db.model import ( + AccountType, + Bank, + Base, + CategorySelector, + Nordigen, + Selector_T, + Transaction, + TransactionCategory, +) @pytest.fixture @@ -14,29 +23,76 @@ def client() -> Client: return client +@pytest.fixture +def banks(client: Client) -> list[Bank]: + banks = [ + Bank("bank", "BANK", AccountType.checking), + Bank("broker", "BROKER", AccountType.investment), + Bank("creditcard", "CC", AccountType.MASTERCARD), + ] + banks[0].nordigen = Nordigen("bank", None, "req", None) + + client.insert(banks) + return banks + + +@pytest.fixture +def transactions(client: Client) -> list[Transaction]: + transactions = [ + Transaction(date(2023, 1, 1), "", Decimal("-10")), + Transaction(date(2023, 1, 2), "", Decimal("-50")), + ] + transactions[0].category = TransactionCategory( + "name", CategorySelector(Selector_T.algorithm) + ) + + client.insert(transactions) + for i, transaction in enumerate(transactions): + transaction.id = i + 1 + transaction.split = False # default + transactions[0].category.id = 1 + transactions[0].category.selector.id = 1 + + return transactions + + class TestDatabase: def test_initialization(self, client: Client): pass - def test_insert_transactions(self, client: Client): + def test_insert_with_session(self, client: Client): transactions = [ - Transaction(date(2023, 1, 1), "", Decimal("-500")), - Transaction(date(2023, 1, 2), "", Decimal("500")), + Transaction(date(2023, 1, 1), "", Decimal("-10")), + Transaction(date(2023, 1, 2), "", Decimal("-50")), ] with client.session as session: session.insert(transactions) assert session.select(Transaction) == transactions - def test_insert_transactions_independent_sessions(self, client: Client): - transactions = [ - Transaction(date(2023, 1, 1), "", Decimal("-500")), - Transaction(date(2023, 1, 2), "", Decimal("500")), - ] - - client.insert(transactions) + def test_insert_transactions(self, client: Client, transactions: list[Transaction]): result = client.select(Transaction) - for i, transaction in enumerate(result): - assert transactions[i].date == transaction.date - assert transactions[i].description == transaction.description - assert transactions[i].amount == transaction.amount + assert result == transactions + + def test_select_transactions_without_category( + self, client: Client, transactions: list[Transaction] + ): + result = client.select(Transaction, lambda: ~Transaction.category.has()) + assert result == [transactions[1]] + + def test_select_banks(self, client: Client, banks: list[Bank]): + result = client.select(Bank) + assert result == banks + + def test_select_banks_with_nordigen(self, client: Client, banks: list[Bank]): + result = client.select(Bank, Bank.nordigen) + assert result == [banks[0]] + + def test_select_banks_by_name(self, client: Client, banks: list[Bank]): + name = banks[0].name + result = client.select(Bank, lambda: Bank.name == name) + assert result == [banks[0]] + + names = [banks[0].name, banks[1].name] + result = client.select(Bank, lambda: Bank.name.in_(names)) + assert result == [banks[0], banks[1]]