budget/pfbudget/db/postgresql.py
Luís Murta e7abae0d17
[Refactor] Database client interface changed
`add` method replaced with `insert`.
`insert` and `select` implemented for new database base class.
Database unit test added.

Due to SQLite implementation of the primary key autoinc, the type of the
IDs on the database for SQLite changed to Integer.
https://www.sqlite.org/autoinc.html
2023-04-29 20:20:20 +01:00

124 lines
3.9 KiB
Python

from dataclasses import asdict
from sqlalchemy import create_engine, delete, select, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import false
from typing import Sequence, Type, TypeVar
from pfbudget.db.model import (
Category,
CategoryGroup,
CategorySchedule,
Link,
Transaction,
)
class DbClient:
"""
General database client using sqlalchemy
"""
__sessions: list[Session]
def __init__(self, url: str, echo=False) -> None:
self._engine = create_engine(url, echo=echo)
@property
def engine(self):
return self._engine
class ClientSession:
def __init__(self, engine):
self.__engine = engine
def __enter__(self):
self.__session = Session(self.__engine)
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.commit()
self.__session.close()
def commit(self):
self.__session.commit()
def expunge_all(self):
self.__session.expunge_all()
T = TypeVar("T")
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
if column is not None:
if values:
if isinstance(values, Sequence):
stmt = select(type).where(column.in_(values))
else:
stmt = select(type).where(column == values)
else:
stmt = select(type).where(column)
else:
stmt = select(type)
return self.__session.scalars(stmt).all()
def uncategorized(self) -> Sequence[Transaction]:
"""Selects all valid uncategorized transactions
At this moment that includes:
- Categories w/o category
- AND non-split categories
Returns:
Sequence[Transaction]: transactions left uncategorized
"""
stmt = (
select(Transaction)
.where(~Transaction.category.has())
.where(Transaction.split == false())
)
return self.__session.scalars(stmt).all()
def insert(self, rows: list):
self.__session.add_all(rows)
def remove_by_name(self, type, rows: list):
stmt = delete(type).where(type.name.in_([row.name for row in rows]))
self.__session.execute(stmt)
def updategroup(self, categories: list[Category], group: CategoryGroup):
stmt = (
update(Category)
.where(Category.name.in_([cat.name for cat in categories]))
.values(group=group)
)
self.__session.execute(stmt)
def updateschedules(self, schedules: list[CategorySchedule]):
stmt = insert(CategorySchedule).values([asdict(s) for s in schedules])
stmt = stmt.on_conflict_do_update(
index_elements=[CategorySchedule.name],
set_=dict(
recurring=stmt.excluded.recurring,
period=stmt.excluded.period,
period_multiplier=stmt.excluded.period_multiplier,
),
)
self.__session.execute(stmt)
def remove_by_id(self, type, ids: list[int]):
stmt = delete(type).where(type.id.in_(ids))
self.__session.execute(stmt)
def update(self, type, values: list[dict]):
print(type, values)
self.__session.execute(update(type), values)
def remove_links(self, original: int, links: list[int]):
stmt = delete(Link).where(
Link.original == original, Link.link.in_(link for link in links)
)
self.__session.execute(stmt)
def session(self) -> ClientSession:
return self.ClientSession(self.engine)