[Refactor] Database client interface changed

`add` method replaced with `insert`.
`insert` and `select` implemented for new database base class.
Database unit test added.

Due to SQLite implementation of the primary key autoinc, the type of the
IDs on the database for SQLite changed to Integer.
https://www.sqlite.org/autoinc.html
This commit is contained in:
Luís Murta 2023-04-27 17:47:13 +01:00
parent 9c7c06c181
commit e7abae0d17
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
8 changed files with 125 additions and 22 deletions

View File

@ -57,7 +57,7 @@ class Interactive:
case "split": case "split":
new = self.split(next) new = self.split(next)
session.add(new) session.insert(new)
case other: case other:
if not other: if not other:
@ -84,7 +84,7 @@ class Interactive:
) )
for tag in tags: for tag in tags:
if tag not in [t.name for t in self.tags]: if tag not in [t.name for t in self.tags]:
session.add([Tag(tag)]) session.insert([Tag(tag)])
self.tags = session.get(Tag) self.tags = session.get(Tag)
next.tags.add(TransactionTag(tag)) next.tags.add(TransactionTag(tag))

View File

@ -72,7 +72,7 @@ class Manager:
and input(f"{transactions[:5]}\nCommit? (y/n)") == "y" and input(f"{transactions[:5]}\nCommit? (y/n)") == "y"
): ):
with self.db.session() as session: with self.db.session() as session:
session.add(sorted(transactions)) session.insert(sorted(transactions))
case Operation.Download: case Operation.Download:
client = Manager.nordigen_client() client = Manager.nordigen_client()
@ -91,7 +91,7 @@ class Manager:
# dry-run # dry-run
if not params[2]: if not params[2]:
with self.db.session() as session: with self.db.session() as session:
session.add(sorted(transactions)) session.insert(sorted(transactions))
else: else:
print(sorted(transactions)) print(sorted(transactions))
@ -149,7 +149,7 @@ class Manager:
| Operation.TagRuleAdd | Operation.TagRuleAdd
): ):
with self.db.session() as session: with self.db.session() as session:
session.add(params) session.insert(params)
case Operation.CategoryUpdate: case Operation.CategoryUpdate:
with self.db.session() as session: with self.db.session() as session:
@ -184,7 +184,7 @@ class Manager:
case Operation.GroupAdd: case Operation.GroupAdd:
with self.db.session() as session: with self.db.session() as session:
session.add(params) session.insert(params)
case Operation.GroupRemove: case Operation.GroupRemove:
assert all(isinstance(param, CategoryGroup) for param in params) assert all(isinstance(param, CategoryGroup) for param in params)
@ -217,7 +217,7 @@ class Manager:
link.category = original.category link.category = original.category
tobelinked = [Link(original.id, link.id) for link in links] tobelinked = [Link(original.id, link.id) for link in links]
session.add(tobelinked) session.insert(tobelinked)
case Operation.Dismantle: case Operation.Dismantle:
assert all(isinstance(param, Link) for param in params) assert all(isinstance(param, Link) for param in params)
@ -260,7 +260,7 @@ class Manager:
splitted.category = t.category splitted.category = t.category
transactions.append(splitted) transactions.append(splitted)
session.add(transactions) session.insert(transactions)
case Operation.Export: case Operation.Export:
with self.db.session() as session: with self.db.session() as session:
@ -298,7 +298,7 @@ class Manager:
if self.certify(transactions): if self.certify(transactions):
with self.db.session() as session: with self.db.session() as session:
session.add(transactions) session.insert(transactions)
case Operation.ExportBanks: case Operation.ExportBanks:
with self.db.session() as session: with self.db.session() as session:
@ -314,7 +314,7 @@ class Manager:
if self.certify(banks): if self.certify(banks):
with self.db.session() as session: with self.db.session() as session:
session.add(banks) session.insert(banks)
case Operation.ExportCategoryRules: case Operation.ExportCategoryRules:
with self.db.session() as session: with self.db.session() as session:
@ -325,7 +325,7 @@ class Manager:
if self.certify(rules): if self.certify(rules):
with self.db.session() as session: with self.db.session() as session:
session.add(rules) session.insert(rules)
case Operation.ExportTagRules: case Operation.ExportTagRules:
with self.db.session() as session: with self.db.session() as session:
@ -336,7 +336,7 @@ class Manager:
if self.certify(rules): if self.certify(rules):
with self.db.session() as session: with self.db.session() as session:
session.add(rules) session.insert(rules)
case Operation.ExportCategories: case Operation.ExportCategories:
with self.db.session() as session: with self.db.session() as session:
@ -360,7 +360,7 @@ class Manager:
if self.certify(categories): if self.certify(categories):
with self.db.session() as session: with self.db.session() as session:
session.add(categories) session.insert(categories)
case Operation.ExportCategoryGroups: case Operation.ExportCategoryGroups:
with self.db.session() as session: with self.db.session() as session:
@ -373,7 +373,7 @@ class Manager:
if self.certify(groups): if self.certify(groups):
with self.db.session() as session: with self.db.session() as session:
session.add(groups) session.insert(groups)
def parse(self, filename: Path, args: dict): def parse(self, filename: Path, args: dict):
return parse_data(filename, args) return parse_data(filename, args)

View File

@ -1,11 +1,70 @@
from typing import Sequence from sqlalchemy import Engine, create_engine, select, text
from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Sequence, Type, TypeVar
from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.model import Transaction from pfbudget.db.model import Transaction
class Client: class Client:
def __init__(self, url: str) -> None: def __init__(self, url: str, **kwargs: dict[str, Any]) -> None:
self.url = url assert url, "Database URL is empty!"
self._engine = create_engine(url, **kwargs)
self._sessionmaker: Optional[sessionmaker[Session]] = None
def insert(self, transactions: Sequence[Transaction]) -> None: def insert(
raise NotImplementedError self, transactions: Sequence[Transaction], session: Optional[Session] = None
) -> None:
if not session:
with self.session as session_:
try:
session_.add_all(transactions)
except Exception as e:
session_.rollback()
raise InsertError(e)
else:
session_.commit()
else:
try:
session.add_all(transactions)
except Exception as e:
session.rollback()
raise InsertError(e)
else:
session.commit()
T = TypeVar("T")
def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]:
stmt = select(what)
result: Sequence[what] = []
if not session:
with self.session as session_:
try:
result = session_.scalars(stmt).all()
except Exception as e:
session_.rollback()
raise SelectError(e)
else:
session_.commit()
else:
try:
result = session.scalars(stmt).all()
except Exception as e:
session.rollback()
raise SelectError(e)
else:
session.commit()
return result
@property
def engine(self) -> Engine:
return self._engine
@property
def session(self) -> Session:
if not self._sessionmaker:
self._sessionmaker = sessionmaker(self._engine)
return self._sessionmaker()

View File

@ -0,0 +1,6 @@
class InsertError(Exception):
pass
class SelectError(Exception):
pass

View File

@ -9,6 +9,7 @@ from sqlalchemy import (
BigInteger, BigInteger,
Enum, Enum,
ForeignKey, ForeignKey,
Integer,
MetaData, MetaData,
Numeric, Numeric,
String, String,
@ -78,7 +79,14 @@ class Bank(Base, Export):
bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))] bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))]
idpk = Annotated[int, mapped_column(BigInteger, primary_key=True, autoincrement=True)] idpk = Annotated[
int,
mapped_column(
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
autoincrement=True,
),
]
money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))] money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))]

View File

@ -78,7 +78,7 @@ class DbClient:
) )
return self.__session.scalars(stmt).all() return self.__session.scalars(stmt).all()
def add(self, rows: list): def insert(self, rows: list):
self.__session.add_all(rows) self.__session.add_all(rows)
def remove_by_name(self, type, rows: list): def remove_by_name(self, type, rows: list):

29
tests/test_database.py Normal file
View File

@ -0,0 +1,29 @@
from datetime import date
from decimal import Decimal
import pytest
from pfbudget.db.client import Client
from pfbudget.db.model import Base, Transaction
@pytest.fixture
def client() -> Client:
url = "sqlite://"
client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}})
Base.metadata.create_all(client.engine)
return client
class TestDatabase:
def test_initialization(self, client: Client):
pass
def test_insert_transactions(self, client: Client):
transactions = [
Transaction(date(2023, 1, 1), "", Decimal("-500")),
Transaction(date(2023, 1, 2), "", Decimal("500")),
]
with client.session as session:
client.insert(transactions, session)
assert client.select(Transaction, session) == transactions

View File

@ -26,7 +26,8 @@ def loader() -> Loader:
class TestDatabaseLoad: class TestDatabaseLoad:
def test_empty_url(self): def test_empty_url(self):
_ = FakeDatabaseClient("") with pytest.raises(AssertionError):
_ = FakeDatabaseClient("")
def test_insert(self, loader: Loader): def test_insert(self, loader: Loader):
transactions = [ transactions = [