From ad62317e56aadfa3ba1547566890dede2cf03a38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Murta?= Date: Thu, 27 Apr 2023 18:15:08 +0100 Subject: [PATCH] [Fix] Fix sessionless database client The database `Client` wasn't working correcly when no session was passed, as the inserted transactions would still be bound to the newly created session inside. Creates a copy of the input transactions to insert on the DB. --- pfbudget/db/client.py | 12 +++++------- tests/test_database.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 82b06e7..a652a35 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -1,8 +1,9 @@ -from sqlalchemy import Engine, create_engine, select, text +from copy import deepcopy +from sqlalchemy import Engine, create_engine, select from sqlalchemy.orm import Session, sessionmaker from typing import Any, Optional, Sequence, Type, TypeVar -from pfbudget.db.exceptions import InsertError, SelectError +from pfbudget.db.exceptions import InsertError, SelectError from pfbudget.db.model import Transaction @@ -16,9 +17,10 @@ class Client: self, transactions: Sequence[Transaction], session: Optional[Session] = None ) -> None: if not session: + new = deepcopy(transactions) with self.session as session_: try: - session_.add_all(transactions) + session_.add_all(new) except Exception as e: session_.rollback() raise InsertError(e) @@ -46,16 +48,12 @@ class Client: except Exception as e: session_.rollback() raise SelectError(e) - else: - session_.commit() else: try: result = session.scalars(stmt).all() except Exception as e: session.rollback() raise SelectError(e) - else: - session.commit() return result diff --git a/tests/test_database.py b/tests/test_database.py index e28dee3..a8741bd 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -27,3 +27,16 @@ class TestDatabase: with client.session as session: client.insert(transactions, session) assert client.select(Transaction, session) == transactions + + def test_insert_transactions_independent_sessions(self, client: Client): + transactions = [ + Transaction(date(2023, 1, 1), "", Decimal("-500")), + Transaction(date(2023, 1, 2), "", Decimal("500")), + ] + + client.insert(transactions) + result = client.select(Transaction) + for i, transaction in enumerate(result): + assert transactions[i].date == transaction.date + assert transactions[i].description == transaction.description + assert transactions[i].amount == transaction.amount