DatabaseSession wrapper for orm.Session

For DB connections that want to keep a session alive, there's a new
`DatabaseSession` class that holds a SQLAlchemy session inside and
offers methods similar to the `Database` class.

The `Database` moves to use the `DatabaseSession` to remove duplicated
code.
This commit is contained in:
Luís Murta 2023-04-28 22:18:29 +01:00
parent 9f39836083
commit 78ff6faa12
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
2 changed files with 44 additions and 46 deletions

View File

@ -3,66 +3,64 @@ from sqlalchemy import Engine, create_engine, select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Sequence, Type, TypeVar from typing import Any, Optional, Sequence, Type, TypeVar
from pfbudget.db.exceptions import InsertError, SelectError # from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.model import Transaction from pfbudget.db.model import Transaction
class DatabaseSession:
def __init__(self, session: Session):
self.__session = session
def __enter__(self):
self.__session.begin()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
if exc_type:
self.__session.rollback()
else:
self.__session.commit()
self.__session.close()
def insert(self, transactions: Sequence[Transaction]) -> None:
self.__session.add_all(transactions)
T = TypeVar("T")
C = TypeVar("C")
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
if exists:
stmt = select(what).where(exists)
else:
stmt = select(what)
return self.__session.scalars(stmt).all()
class Client: class Client:
def __init__(self, url: str, **kwargs: Any) -> None: def __init__(self, url: str, **kwargs: Any):
assert url, "Database URL is empty!" assert url, "Database URL is empty!"
self._engine = create_engine(url, **kwargs) self._engine = create_engine(url, **kwargs)
self._sessionmaker: Optional[sessionmaker[Session]] = None self._sessionmaker: Optional[sessionmaker[Session]] = None
def insert( def insert(self, transactions: Sequence[Transaction]) -> None:
self, transactions: Sequence[Transaction], session: Optional[Session] = None new = deepcopy(transactions)
) -> None: with self.session as session:
if not session: session.insert(new)
new = deepcopy(transactions)
with self.session as session_:
try:
session_.add_all(new)
except Exception as e:
session_.rollback()
raise InsertError(e)
else:
session_.commit()
else:
try:
session.add_all(transactions)
except Exception as e:
session.rollback()
raise InsertError(e)
else:
session.commit()
T = TypeVar("T") T = TypeVar("T")
C = TypeVar("C")
def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]: def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
stmt = select(what) return self.session.select(what, exists)
result: Sequence[what] = []
if not session:
with self.session as session_:
try:
result = session_.scalars(stmt).all()
except Exception as e:
session_.rollback()
raise SelectError(e)
else:
try:
result = session.scalars(stmt).all()
except Exception as e:
session.rollback()
raise SelectError(e)
return result
@property @property
def engine(self) -> Engine: def engine(self) -> Engine:
return self._engine return self._engine
@property @property
def session(self) -> Session: def session(self) -> DatabaseSession:
if not self._sessionmaker: if not self._sessionmaker:
self._sessionmaker = sessionmaker(self._engine) self._sessionmaker = sessionmaker(self._engine)
return self._sessionmaker()
return DatabaseSession(self._sessionmaker())

View File

@ -25,8 +25,8 @@ class TestDatabase:
] ]
with client.session as session: with client.session as session:
client.insert(transactions, session) session.insert(transactions)
assert client.select(Transaction, session) == transactions assert session.select(Transaction) == transactions
def test_insert_transactions_independent_sessions(self, client: Client): def test_insert_transactions_independent_sessions(self, client: Client):
transactions = [ transactions = [