[Refactor] Database client interface changed
`add` method replaced with `insert`. `insert` and `select` implemented for new database base class. Database unit test added. Due to SQLite implementation of the primary key autoinc, the type of the IDs on the database for SQLite changed to Integer. https://www.sqlite.org/autoinc.html
This commit is contained in:
parent
9c7c06c181
commit
e7abae0d17
@ -57,7 +57,7 @@ class Interactive:
|
|||||||
|
|
||||||
case "split":
|
case "split":
|
||||||
new = self.split(next)
|
new = self.split(next)
|
||||||
session.add(new)
|
session.insert(new)
|
||||||
|
|
||||||
case other:
|
case other:
|
||||||
if not other:
|
if not other:
|
||||||
@ -84,7 +84,7 @@ class Interactive:
|
|||||||
)
|
)
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
if tag not in [t.name for t in self.tags]:
|
if tag not in [t.name for t in self.tags]:
|
||||||
session.add([Tag(tag)])
|
session.insert([Tag(tag)])
|
||||||
self.tags = session.get(Tag)
|
self.tags = session.get(Tag)
|
||||||
|
|
||||||
next.tags.add(TransactionTag(tag))
|
next.tags.add(TransactionTag(tag))
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class Manager:
|
|||||||
and input(f"{transactions[:5]}\nCommit? (y/n)") == "y"
|
and input(f"{transactions[:5]}\nCommit? (y/n)") == "y"
|
||||||
):
|
):
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(sorted(transactions))
|
session.insert(sorted(transactions))
|
||||||
|
|
||||||
case Operation.Download:
|
case Operation.Download:
|
||||||
client = Manager.nordigen_client()
|
client = Manager.nordigen_client()
|
||||||
@ -91,7 +91,7 @@ class Manager:
|
|||||||
# dry-run
|
# dry-run
|
||||||
if not params[2]:
|
if not params[2]:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(sorted(transactions))
|
session.insert(sorted(transactions))
|
||||||
else:
|
else:
|
||||||
print(sorted(transactions))
|
print(sorted(transactions))
|
||||||
|
|
||||||
@ -149,7 +149,7 @@ class Manager:
|
|||||||
| Operation.TagRuleAdd
|
| Operation.TagRuleAdd
|
||||||
):
|
):
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(params)
|
session.insert(params)
|
||||||
|
|
||||||
case Operation.CategoryUpdate:
|
case Operation.CategoryUpdate:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
@ -184,7 +184,7 @@ class Manager:
|
|||||||
|
|
||||||
case Operation.GroupAdd:
|
case Operation.GroupAdd:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(params)
|
session.insert(params)
|
||||||
|
|
||||||
case Operation.GroupRemove:
|
case Operation.GroupRemove:
|
||||||
assert all(isinstance(param, CategoryGroup) for param in params)
|
assert all(isinstance(param, CategoryGroup) for param in params)
|
||||||
@ -217,7 +217,7 @@ class Manager:
|
|||||||
link.category = original.category
|
link.category = original.category
|
||||||
|
|
||||||
tobelinked = [Link(original.id, link.id) for link in links]
|
tobelinked = [Link(original.id, link.id) for link in links]
|
||||||
session.add(tobelinked)
|
session.insert(tobelinked)
|
||||||
|
|
||||||
case Operation.Dismantle:
|
case Operation.Dismantle:
|
||||||
assert all(isinstance(param, Link) for param in params)
|
assert all(isinstance(param, Link) for param in params)
|
||||||
@ -260,7 +260,7 @@ class Manager:
|
|||||||
splitted.category = t.category
|
splitted.category = t.category
|
||||||
transactions.append(splitted)
|
transactions.append(splitted)
|
||||||
|
|
||||||
session.add(transactions)
|
session.insert(transactions)
|
||||||
|
|
||||||
case Operation.Export:
|
case Operation.Export:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
@ -298,7 +298,7 @@ class Manager:
|
|||||||
|
|
||||||
if self.certify(transactions):
|
if self.certify(transactions):
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(transactions)
|
session.insert(transactions)
|
||||||
|
|
||||||
case Operation.ExportBanks:
|
case Operation.ExportBanks:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
@ -314,7 +314,7 @@ class Manager:
|
|||||||
|
|
||||||
if self.certify(banks):
|
if self.certify(banks):
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(banks)
|
session.insert(banks)
|
||||||
|
|
||||||
case Operation.ExportCategoryRules:
|
case Operation.ExportCategoryRules:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
@ -325,7 +325,7 @@ class Manager:
|
|||||||
|
|
||||||
if self.certify(rules):
|
if self.certify(rules):
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(rules)
|
session.insert(rules)
|
||||||
|
|
||||||
case Operation.ExportTagRules:
|
case Operation.ExportTagRules:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
@ -336,7 +336,7 @@ class Manager:
|
|||||||
|
|
||||||
if self.certify(rules):
|
if self.certify(rules):
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(rules)
|
session.insert(rules)
|
||||||
|
|
||||||
case Operation.ExportCategories:
|
case Operation.ExportCategories:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
@ -360,7 +360,7 @@ class Manager:
|
|||||||
|
|
||||||
if self.certify(categories):
|
if self.certify(categories):
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(categories)
|
session.insert(categories)
|
||||||
|
|
||||||
case Operation.ExportCategoryGroups:
|
case Operation.ExportCategoryGroups:
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
@ -373,7 +373,7 @@ class Manager:
|
|||||||
|
|
||||||
if self.certify(groups):
|
if self.certify(groups):
|
||||||
with self.db.session() as session:
|
with self.db.session() as session:
|
||||||
session.add(groups)
|
session.insert(groups)
|
||||||
|
|
||||||
def parse(self, filename: Path, args: dict):
|
def parse(self, filename: Path, args: dict):
|
||||||
return parse_data(filename, args)
|
return parse_data(filename, args)
|
||||||
|
|||||||
@ -1,11 +1,70 @@
|
|||||||
from typing import Sequence
|
from sqlalchemy import Engine, create_engine, select, text
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
from typing import Any, Optional, Sequence, Type, TypeVar
|
||||||
|
from pfbudget.db.exceptions import InsertError, SelectError
|
||||||
|
|
||||||
from pfbudget.db.model import Transaction
|
from pfbudget.db.model import Transaction
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str, **kwargs: dict[str, Any]) -> None:
|
||||||
self.url = url
|
assert url, "Database URL is empty!"
|
||||||
|
self._engine = create_engine(url, **kwargs)
|
||||||
|
self._sessionmaker: Optional[sessionmaker[Session]] = None
|
||||||
|
|
||||||
def insert(self, transactions: Sequence[Transaction]) -> None:
|
def insert(
|
||||||
raise NotImplementedError
|
self, transactions: Sequence[Transaction], session: Optional[Session] = None
|
||||||
|
) -> None:
|
||||||
|
if not session:
|
||||||
|
with self.session as session_:
|
||||||
|
try:
|
||||||
|
session_.add_all(transactions)
|
||||||
|
except Exception as e:
|
||||||
|
session_.rollback()
|
||||||
|
raise InsertError(e)
|
||||||
|
else:
|
||||||
|
session_.commit()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
session.add_all(transactions)
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise InsertError(e)
|
||||||
|
else:
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]:
|
||||||
|
stmt = select(what)
|
||||||
|
result: Sequence[what] = []
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
with self.session as session_:
|
||||||
|
try:
|
||||||
|
result = session_.scalars(stmt).all()
|
||||||
|
except Exception as e:
|
||||||
|
session_.rollback()
|
||||||
|
raise SelectError(e)
|
||||||
|
else:
|
||||||
|
session_.commit()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
result = session.scalars(stmt).all()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise SelectError(e)
|
||||||
|
else:
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def engine(self) -> Engine:
|
||||||
|
return self._engine
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> Session:
|
||||||
|
if not self._sessionmaker:
|
||||||
|
self._sessionmaker = sessionmaker(self._engine)
|
||||||
|
return self._sessionmaker()
|
||||||
|
|||||||
6
pfbudget/db/exceptions.py
Normal file
6
pfbudget/db/exceptions.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
class InsertError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SelectError(Exception):
|
||||||
|
pass
|
||||||
@ -9,6 +9,7 @@ from sqlalchemy import (
|
|||||||
BigInteger,
|
BigInteger,
|
||||||
Enum,
|
Enum,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
MetaData,
|
MetaData,
|
||||||
Numeric,
|
Numeric,
|
||||||
String,
|
String,
|
||||||
@ -78,7 +79,14 @@ class Bank(Base, Export):
|
|||||||
|
|
||||||
bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))]
|
bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))]
|
||||||
|
|
||||||
idpk = Annotated[int, mapped_column(BigInteger, primary_key=True, autoincrement=True)]
|
idpk = Annotated[
|
||||||
|
int,
|
||||||
|
mapped_column(
|
||||||
|
BigInteger().with_variant(Integer, "sqlite"),
|
||||||
|
primary_key=True,
|
||||||
|
autoincrement=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))]
|
money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -78,7 +78,7 @@ class DbClient:
|
|||||||
)
|
)
|
||||||
return self.__session.scalars(stmt).all()
|
return self.__session.scalars(stmt).all()
|
||||||
|
|
||||||
def add(self, rows: list):
|
def insert(self, rows: list):
|
||||||
self.__session.add_all(rows)
|
self.__session.add_all(rows)
|
||||||
|
|
||||||
def remove_by_name(self, type, rows: list):
|
def remove_by_name(self, type, rows: list):
|
||||||
|
|||||||
29
tests/test_database.py
Normal file
29
tests/test_database.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from datetime import date
|
||||||
|
from decimal import Decimal
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pfbudget.db.client import Client
|
||||||
|
from pfbudget.db.model import Base, Transaction
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client() -> Client:
|
||||||
|
url = "sqlite://"
|
||||||
|
client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}})
|
||||||
|
Base.metadata.create_all(client.engine)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabase:
|
||||||
|
def test_initialization(self, client: Client):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_insert_transactions(self, client: Client):
|
||||||
|
transactions = [
|
||||||
|
Transaction(date(2023, 1, 1), "", Decimal("-500")),
|
||||||
|
Transaction(date(2023, 1, 2), "", Decimal("500")),
|
||||||
|
]
|
||||||
|
|
||||||
|
with client.session as session:
|
||||||
|
client.insert(transactions, session)
|
||||||
|
assert client.select(Transaction, session) == transactions
|
||||||
@ -26,6 +26,7 @@ def loader() -> Loader:
|
|||||||
|
|
||||||
class TestDatabaseLoad:
|
class TestDatabaseLoad:
|
||||||
def test_empty_url(self):
|
def test_empty_url(self):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
_ = FakeDatabaseClient("")
|
_ = FakeDatabaseClient("")
|
||||||
|
|
||||||
def test_insert(self, loader: Loader):
|
def test_insert(self, loader: Loader):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user