[Fix] Fix sessionless database client

The database `Client` wasn't working correcly when no session was
passed, as the inserted transactions would still be bound to the newly
created session inside.
Creates a copy of the input transactions to insert on the DB.
This commit is contained in:
Luís Murta 2023-04-27 18:15:08 +01:00
parent e7abae0d17
commit ad62317e56
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
2 changed files with 18 additions and 7 deletions

View File

@ -1,8 +1,9 @@
from sqlalchemy import Engine, create_engine, select, text
from copy import deepcopy
from sqlalchemy import Engine, create_engine, select
from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Optional, Sequence, Type, TypeVar
from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.exceptions import InsertError, SelectError
from pfbudget.db.model import Transaction
@ -16,9 +17,10 @@ class Client:
self, transactions: Sequence[Transaction], session: Optional[Session] = None
) -> None:
if not session:
new = deepcopy(transactions)
with self.session as session_:
try:
session_.add_all(transactions)
session_.add_all(new)
except Exception as e:
session_.rollback()
raise InsertError(e)
@ -46,16 +48,12 @@ class Client:
except Exception as e:
session_.rollback()
raise SelectError(e)
else:
session_.commit()
else:
try:
result = session.scalars(stmt).all()
except Exception as e:
session.rollback()
raise SelectError(e)
else:
session.commit()
return result

View File

@ -27,3 +27,16 @@ class TestDatabase:
with client.session as session:
client.insert(transactions, session)
assert client.select(Transaction, session) == transactions
def test_insert_transactions_independent_sessions(self, client: Client):
transactions = [
Transaction(date(2023, 1, 1), "", Decimal("-500")),
Transaction(date(2023, 1, 2), "", Decimal("500")),
]
client.insert(transactions)
result = client.select(Transaction)
for i, transaction in enumerate(result):
assert transactions[i].date == transaction.date
assert transactions[i].description == transaction.description
assert transactions[i].amount == transaction.amount