[Refactor] Database client interface changed

`add` method replaced with `insert`.
`insert` and `select` implemented for new database base class.
Database unit test added.

Due to SQLite implementation of the primary key autoinc, the type of the
IDs on the database for SQLite changed to Integer.
https://www.sqlite.org/autoinc.html
This commit is contained in:
Luís Murta 2023-04-27 17:47:13 +01:00
parent 9c7c06c181
commit e7abae0d17
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
8 changed files with 125 additions and 22 deletions

View File

@ -57,7 +57,7 @@ class Interactive:
case "split":
new = self.split(next)
session.add(new)
session.insert(new)
case other:
if not other:
@ -84,7 +84,7 @@ class Interactive:
)
for tag in tags:
if tag not in [t.name for t in self.tags]:
session.add([Tag(tag)])
session.insert([Tag(tag)])
self.tags = session.get(Tag)
next.tags.add(TransactionTag(tag))

View File

@ -72,7 +72,7 @@ class Manager:
and input(f"{transactions[:5]}\nCommit? (y/n)") == "y"
):
with self.db.session() as session:
session.add(sorted(transactions))
session.insert(sorted(transactions))
case Operation.Download:
client = Manager.nordigen_client()
@ -91,7 +91,7 @@ class Manager:
# dry-run
if not params[2]:
with self.db.session() as session:
session.add(sorted(transactions))
session.insert(sorted(transactions))
else:
print(sorted(transactions))
@ -149,7 +149,7 @@ class Manager:
| Operation.TagRuleAdd
):
with self.db.session() as session:
session.add(params)
session.insert(params)
case Operation.CategoryUpdate:
with self.db.session() as session:
@ -184,7 +184,7 @@ class Manager:
case Operation.GroupAdd:
with self.db.session() as session:
session.add(params)
session.insert(params)
case Operation.GroupRemove:
assert all(isinstance(param, CategoryGroup) for param in params)
@ -217,7 +217,7 @@ class Manager:
link.category = original.category
tobelinked = [Link(original.id, link.id) for link in links]
session.add(tobelinked)
session.insert(tobelinked)
case Operation.Dismantle:
assert all(isinstance(param, Link) for param in params)
@ -260,7 +260,7 @@ class Manager:
splitted.category = t.category
transactions.append(splitted)
session.add(transactions)
session.insert(transactions)
case Operation.Export:
with self.db.session() as session:
@ -298,7 +298,7 @@ class Manager:
if self.certify(transactions):
with self.db.session() as session:
session.add(transactions)
session.insert(transactions)
case Operation.ExportBanks:
with self.db.session() as session:
@ -314,7 +314,7 @@ class Manager:
if self.certify(banks):
with self.db.session() as session:
session.add(banks)
session.insert(banks)
case Operation.ExportCategoryRules:
with self.db.session() as session:
@ -325,7 +325,7 @@ class Manager:
if self.certify(rules):
with self.db.session() as session:
session.add(rules)
session.insert(rules)
case Operation.ExportTagRules:
with self.db.session() as session:
@ -336,7 +336,7 @@ class Manager:
if self.certify(rules):
with self.db.session() as session:
session.add(rules)
session.insert(rules)
case Operation.ExportCategories:
with self.db.session() as session:
@ -360,7 +360,7 @@ class Manager:
if self.certify(categories):
with self.db.session() as session:
session.add(categories)
session.insert(categories)
case Operation.ExportCategoryGroups:
with self.db.session() as session:
@ -373,7 +373,7 @@ class Manager:
if self.certify(groups):
with self.db.session() as session:
session.add(groups)
session.insert(groups)
def parse(self, filename: Path, args: dict):
return parse_data(filename, args)

View File

@ -1,11 +1,70 @@
from typing import Sequence
from sqlalchemy import Engine, create_engine, select, text
from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Sequence, Type, TypeVar
from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.model import Transaction
class Client:
def __init__(self, url: str) -> None:
self.url = url
def __init__(self, url: str, **kwargs: dict[str, Any]) -> None:
assert url, "Database URL is empty!"
self._engine = create_engine(url, **kwargs)
self._sessionmaker: Optional[sessionmaker[Session]] = None
def insert(self, transactions: Sequence[Transaction]) -> None:
raise NotImplementedError
def insert(
self, transactions: Sequence[Transaction], session: Optional[Session] = None
) -> None:
if not session:
with self.session as session_:
try:
session_.add_all(transactions)
except Exception as e:
session_.rollback()
raise InsertError(e)
else:
session_.commit()
else:
try:
session.add_all(transactions)
except Exception as e:
session.rollback()
raise InsertError(e)
else:
session.commit()
T = TypeVar("T")
def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]:
stmt = select(what)
result: Sequence[what] = []
if not session:
with self.session as session_:
try:
result = session_.scalars(stmt).all()
except Exception as e:
session_.rollback()
raise SelectError(e)
else:
session_.commit()
else:
try:
result = session.scalars(stmt).all()
except Exception as e:
session.rollback()
raise SelectError(e)
else:
session.commit()
return result
@property
def engine(self) -> Engine:
return self._engine
@property
def session(self) -> Session:
if not self._sessionmaker:
self._sessionmaker = sessionmaker(self._engine)
return self._sessionmaker()

View File

@ -0,0 +1,6 @@
class InsertError(Exception):
pass
class SelectError(Exception):
pass

View File

@ -9,6 +9,7 @@ from sqlalchemy import (
BigInteger,
Enum,
ForeignKey,
Integer,
MetaData,
Numeric,
String,
@ -78,7 +79,14 @@ class Bank(Base, Export):
bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))]
idpk = Annotated[int, mapped_column(BigInteger, primary_key=True, autoincrement=True)]
idpk = Annotated[
int,
mapped_column(
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
autoincrement=True,
),
]
money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))]

View File

@ -78,7 +78,7 @@ class DbClient:
)
return self.__session.scalars(stmt).all()
def add(self, rows: list):
def insert(self, rows: list):
self.__session.add_all(rows)
def remove_by_name(self, type, rows: list):

29
tests/test_database.py Normal file
View File

@ -0,0 +1,29 @@
from datetime import date
from decimal import Decimal
import pytest
from pfbudget.db.client import Client
from pfbudget.db.model import Base, Transaction
@pytest.fixture
def client() -> Client:
url = "sqlite://"
client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}})
Base.metadata.create_all(client.engine)
return client
class TestDatabase:
def test_initialization(self, client: Client):
pass
def test_insert_transactions(self, client: Client):
transactions = [
Transaction(date(2023, 1, 1), "", Decimal("-500")),
Transaction(date(2023, 1, 2), "", Decimal("500")),
]
with client.session as session:
client.insert(transactions, session)
assert client.select(Transaction, session) == transactions

View File

@ -26,6 +26,7 @@ def loader() -> Loader:
class TestDatabaseLoad:
def test_empty_url(self):
with pytest.raises(AssertionError):
_ = FakeDatabaseClient("")
def test_insert(self, loader: Loader):