[Fix] Fix sessionless database client

The database `Client` wasn't working correcly when no session was
passed, as the inserted transactions would still be bound to the newly
created session inside.
Creates a copy of the input transactions to insert on the DB.
This commit is contained in:
Luís Murta 2023-04-27 18:15:08 +01:00
parent e7abae0d17
commit ad62317e56
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
2 changed files with 18 additions and 7 deletions

View File

@ -1,8 +1,9 @@
from sqlalchemy import Engine, create_engine, select, text from copy import deepcopy
from sqlalchemy import Engine, create_engine, select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Sequence, Type, TypeVar from typing import Any, Optional, Sequence, Type, TypeVar
from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.model import Transaction from pfbudget.db.model import Transaction
@ -16,9 +17,10 @@ class Client:
self, transactions: Sequence[Transaction], session: Optional[Session] = None self, transactions: Sequence[Transaction], session: Optional[Session] = None
) -> None: ) -> None:
if not session: if not session:
new = deepcopy(transactions)
with self.session as session_: with self.session as session_:
try: try:
session_.add_all(transactions) session_.add_all(new)
except Exception as e: except Exception as e:
session_.rollback() session_.rollback()
raise InsertError(e) raise InsertError(e)
@ -46,16 +48,12 @@ class Client:
except Exception as e: except Exception as e:
session_.rollback() session_.rollback()
raise SelectError(e) raise SelectError(e)
else:
session_.commit()
else: else:
try: try:
result = session.scalars(stmt).all() result = session.scalars(stmt).all()
except Exception as e: except Exception as e:
session.rollback() session.rollback()
raise SelectError(e) raise SelectError(e)
else:
session.commit()
return result return result

View File

@ -27,3 +27,16 @@ class TestDatabase:
with client.session as session: with client.session as session:
client.insert(transactions, session) client.insert(transactions, session)
assert client.select(Transaction, session) == transactions assert client.select(Transaction, session) == transactions
def test_insert_transactions_independent_sessions(self, client: Client):
transactions = [
Transaction(date(2023, 1, 1), "", Decimal("-500")),
Transaction(date(2023, 1, 2), "", Decimal("500")),
]
client.insert(transactions)
result = client.select(Transaction)
for i, transaction in enumerate(result):
assert transactions[i].date == transaction.date
assert transactions[i].description == transaction.description
assert transactions[i].amount == transaction.amount