`add` method replaced with `insert`. `insert` and `select` implemented for new database base class. Database unit test added. Due to SQLite implementation of the primary key autoinc, the type of the IDs on the database for SQLite changed to Integer. https://www.sqlite.org/autoinc.html
124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
from dataclasses import asdict
|
|
from sqlalchemy import create_engine, delete, select, update
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.sql.expression import false
|
|
from typing import Sequence, Type, TypeVar
|
|
|
|
from pfbudget.db.model import (
|
|
Category,
|
|
CategoryGroup,
|
|
CategorySchedule,
|
|
Link,
|
|
Transaction,
|
|
)
|
|
|
|
|
|
class DbClient:
|
|
"""
|
|
General database client using sqlalchemy
|
|
"""
|
|
|
|
__sessions: list[Session]
|
|
|
|
def __init__(self, url: str, echo=False) -> None:
|
|
self._engine = create_engine(url, echo=echo)
|
|
|
|
@property
|
|
def engine(self):
|
|
return self._engine
|
|
|
|
class ClientSession:
|
|
def __init__(self, engine):
|
|
self.__engine = engine
|
|
|
|
def __enter__(self):
|
|
self.__session = Session(self.__engine)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
self.commit()
|
|
self.__session.close()
|
|
|
|
def commit(self):
|
|
self.__session.commit()
|
|
|
|
def expunge_all(self):
|
|
self.__session.expunge_all()
|
|
|
|
T = TypeVar("T")
|
|
|
|
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
|
|
if column is not None:
|
|
if values:
|
|
if isinstance(values, Sequence):
|
|
stmt = select(type).where(column.in_(values))
|
|
else:
|
|
stmt = select(type).where(column == values)
|
|
else:
|
|
stmt = select(type).where(column)
|
|
else:
|
|
stmt = select(type)
|
|
|
|
return self.__session.scalars(stmt).all()
|
|
|
|
def uncategorized(self) -> Sequence[Transaction]:
|
|
"""Selects all valid uncategorized transactions
|
|
At this moment that includes:
|
|
- Categories w/o category
|
|
- AND non-split categories
|
|
|
|
Returns:
|
|
Sequence[Transaction]: transactions left uncategorized
|
|
"""
|
|
stmt = (
|
|
select(Transaction)
|
|
.where(~Transaction.category.has())
|
|
.where(Transaction.split == false())
|
|
)
|
|
return self.__session.scalars(stmt).all()
|
|
|
|
def insert(self, rows: list):
|
|
self.__session.add_all(rows)
|
|
|
|
def remove_by_name(self, type, rows: list):
|
|
stmt = delete(type).where(type.name.in_([row.name for row in rows]))
|
|
self.__session.execute(stmt)
|
|
|
|
def updategroup(self, categories: list[Category], group: CategoryGroup):
|
|
stmt = (
|
|
update(Category)
|
|
.where(Category.name.in_([cat.name for cat in categories]))
|
|
.values(group=group)
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def updateschedules(self, schedules: list[CategorySchedule]):
|
|
stmt = insert(CategorySchedule).values([asdict(s) for s in schedules])
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=[CategorySchedule.name],
|
|
set_=dict(
|
|
recurring=stmt.excluded.recurring,
|
|
period=stmt.excluded.period,
|
|
period_multiplier=stmt.excluded.period_multiplier,
|
|
),
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def remove_by_id(self, type, ids: list[int]):
|
|
stmt = delete(type).where(type.id.in_(ids))
|
|
self.__session.execute(stmt)
|
|
|
|
def update(self, type, values: list[dict]):
|
|
print(type, values)
|
|
self.__session.execute(update(type), values)
|
|
|
|
def remove_links(self, original: int, links: list[int]):
|
|
stmt = delete(Link).where(
|
|
Link.original == original, Link.link.in_(link for link in links)
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def session(self) -> ClientSession:
|
|
return self.ClientSession(self.engine)
|