Generalize insert in new Database interface
Move over all inserts on manager.py to new interface.
This commit is contained in:
parent
78ff6faa12
commit
94322ae542
@ -53,7 +53,7 @@ if __name__ == "__main__":
|
||||
if not args["all"]:
|
||||
params.append(args["banks"])
|
||||
else:
|
||||
params.append([])
|
||||
params.append(None)
|
||||
|
||||
case Operation.BankAdd:
|
||||
keys = {"bank", "bic", "type"}
|
||||
|
||||
@ -72,18 +72,16 @@ class Manager:
|
||||
len(transactions) > 0
|
||||
and input(f"{transactions[:5]}\nCommit? (y/n)") == "y"
|
||||
):
|
||||
with self.db.session() as session:
|
||||
session.insert(sorted(transactions))
|
||||
self.database.insert(sorted(transactions))
|
||||
|
||||
case Operation.Download:
|
||||
client = Manager.nordigen_client()
|
||||
with self.db.session() as session:
|
||||
if len(params[3]) == 0:
|
||||
banks = session.get(Bank, Bank.nordigen)
|
||||
if params[3]:
|
||||
values = params[3]
|
||||
banks = self.database.select(Bank, lambda: Bank.name.in_(values))
|
||||
else:
|
||||
banks = session.get(Bank, Bank.name, params[3])
|
||||
session.expunge_all()
|
||||
banks = self.database.select(Bank, Bank.nordigen)
|
||||
|
||||
client = Manager.nordigen_client()
|
||||
extractor = PSD2Extractor(client)
|
||||
transactions = []
|
||||
for bank in banks:
|
||||
@ -91,18 +89,17 @@ class Manager:
|
||||
|
||||
# dry-run
|
||||
if not params[2]:
|
||||
with self.db.session() as session:
|
||||
session.insert(sorted(transactions))
|
||||
self.database.insert(sorted(transactions))
|
||||
else:
|
||||
print(sorted(transactions))
|
||||
|
||||
case Operation.Categorize:
|
||||
with self.db.session() as session:
|
||||
uncategorized = session.get(
|
||||
BankTransaction, ~BankTransaction.category.has()
|
||||
with self.database.session as session:
|
||||
uncategorized = session.select(
|
||||
BankTransaction, lambda: ~BankTransaction.category.has()
|
||||
)
|
||||
categories = session.get(Category)
|
||||
tags = session.get(Tag)
|
||||
categories = session.select(Category)
|
||||
tags = session.select(Tag)
|
||||
|
||||
rules = [cat.rules for cat in categories if cat.name == "null"]
|
||||
Nullifier(rules).transform_inplace(uncategorized)
|
||||
@ -144,13 +141,13 @@ class Manager:
|
||||
case (
|
||||
Operation.BankAdd
|
||||
| Operation.CategoryAdd
|
||||
| Operation.GroupAdd
|
||||
| Operation.PSD2Add
|
||||
| Operation.RuleAdd
|
||||
| Operation.TagAdd
|
||||
| Operation.TagRuleAdd
|
||||
):
|
||||
with self.db.session() as session:
|
||||
session.insert(params)
|
||||
self.database.insert(params)
|
||||
|
||||
case Operation.CategoryUpdate:
|
||||
with self.db.session() as session:
|
||||
@ -183,10 +180,6 @@ class Manager:
|
||||
with self.db.session() as session:
|
||||
session.update(Rule, params)
|
||||
|
||||
case Operation.GroupAdd:
|
||||
with self.db.session() as session:
|
||||
session.insert(params)
|
||||
|
||||
case Operation.GroupRemove:
|
||||
assert all(isinstance(param, CategoryGroup) for param in params)
|
||||
with self.db.session() as session:
|
||||
@ -436,7 +429,7 @@ class Manager:
|
||||
return False
|
||||
|
||||
@property
|
||||
def db(self) -> Client:
|
||||
def db(self) -> DbClient:
|
||||
return DbClient(self._db, self._verbosity > 2)
|
||||
|
||||
@property
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from sqlalchemy import Engine, create_engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing import Any, Optional, Sequence, Type, TypeVar
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
|
||||
# from pfbudget.db.exceptions import InsertError, SelectError
|
||||
from pfbudget.db.model import Transaction
|
||||
|
||||
|
||||
class DatabaseSession:
|
||||
@ -22,15 +22,14 @@ class DatabaseSession:
|
||||
self.__session.commit()
|
||||
self.__session.close()
|
||||
|
||||
def insert(self, transactions: Sequence[Transaction]) -> None:
|
||||
self.__session.add_all(transactions)
|
||||
def insert(self, sequence: Sequence[Any]) -> None:
|
||||
self.__session.add_all(sequence)
|
||||
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C")
|
||||
|
||||
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
|
||||
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
|
||||
if exists:
|
||||
stmt = select(what).where(exists)
|
||||
stmt = select(what).filter(exists)
|
||||
else:
|
||||
stmt = select(what)
|
||||
|
||||
@ -41,17 +40,16 @@ class Client:
|
||||
def __init__(self, url: str, **kwargs: Any):
|
||||
assert url, "Database URL is empty!"
|
||||
self._engine = create_engine(url, **kwargs)
|
||||
self._sessionmaker: Optional[sessionmaker[Session]] = None
|
||||
self._sessionmaker = sessionmaker(self._engine)
|
||||
|
||||
def insert(self, transactions: Sequence[Transaction]) -> None:
|
||||
new = deepcopy(transactions)
|
||||
def insert(self, sequence: Sequence[Any]) -> None:
|
||||
new = deepcopy(sequence)
|
||||
with self.session as session:
|
||||
session.insert(new)
|
||||
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C")
|
||||
|
||||
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
|
||||
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
|
||||
return self.session.select(what, exists)
|
||||
|
||||
@property
|
||||
@ -60,7 +58,4 @@ class Client:
|
||||
|
||||
@property
|
||||
def session(self) -> DatabaseSession:
|
||||
if not self._sessionmaker:
|
||||
self._sessionmaker = sessionmaker(self._engine)
|
||||
|
||||
return DatabaseSession(self._sessionmaker())
|
||||
|
||||
@ -3,7 +3,16 @@ from decimal import Decimal
|
||||
import pytest
|
||||
|
||||
from pfbudget.db.client import Client
|
||||
from pfbudget.db.model import Base, Transaction
|
||||
from pfbudget.db.model import (
|
||||
AccountType,
|
||||
Bank,
|
||||
Base,
|
||||
CategorySelector,
|
||||
Nordigen,
|
||||
Selector_T,
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -14,29 +23,76 @@ def client() -> Client:
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def banks(client: Client) -> list[Bank]:
|
||||
banks = [
|
||||
Bank("bank", "BANK", AccountType.checking),
|
||||
Bank("broker", "BROKER", AccountType.investment),
|
||||
Bank("creditcard", "CC", AccountType.MASTERCARD),
|
||||
]
|
||||
banks[0].nordigen = Nordigen("bank", None, "req", None)
|
||||
|
||||
client.insert(banks)
|
||||
return banks
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transactions(client: Client) -> list[Transaction]:
|
||||
transactions = [
|
||||
Transaction(date(2023, 1, 1), "", Decimal("-10")),
|
||||
Transaction(date(2023, 1, 2), "", Decimal("-50")),
|
||||
]
|
||||
transactions[0].category = TransactionCategory(
|
||||
"name", CategorySelector(Selector_T.algorithm)
|
||||
)
|
||||
|
||||
client.insert(transactions)
|
||||
for i, transaction in enumerate(transactions):
|
||||
transaction.id = i + 1
|
||||
transaction.split = False # default
|
||||
transactions[0].category.id = 1
|
||||
transactions[0].category.selector.id = 1
|
||||
|
||||
return transactions
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
def test_initialization(self, client: Client):
|
||||
pass
|
||||
|
||||
def test_insert_transactions(self, client: Client):
|
||||
def test_insert_with_session(self, client: Client):
|
||||
transactions = [
|
||||
Transaction(date(2023, 1, 1), "", Decimal("-500")),
|
||||
Transaction(date(2023, 1, 2), "", Decimal("500")),
|
||||
Transaction(date(2023, 1, 1), "", Decimal("-10")),
|
||||
Transaction(date(2023, 1, 2), "", Decimal("-50")),
|
||||
]
|
||||
|
||||
with client.session as session:
|
||||
session.insert(transactions)
|
||||
assert session.select(Transaction) == transactions
|
||||
|
||||
def test_insert_transactions_independent_sessions(self, client: Client):
|
||||
transactions = [
|
||||
Transaction(date(2023, 1, 1), "", Decimal("-500")),
|
||||
Transaction(date(2023, 1, 2), "", Decimal("500")),
|
||||
]
|
||||
|
||||
client.insert(transactions)
|
||||
def test_insert_transactions(self, client: Client, transactions: list[Transaction]):
|
||||
result = client.select(Transaction)
|
||||
for i, transaction in enumerate(result):
|
||||
assert transactions[i].date == transaction.date
|
||||
assert transactions[i].description == transaction.description
|
||||
assert transactions[i].amount == transaction.amount
|
||||
assert result == transactions
|
||||
|
||||
def test_select_transactions_without_category(
|
||||
self, client: Client, transactions: list[Transaction]
|
||||
):
|
||||
result = client.select(Transaction, lambda: ~Transaction.category.has())
|
||||
assert result == [transactions[1]]
|
||||
|
||||
def test_select_banks(self, client: Client, banks: list[Bank]):
|
||||
result = client.select(Bank)
|
||||
assert result == banks
|
||||
|
||||
def test_select_banks_with_nordigen(self, client: Client, banks: list[Bank]):
|
||||
result = client.select(Bank, Bank.nordigen)
|
||||
assert result == [banks[0]]
|
||||
|
||||
def test_select_banks_by_name(self, client: Client, banks: list[Bank]):
|
||||
name = banks[0].name
|
||||
result = client.select(Bank, lambda: Bank.name == name)
|
||||
assert result == [banks[0]]
|
||||
|
||||
names = [banks[0].name, banks[1].name]
|
||||
result = client.select(Bank, lambda: Bank.name.in_(names))
|
||||
assert result == [banks[0], banks[1]]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user