DatabaseSession wrapper for orm.Session

For DB connections that want to keep a session alive, there's a new
`DatabaseSession` class that holds a SQLAlchemy session inside and
offers methods similar to the `Database` class.

The `Database` moves to use the `DatabaseSession` to remove duplicated
code.
This commit is contained in:
Luís Murta 2023-04-28 22:18:29 +01:00
parent 9f39836083
commit 78ff6faa12
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
2 changed files with 44 additions and 46 deletions

View File

@ -3,66 +3,64 @@ from sqlalchemy import Engine, create_engine, select
from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Sequence, Type, TypeVar
from pfbudget.db.exceptions import InsertError, SelectError
# from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.model import Transaction
class DatabaseSession:
def __init__(self, session: Session):
self.__session = session
def __enter__(self):
self.__session.begin()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
if exc_type:
self.__session.rollback()
else:
self.__session.commit()
self.__session.close()
def insert(self, transactions: Sequence[Transaction]) -> None:
self.__session.add_all(transactions)
T = TypeVar("T")
C = TypeVar("C")
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
if exists:
stmt = select(what).where(exists)
else:
stmt = select(what)
return self.__session.scalars(stmt).all()
class Client:
def __init__(self, url: str, **kwargs: Any) -> None:
def __init__(self, url: str, **kwargs: Any):
assert url, "Database URL is empty!"
self._engine = create_engine(url, **kwargs)
self._sessionmaker: Optional[sessionmaker[Session]] = None
def insert(
self, transactions: Sequence[Transaction], session: Optional[Session] = None
) -> None:
if not session:
new = deepcopy(transactions)
with self.session as session_:
try:
session_.add_all(new)
except Exception as e:
session_.rollback()
raise InsertError(e)
else:
session_.commit()
else:
try:
session.add_all(transactions)
except Exception as e:
session.rollback()
raise InsertError(e)
else:
session.commit()
def insert(self, transactions: Sequence[Transaction]) -> None:
new = deepcopy(transactions)
with self.session as session:
session.insert(new)
T = TypeVar("T")
C = TypeVar("C")
def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]:
stmt = select(what)
result: Sequence[what] = []
if not session:
with self.session as session_:
try:
result = session_.scalars(stmt).all()
except Exception as e:
session_.rollback()
raise SelectError(e)
else:
try:
result = session.scalars(stmt).all()
except Exception as e:
session.rollback()
raise SelectError(e)
return result
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
return self.session.select(what, exists)
@property
def engine(self) -> Engine:
return self._engine
@property
def session(self) -> Session:
def session(self) -> DatabaseSession:
if not self._sessionmaker:
self._sessionmaker = sessionmaker(self._engine)
return self._sessionmaker()
return DatabaseSession(self._sessionmaker())

View File

@ -25,8 +25,8 @@ class TestDatabase:
]
with client.session as session:
client.insert(transactions, session)
assert client.select(Transaction, session) == transactions
session.insert(transactions)
assert session.select(Transaction) == transactions
def test_insert_transactions_independent_sessions(self, client: Client):
transactions = [