diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 0c1a5c5..6de48e0 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -3,66 +3,64 @@ from sqlalchemy import Engine, create_engine, select from sqlalchemy.orm import Session, sessionmaker from typing import Any, Optional, Sequence, Type, TypeVar -from pfbudget.db.exceptions import InsertError, SelectError +# from pfbudget.db.exceptions import InsertError, SelectError from pfbudget.db.model import Transaction +class DatabaseSession: + def __init__(self, session: Session): + self.__session = session + + def __enter__(self): + self.__session.begin() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): + if exc_type: + self.__session.rollback() + else: + self.__session.commit() + self.__session.close() + + def insert(self, transactions: Sequence[Transaction]) -> None: + self.__session.add_all(transactions) + + T = TypeVar("T") + C = TypeVar("C") + + def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]: + if exists: + stmt = select(what).where(exists) + else: + stmt = select(what) + + return self.__session.scalars(stmt).all() + + class Client: - def __init__(self, url: str, **kwargs: Any) -> None: + def __init__(self, url: str, **kwargs: Any): assert url, "Database URL is empty!" self._engine = create_engine(url, **kwargs) self._sessionmaker: Optional[sessionmaker[Session]] = None - def insert( - self, transactions: Sequence[Transaction], session: Optional[Session] = None - ) -> None: - if not session: - new = deepcopy(transactions) - with self.session as session_: - try: - session_.add_all(new) - except Exception as e: - session_.rollback() - raise InsertError(e) - else: - session_.commit() - else: - try: - session.add_all(transactions) - except Exception as e: - session.rollback() - raise InsertError(e) - else: - session.commit() + def insert(self, transactions: Sequence[Transaction]) -> None: + new = deepcopy(transactions) + with self.session as session: + session.insert(new) T = TypeVar("T") + C = TypeVar("C") - def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]: - stmt = select(what) - result: Sequence[what] = [] - - if not session: - with self.session as session_: - try: - result = session_.scalars(stmt).all() - except Exception as e: - session_.rollback() - raise SelectError(e) - else: - try: - result = session.scalars(stmt).all() - except Exception as e: - session.rollback() - raise SelectError(e) - - return result + def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]: + return self.session.select(what, exists) @property def engine(self) -> Engine: return self._engine @property - def session(self) -> Session: + def session(self) -> DatabaseSession: if not self._sessionmaker: self._sessionmaker = sessionmaker(self._engine) - return self._sessionmaker() + + return DatabaseSession(self._sessionmaker()) diff --git a/tests/test_database.py b/tests/test_database.py index a8741bd..327b32b 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -25,8 +25,8 @@ class TestDatabase: ] with client.session as session: - client.insert(transactions, session) - assert client.select(Transaction, session) == transactions + session.insert(transactions) + assert session.select(Transaction) == transactions def test_insert_transactions_independent_sessions(self, client: Client): transactions = [