From e7abae0d1747b79d8fbe4aa1d17fa902c0cd8ff1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Murta?= Date: Thu, 27 Apr 2023 17:47:13 +0100 Subject: [PATCH] [Refactor] Database client interface changed `add` method replaced with `insert`. `insert` and `select` implemented for new database base class. Database unit test added. Due to SQLite implementation of the primary key autoinc, the type of the IDs on the database for SQLite changed to Integer. https://www.sqlite.org/autoinc.html --- pfbudget/cli/interactive.py | 4 +-- pfbudget/core/manager.py | 24 ++++++------- pfbudget/db/client.py | 69 ++++++++++++++++++++++++++++++++++--- pfbudget/db/exceptions.py | 6 ++++ pfbudget/db/model.py | 10 +++++- pfbudget/db/postgresql.py | 2 +- tests/test_database.py | 29 ++++++++++++++++ tests/test_load.py | 3 +- 8 files changed, 125 insertions(+), 22 deletions(-) create mode 100644 pfbudget/db/exceptions.py create mode 100644 tests/test_database.py diff --git a/pfbudget/cli/interactive.py b/pfbudget/cli/interactive.py index cd392c1..9696f17 100644 --- a/pfbudget/cli/interactive.py +++ b/pfbudget/cli/interactive.py @@ -57,7 +57,7 @@ class Interactive: case "split": new = self.split(next) - session.add(new) + session.insert(new) case other: if not other: @@ -84,7 +84,7 @@ class Interactive: ) for tag in tags: if tag not in [t.name for t in self.tags]: - session.add([Tag(tag)]) + session.insert([Tag(tag)]) self.tags = session.get(Tag) next.tags.add(TransactionTag(tag)) diff --git a/pfbudget/core/manager.py b/pfbudget/core/manager.py index d43ff43..08ee0b1 100644 --- a/pfbudget/core/manager.py +++ b/pfbudget/core/manager.py @@ -72,7 +72,7 @@ class Manager: and input(f"{transactions[:5]}\nCommit? (y/n)") == "y" ): with self.db.session() as session: - session.add(sorted(transactions)) + session.insert(sorted(transactions)) case Operation.Download: client = Manager.nordigen_client() @@ -91,7 +91,7 @@ class Manager: # dry-run if not params[2]: with self.db.session() as session: - session.add(sorted(transactions)) + session.insert(sorted(transactions)) else: print(sorted(transactions)) @@ -149,7 +149,7 @@ class Manager: | Operation.TagRuleAdd ): with self.db.session() as session: - session.add(params) + session.insert(params) case Operation.CategoryUpdate: with self.db.session() as session: @@ -184,7 +184,7 @@ class Manager: case Operation.GroupAdd: with self.db.session() as session: - session.add(params) + session.insert(params) case Operation.GroupRemove: assert all(isinstance(param, CategoryGroup) for param in params) @@ -217,7 +217,7 @@ class Manager: link.category = original.category tobelinked = [Link(original.id, link.id) for link in links] - session.add(tobelinked) + session.insert(tobelinked) case Operation.Dismantle: assert all(isinstance(param, Link) for param in params) @@ -260,7 +260,7 @@ class Manager: splitted.category = t.category transactions.append(splitted) - session.add(transactions) + session.insert(transactions) case Operation.Export: with self.db.session() as session: @@ -298,7 +298,7 @@ class Manager: if self.certify(transactions): with self.db.session() as session: - session.add(transactions) + session.insert(transactions) case Operation.ExportBanks: with self.db.session() as session: @@ -314,7 +314,7 @@ class Manager: if self.certify(banks): with self.db.session() as session: - session.add(banks) + session.insert(banks) case Operation.ExportCategoryRules: with self.db.session() as session: @@ -325,7 +325,7 @@ class Manager: if self.certify(rules): with self.db.session() as session: - session.add(rules) + session.insert(rules) case Operation.ExportTagRules: with self.db.session() as session: @@ -336,7 +336,7 @@ class Manager: if self.certify(rules): with self.db.session() as session: - session.add(rules) + session.insert(rules) case Operation.ExportCategories: with self.db.session() as session: @@ -360,7 +360,7 @@ class Manager: if self.certify(categories): with self.db.session() as session: - session.add(categories) + session.insert(categories) case Operation.ExportCategoryGroups: with self.db.session() as session: @@ -373,7 +373,7 @@ class Manager: if self.certify(groups): with self.db.session() as session: - session.add(groups) + session.insert(groups) def parse(self, filename: Path, args: dict): return parse_data(filename, args) diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 4f4f0e0..82b06e7 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -1,11 +1,70 @@ -from typing import Sequence +from sqlalchemy import Engine, create_engine, select, text +from sqlalchemy.orm import Session, sessionmaker +from typing import Any, Optional, Sequence, Type, TypeVar +from pfbudget.db.exceptions import InsertError, SelectError from pfbudget.db.model import Transaction class Client: - def __init__(self, url: str) -> None: - self.url = url + def __init__(self, url: str, **kwargs: dict[str, Any]) -> None: + assert url, "Database URL is empty!" + self._engine = create_engine(url, **kwargs) + self._sessionmaker: Optional[sessionmaker[Session]] = None - def insert(self, transactions: Sequence[Transaction]) -> None: - raise NotImplementedError + def insert( + self, transactions: Sequence[Transaction], session: Optional[Session] = None + ) -> None: + if not session: + with self.session as session_: + try: + session_.add_all(transactions) + except Exception as e: + session_.rollback() + raise InsertError(e) + else: + session_.commit() + else: + try: + session.add_all(transactions) + except Exception as e: + session.rollback() + raise InsertError(e) + else: + session.commit() + + T = TypeVar("T") + + def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]: + stmt = select(what) + result: Sequence[what] = [] + + if not session: + with self.session as session_: + try: + result = session_.scalars(stmt).all() + except Exception as e: + session_.rollback() + raise SelectError(e) + else: + session_.commit() + else: + try: + result = session.scalars(stmt).all() + except Exception as e: + session.rollback() + raise SelectError(e) + else: + session.commit() + + return result + + @property + def engine(self) -> Engine: + return self._engine + + @property + def session(self) -> Session: + if not self._sessionmaker: + self._sessionmaker = sessionmaker(self._engine) + return self._sessionmaker() diff --git a/pfbudget/db/exceptions.py b/pfbudget/db/exceptions.py new file mode 100644 index 0000000..7c1e394 --- /dev/null +++ b/pfbudget/db/exceptions.py @@ -0,0 +1,6 @@ +class InsertError(Exception): + pass + + +class SelectError(Exception): + pass diff --git a/pfbudget/db/model.py b/pfbudget/db/model.py index 0f38567..6cafb49 100644 --- a/pfbudget/db/model.py +++ b/pfbudget/db/model.py @@ -9,6 +9,7 @@ from sqlalchemy import ( BigInteger, Enum, ForeignKey, + Integer, MetaData, Numeric, String, @@ -78,7 +79,14 @@ class Bank(Base, Export): bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))] -idpk = Annotated[int, mapped_column(BigInteger, primary_key=True, autoincrement=True)] +idpk = Annotated[ + int, + mapped_column( + BigInteger().with_variant(Integer, "sqlite"), + primary_key=True, + autoincrement=True, + ), +] money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))] diff --git a/pfbudget/db/postgresql.py b/pfbudget/db/postgresql.py index 4c52820..892c4e2 100644 --- a/pfbudget/db/postgresql.py +++ b/pfbudget/db/postgresql.py @@ -78,7 +78,7 @@ class DbClient: ) return self.__session.scalars(stmt).all() - def add(self, rows: list): + def insert(self, rows: list): self.__session.add_all(rows) def remove_by_name(self, type, rows: list): diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..e28dee3 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,29 @@ +from datetime import date +from decimal import Decimal +import pytest + +from pfbudget.db.client import Client +from pfbudget.db.model import Base, Transaction + + +@pytest.fixture +def client() -> Client: + url = "sqlite://" + client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}}) + Base.metadata.create_all(client.engine) + return client + + +class TestDatabase: + def test_initialization(self, client: Client): + pass + + def test_insert_transactions(self, client: Client): + transactions = [ + Transaction(date(2023, 1, 1), "", Decimal("-500")), + Transaction(date(2023, 1, 2), "", Decimal("500")), + ] + + with client.session as session: + client.insert(transactions, session) + assert client.select(Transaction, session) == transactions diff --git a/tests/test_load.py b/tests/test_load.py index e4ec3af..2e438fb 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -26,7 +26,8 @@ def loader() -> Loader: class TestDatabaseLoad: def test_empty_url(self): - _ = FakeDatabaseClient("") + with pytest.raises(AssertionError): + _ = FakeDatabaseClient("") def test_insert(self, loader: Loader): transactions = [