DatabaseSession wrapper for orm.Session
For DB connections that want to keep a session alive, there's a new `DatabaseSession` class that holds a SQLAlchemy session inside and offers methods similar to the `Database` class. The `Database` moves to use the `DatabaseSession` to remove duplicated code.
This commit is contained in:
parent
9f39836083
commit
78ff6faa12
@ -3,66 +3,64 @@ from sqlalchemy import Engine, create_engine, select
|
|||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
from typing import Any, Optional, Sequence, Type, TypeVar
|
from typing import Any, Optional, Sequence, Type, TypeVar
|
||||||
|
|
||||||
from pfbudget.db.exceptions import InsertError, SelectError
|
# from pfbudget.db.exceptions import InsertError, SelectError
|
||||||
from pfbudget.db.model import Transaction
|
from pfbudget.db.model import Transaction
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSession:
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
self.__session = session
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.__session.begin()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
||||||
|
if exc_type:
|
||||||
|
self.__session.rollback()
|
||||||
|
else:
|
||||||
|
self.__session.commit()
|
||||||
|
self.__session.close()
|
||||||
|
|
||||||
|
def insert(self, transactions: Sequence[Transaction]) -> None:
|
||||||
|
self.__session.add_all(transactions)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
C = TypeVar("C")
|
||||||
|
|
||||||
|
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
|
||||||
|
if exists:
|
||||||
|
stmt = select(what).where(exists)
|
||||||
|
else:
|
||||||
|
stmt = select(what)
|
||||||
|
|
||||||
|
return self.__session.scalars(stmt).all()
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(self, url: str, **kwargs: Any) -> None:
|
def __init__(self, url: str, **kwargs: Any):
|
||||||
assert url, "Database URL is empty!"
|
assert url, "Database URL is empty!"
|
||||||
self._engine = create_engine(url, **kwargs)
|
self._engine = create_engine(url, **kwargs)
|
||||||
self._sessionmaker: Optional[sessionmaker[Session]] = None
|
self._sessionmaker: Optional[sessionmaker[Session]] = None
|
||||||
|
|
||||||
def insert(
|
def insert(self, transactions: Sequence[Transaction]) -> None:
|
||||||
self, transactions: Sequence[Transaction], session: Optional[Session] = None
|
|
||||||
) -> None:
|
|
||||||
if not session:
|
|
||||||
new = deepcopy(transactions)
|
new = deepcopy(transactions)
|
||||||
with self.session as session_:
|
with self.session as session:
|
||||||
try:
|
session.insert(new)
|
||||||
session_.add_all(new)
|
|
||||||
except Exception as e:
|
|
||||||
session_.rollback()
|
|
||||||
raise InsertError(e)
|
|
||||||
else:
|
|
||||||
session_.commit()
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
session.add_all(transactions)
|
|
||||||
except Exception as e:
|
|
||||||
session.rollback()
|
|
||||||
raise InsertError(e)
|
|
||||||
else:
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
C = TypeVar("C")
|
||||||
|
|
||||||
def select(self, what: Type[T], session: Optional[Session] = None) -> Sequence[T]:
|
def select(self, what: Type[T], exists: Optional[Type[C]] = None) -> Sequence[T]:
|
||||||
stmt = select(what)
|
return self.session.select(what, exists)
|
||||||
result: Sequence[what] = []
|
|
||||||
|
|
||||||
if not session:
|
|
||||||
with self.session as session_:
|
|
||||||
try:
|
|
||||||
result = session_.scalars(stmt).all()
|
|
||||||
except Exception as e:
|
|
||||||
session_.rollback()
|
|
||||||
raise SelectError(e)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
result = session.scalars(stmt).all()
|
|
||||||
except Exception as e:
|
|
||||||
session.rollback()
|
|
||||||
raise SelectError(e)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def engine(self) -> Engine:
|
def engine(self) -> Engine:
|
||||||
return self._engine
|
return self._engine
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session(self) -> Session:
|
def session(self) -> DatabaseSession:
|
||||||
if not self._sessionmaker:
|
if not self._sessionmaker:
|
||||||
self._sessionmaker = sessionmaker(self._engine)
|
self._sessionmaker = sessionmaker(self._engine)
|
||||||
return self._sessionmaker()
|
|
||||||
|
return DatabaseSession(self._sessionmaker())
|
||||||
|
|||||||
@ -25,8 +25,8 @@ class TestDatabase:
|
|||||||
]
|
]
|
||||||
|
|
||||||
with client.session as session:
|
with client.session as session:
|
||||||
client.insert(transactions, session)
|
session.insert(transactions)
|
||||||
assert client.select(Transaction, session) == transactions
|
assert session.select(Transaction) == transactions
|
||||||
|
|
||||||
def test_insert_transactions_independent_sessions(self, client: Client):
|
def test_insert_transactions_independent_sessions(self, client: Client):
|
||||||
transactions = [
|
transactions = [
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user