Generalize insert in new Database interface

Move over all inserts on manager.py to new interface.
This commit is contained in:
Luís Murta 2023-04-29 01:21:52 +01:00
parent 78ff6faa12
commit 94322ae542
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
4 changed files with 98 additions and 54 deletions

View File

@ -53,7 +53,7 @@ if __name__ == "__main__":
if not args["all"]:
params.append(args["banks"])
else:
params.append([])
params.append(None)
case Operation.BankAdd:
keys = {"bank", "bic", "type"}

View File

@ -72,18 +72,16 @@ class Manager:
len(transactions) > 0
and input(f"{transactions[:5]}\nCommit? (y/n)") == "y"
):
with self.db.session() as session:
session.insert(sorted(transactions))
self.database.insert(sorted(transactions))
case Operation.Download:
client = Manager.nordigen_client()
with self.db.session() as session:
if len(params[3]) == 0:
banks = session.get(Bank, Bank.nordigen)
if params[3]:
values = params[3]
banks = self.database.select(Bank, lambda: Bank.name.in_(values))
else:
banks = session.get(Bank, Bank.name, params[3])
session.expunge_all()
banks = self.database.select(Bank, Bank.nordigen)
client = Manager.nordigen_client()
extractor = PSD2Extractor(client)
transactions = []
for bank in banks:
@ -91,18 +89,17 @@ class Manager:
# dry-run
if not params[2]:
with self.db.session() as session:
session.insert(sorted(transactions))
self.database.insert(sorted(transactions))
else:
print(sorted(transactions))
case Operation.Categorize:
with self.db.session() as session:
uncategorized = session.get(
BankTransaction, ~BankTransaction.category.has()
with self.database.session as session:
uncategorized = session.select(
BankTransaction, lambda: ~BankTransaction.category.has()
)
categories = session.get(Category)
tags = session.get(Tag)
categories = session.select(Category)
tags = session.select(Tag)
rules = [cat.rules for cat in categories if cat.name == "null"]
Nullifier(rules).transform_inplace(uncategorized)
@ -144,13 +141,13 @@ class Manager:
case (
Operation.BankAdd
| Operation.CategoryAdd
| Operation.GroupAdd
| Operation.PSD2Add
| Operation.RuleAdd
| Operation.TagAdd
| Operation.TagRuleAdd
):
with self.db.session() as session:
session.insert(params)
self.database.insert(params)
case Operation.CategoryUpdate:
with self.db.session() as session:
@ -183,10 +180,6 @@ class Manager:
with self.db.session() as session:
session.update(Rule, params)
case Operation.GroupAdd:
with self.db.session() as session:
session.insert(params)
case Operation.GroupRemove:
assert all(isinstance(param, CategoryGroup) for param in params)
with self.db.session() as session:
@ -436,7 +429,7 @@ class Manager:
return False
@property
def db(self) -> Client:
def db(self) -> DbClient:
return DbClient(self._db, self._verbosity > 2)
@property

View File

@ -1,10 +1,10 @@
from collections.abc import Sequence
from copy import deepcopy
from sqlalchemy import Engine, create_engine, select
from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Sequence, Type, TypeVar
from typing import Any, Optional, Type, TypeVar
# from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.model import Transaction
class DatabaseSession:
@ -22,15 +22,14 @@ class DatabaseSession:
self.__session.commit()
self.__session.close()
def insert(self, transactions: Sequence[Transaction]) -> None:
self.__session.add_all(transactions)
def insert(self, sequence: Sequence[Any]) -> None:
self.__session.add_all(sequence)
T = TypeVar("T")
C = TypeVar("C")
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
if exists:
stmt = select(what).where(exists)
stmt = select(what).filter(exists)
else:
stmt = select(what)
@ -41,17 +40,16 @@ class Client:
def __init__(self, url: str, **kwargs: Any):
assert url, "Database URL is empty!"
self._engine = create_engine(url, **kwargs)
self._sessionmaker: Optional[sessionmaker[Session]] = None
self._sessionmaker = sessionmaker(self._engine)
def insert(self, transactions: Sequence[Transaction]) -> None:
new = deepcopy(transactions)
def insert(self, sequence: Sequence[Any]) -> None:
new = deepcopy(sequence)
with self.session as session:
session.insert(new)
T = TypeVar("T")
C = TypeVar("C")
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
return self.session.select(what, exists)
@property
@ -60,7 +58,4 @@ class Client:
@property
def session(self) -> DatabaseSession:
if not self._sessionmaker:
self._sessionmaker = sessionmaker(self._engine)
return DatabaseSession(self._sessionmaker())

View File

@ -3,7 +3,16 @@ from decimal import Decimal
import pytest
from pfbudget.db.client import Client
from pfbudget.db.model import Base, Transaction
from pfbudget.db.model import (
AccountType,
Bank,
Base,
CategorySelector,
Nordigen,
Selector_T,
Transaction,
TransactionCategory,
)
@pytest.fixture
@ -14,29 +23,76 @@ def client() -> Client:
return client
@pytest.fixture
def banks(client: Client) -> list[Bank]:
banks = [
Bank("bank", "BANK", AccountType.checking),
Bank("broker", "BROKER", AccountType.investment),
Bank("creditcard", "CC", AccountType.MASTERCARD),
]
banks[0].nordigen = Nordigen("bank", None, "req", None)
client.insert(banks)
return banks
@pytest.fixture
def transactions(client: Client) -> list[Transaction]:
transactions = [
Transaction(date(2023, 1, 1), "", Decimal("-10")),
Transaction(date(2023, 1, 2), "", Decimal("-50")),
]
transactions[0].category = TransactionCategory(
"name", CategorySelector(Selector_T.algorithm)
)
client.insert(transactions)
for i, transaction in enumerate(transactions):
transaction.id = i + 1
transaction.split = False # default
transactions[0].category.id = 1
transactions[0].category.selector.id = 1
return transactions
class TestDatabase:
def test_initialization(self, client: Client):
pass
def test_insert_transactions(self, client: Client):
def test_insert_with_session(self, client: Client):
transactions = [
Transaction(date(2023, 1, 1), "", Decimal("-500")),
Transaction(date(2023, 1, 2), "", Decimal("500")),
Transaction(date(2023, 1, 1), "", Decimal("-10")),
Transaction(date(2023, 1, 2), "", Decimal("-50")),
]
with client.session as session:
session.insert(transactions)
assert session.select(Transaction) == transactions
def test_insert_transactions_independent_sessions(self, client: Client):
transactions = [
Transaction(date(2023, 1, 1), "", Decimal("-500")),
Transaction(date(2023, 1, 2), "", Decimal("500")),
]
client.insert(transactions)
def test_insert_transactions(self, client: Client, transactions: list[Transaction]):
result = client.select(Transaction)
for i, transaction in enumerate(result):
assert transactions[i].date == transaction.date
assert transactions[i].description == transaction.description
assert transactions[i].amount == transaction.amount
assert result == transactions
def test_select_transactions_without_category(
self, client: Client, transactions: list[Transaction]
):
result = client.select(Transaction, lambda: ~Transaction.category.has())
assert result == [transactions[1]]
def test_select_banks(self, client: Client, banks: list[Bank]):
result = client.select(Bank)
assert result == banks
def test_select_banks_with_nordigen(self, client: Client, banks: list[Bank]):
result = client.select(Bank, Bank.nordigen)
assert result == [banks[0]]
def test_select_banks_by_name(self, client: Client, banks: list[Bank]):
name = banks[0].name
result = client.select(Bank, lambda: Bank.name == name)
assert result == [banks[0]]
names = [banks[0].name, banks[1].name]
result = client.select(Bank, lambda: Bank.name.in_(names))
assert result == [banks[0], banks[1]]