diff --git a/pfbudget/cli/interactive.py b/pfbudget/cli/interactive.py index cd392c1..9696f17 100644 --- a/pfbudget/cli/interactive.py +++ b/pfbudget/cli/interactive.py @@ -57,7 +57,7 @@ class Interactive: case "split": new = self.split(next) - session.add(new) + session.insert(new) case other: if not other: @@ -84,7 +84,7 @@ class Interactive: ) for tag in tags: if tag not in [t.name for t in self.tags]: - session.add([Tag(tag)]) + session.insert([Tag(tag)]) self.tags = session.get(Tag) next.tags.add(TransactionTag(tag)) diff --git a/pfbudget/core/manager.py b/pfbudget/core/manager.py index d43ff43..08ee0b1 100644 --- a/pfbudget/core/manager.py +++ b/pfbudget/core/manager.py @@ -72,7 +72,7 @@ class Manager: and input(f"{transactions[:5]}\nCommit? (y/n)") == "y" ): with self.db.session() as session: - session.add(sorted(transactions)) + session.insert(sorted(transactions)) case Operation.Download: client = Manager.nordigen_client() @@ -91,7 +91,7 @@ class Manager: # dry-run if not params[2]: with self.db.session() as session: - session.add(sorted(transactions)) + session.insert(sorted(transactions)) else: print(sorted(transactions)) @@ -149,7 +149,7 @@ class Manager: | Operation.TagRuleAdd ): with self.db.session() as session: - session.add(params) + session.insert(params) case Operation.CategoryUpdate: with self.db.session() as session: @@ -184,7 +184,7 @@ class Manager: case Operation.GroupAdd: with self.db.session() as session: - session.add(params) + session.insert(params) case Operation.GroupRemove: assert all(isinstance(param, CategoryGroup) for param in params) @@ -217,7 +217,7 @@ class Manager: link.category = original.category tobelinked = [Link(original.id, link.id) for link in links] - session.add(tobelinked) + session.insert(tobelinked) case Operation.Dismantle: assert all(isinstance(param, Link) for param in params) @@ -260,7 +260,7 @@ class Manager: splitted.category = t.category transactions.append(splitted) - session.add(transactions) + session.insert(transactions) case Operation.Export: with self.db.session() as session: @@ -298,7 +298,7 @@ class Manager: if self.certify(transactions): with self.db.session() as session: - session.add(transactions) + session.insert(transactions) case Operation.ExportBanks: with self.db.session() as session: @@ -314,7 +314,7 @@ class Manager: if self.certify(banks): with self.db.session() as session: - session.add(banks) + session.insert(banks) case Operation.ExportCategoryRules: with self.db.session() as session: @@ -325,7 +325,7 @@ class Manager: if self.certify(rules): with self.db.session() as session: - session.add(rules) + session.insert(rules) case Operation.ExportTagRules: with self.db.session() as session: @@ -336,7 +336,7 @@ class Manager: if self.certify(rules): with self.db.session() as session: - session.add(rules) + session.insert(rules) case Operation.ExportCategories: with self.db.session() as session: @@ -360,7 +360,7 @@ class Manager: if self.certify(categories): with self.db.session() as session: - session.add(categories) + session.insert(categories) case Operation.ExportCategoryGroups: with self.db.session() as session: @@ -373,7 +373,7 @@ class Manager: if self.certify(groups): with self.db.session() as session: - session.add(groups) + session.insert(groups) def parse(self, filename: Path, args: dict): return parse_data(filename, args) diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 4f4f0e0..82b06e7 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -1,11 +1,70 @@ -from typing import Sequence +from sqlalchemy import Engine, create_engine, select, text +from sqlalchemy.orm import Session, sessionmaker +from typing import Any, Optional, Sequence, Type, TypeVar +from pfbudget.db.exceptions import InsertError, SelectError from pfbudget.db.model import Transaction class Client: - def __init__(self, url: str) -> None: - self.url = url + def __init__(self, url: str, **kwargs: dict[str, Any]) -> None: + assert url, "Database URL is empty!" + self._engine = create_engine(url, **kwargs) + self._sessionmaker: Optional[sessionmaker[Session]] = None - def insert(self, transactions: Sequence[Transaction]) -> None: - raise NotImplementedError + def insert( + self, transactions: Sequence[Transaction], session: Optional[Session] = None + ) -> None: + if not session: + with self.session as session_: + try: + session_.add_all(transactions) + except Exception as e: + session_.rollback() + raise InsertError(e) + else: + session_.commit() + else: + try: + session.add_all(transactions) + except Exception as e: + session.rollback() + raise InsertError(e) + else: + session.commit() + + T = TypeVar("T") + + def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]: + stmt = select(what) + result: Sequence[what] = [] + + if not session: + with self.session as session_: + try: + result = session_.scalars(stmt).all() + except Exception as e: + session_.rollback() + raise SelectError(e) + else: + session_.commit() + else: + try: + result = session.scalars(stmt).all() + except Exception as e: + session.rollback() + raise SelectError(e) + else: + session.commit() + + return result + + @property + def engine(self) -> Engine: + return self._engine + + @property + def session(self) -> Session: + if not self._sessionmaker: + self._sessionmaker = sessionmaker(self._engine) + return self._sessionmaker() diff --git a/pfbudget/db/exceptions.py b/pfbudget/db/exceptions.py new file mode 100644 index 0000000..7c1e394 --- /dev/null +++ b/pfbudget/db/exceptions.py @@ -0,0 +1,6 @@ +class InsertError(Exception): + pass + + +class SelectError(Exception): + pass diff --git a/pfbudget/db/model.py b/pfbudget/db/model.py index 0f38567..6cafb49 100644 --- a/pfbudget/db/model.py +++ b/pfbudget/db/model.py @@ -9,6 +9,7 @@ from sqlalchemy import ( BigInteger, Enum, ForeignKey, + Integer, MetaData, Numeric, String, @@ -78,7 +79,14 @@ class Bank(Base, Export): bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))] -idpk = Annotated[int, mapped_column(BigInteger, primary_key=True, autoincrement=True)] +idpk = Annotated[ + int, + mapped_column( + BigInteger().with_variant(Integer, "sqlite"), + primary_key=True, + autoincrement=True, + ), +] money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))] diff --git a/pfbudget/db/postgresql.py b/pfbudget/db/postgresql.py index 4c52820..892c4e2 100644 --- a/pfbudget/db/postgresql.py +++ b/pfbudget/db/postgresql.py @@ -78,7 +78,7 @@ class DbClient: ) return self.__session.scalars(stmt).all() - def add(self, rows: list): + def insert(self, rows: list): self.__session.add_all(rows) def remove_by_name(self, type, rows: list): diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..e28dee3 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,29 @@ +from datetime import date +from decimal import Decimal +import pytest + +from pfbudget.db.client import Client +from pfbudget.db.model import Base, Transaction + + +@pytest.fixture +def client() -> Client: + url = "sqlite://" + client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}}) + Base.metadata.create_all(client.engine) + return client + + +class TestDatabase: + def test_initialization(self, client: Client): + pass + + def test_insert_transactions(self, client: Client): + transactions = [ + Transaction(date(2023, 1, 1), "", Decimal("-500")), + Transaction(date(2023, 1, 2), "", Decimal("500")), + ] + + with client.session as session: + client.insert(transactions, session) + assert client.select(Transaction, session) == transactions diff --git a/tests/test_load.py b/tests/test_load.py index e4ec3af..2e438fb 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -26,7 +26,8 @@ def loader() -> Loader: class TestDatabaseLoad: def test_empty_url(self): - _ = FakeDatabaseClient("") + with pytest.raises(AssertionError): + _ = FakeDatabaseClient("") def test_insert(self, loader: Loader): transactions = [