The `Client` was never closing the `Session` in `DatabaseSession`, which meant that tests that compared transactions and banks would work since attributes loaded only on select would still work. However, this was not the intended behavior. Calling select from the client with using sessions should detach objects completely from a DB session. Therefore, attributes on Transaction, TransactionCategory and Bank are now loaded on a join (lazy parameter).
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
from collections.abc import Sequence
|
|
from copy import deepcopy
|
|
from sqlalchemy import Engine, create_engine, delete, select, update
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
from typing import Any, Mapping, Optional, Type, TypeVar
|
|
|
|
# from pfbudget.db.exceptions import InsertError, SelectError
|
|
|
|
|
|
class DatabaseSession:
|
|
def __init__(self, session: Session):
|
|
self.__session = session
|
|
|
|
def __enter__(self):
|
|
self.__session.begin()
|
|
return self
|
|
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
|
|
if exc_type:
|
|
self.__session.rollback()
|
|
else:
|
|
self.__session.commit()
|
|
self.__session.close()
|
|
|
|
def close(self):
|
|
self.__session.close()
|
|
|
|
def insert(self, sequence: Sequence[Any]) -> None:
|
|
self.__session.add_all(sequence)
|
|
|
|
T = TypeVar("T")
|
|
|
|
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
|
|
if exists:
|
|
stmt = select(what).filter(exists)
|
|
else:
|
|
stmt = select(what)
|
|
|
|
return self.__session.scalars(stmt).unique().all()
|
|
|
|
|
|
class Client:
|
|
def __init__(self, url: str, **kwargs: Any):
|
|
assert url, "Database URL is empty!"
|
|
self._engine = create_engine(url, **kwargs)
|
|
self._sessionmaker = sessionmaker(self._engine)
|
|
|
|
def insert(self, sequence: Sequence[Any]) -> None:
|
|
new = deepcopy(sequence)
|
|
with self.session as session:
|
|
session.insert(new)
|
|
|
|
T = TypeVar("T")
|
|
|
|
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
|
|
session = self.session
|
|
result = session.select(what, exists)
|
|
session.close()
|
|
return result
|
|
|
|
def update(self, what: Type[Any], values: Sequence[Mapping[str, Any]]) -> None:
|
|
with self._sessionmaker() as session, session.begin():
|
|
session.execute(update(what), values)
|
|
|
|
def delete(self, what: Type[Any], column: Any, values: Sequence[str]) -> None:
|
|
with self._sessionmaker() as session, session.begin():
|
|
session.execute(delete(what).where(column.in_(values)))
|
|
|
|
@property
|
|
def engine(self) -> Engine:
|
|
return self._engine
|
|
|
|
@property
|
|
def session(self) -> DatabaseSession:
|
|
return DatabaseSession(self._sessionmaker())
|