and as a result, had to fix a LOT of minor potential future issue. It also reorders and clears unused imports. When exporting transactions, it will sort by date.
103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
from dataclasses import asdict
|
|
from sqlalchemy import create_engine, delete, select, update
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlalchemy.orm import Session
|
|
from typing import Sequence, Type, TypeVar
|
|
|
|
from pfbudget.db.model import (
|
|
Category,
|
|
CategoryGroup,
|
|
CategorySchedule,
|
|
Link,
|
|
)
|
|
|
|
|
|
class DbClient:
|
|
"""
|
|
General database client using sqlalchemy
|
|
"""
|
|
|
|
__sessions: list[Session]
|
|
|
|
def __init__(self, url: str, echo=False) -> None:
|
|
self._engine = create_engine(url, echo=echo)
|
|
|
|
@property
|
|
def engine(self):
|
|
return self._engine
|
|
|
|
class ClientSession:
|
|
def __init__(self, engine):
|
|
self.__engine = engine
|
|
|
|
def __enter__(self):
|
|
self.__session = Session(self.__engine)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
self.commit()
|
|
self.__session.close()
|
|
|
|
def commit(self):
|
|
self.__session.commit()
|
|
|
|
def expunge_all(self):
|
|
self.__session.expunge_all()
|
|
|
|
T = TypeVar("T")
|
|
|
|
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
|
|
if column is not None:
|
|
if values:
|
|
stmt = select(type).where(column.in_(values))
|
|
else:
|
|
stmt = select(type).where(column)
|
|
else:
|
|
stmt = select(type)
|
|
|
|
return self.__session.scalars(stmt).all()
|
|
|
|
def add(self, rows: list):
|
|
self.__session.add_all(rows)
|
|
|
|
def remove_by_name(self, type, rows: list):
|
|
stmt = delete(type).where(type.name.in_([row.name for row in rows]))
|
|
self.__session.execute(stmt)
|
|
|
|
def updategroup(self, categories: list[Category], group: CategoryGroup):
|
|
stmt = (
|
|
update(Category)
|
|
.where(Category.name.in_([cat.name for cat in categories]))
|
|
.values(group=group)
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def updateschedules(self, schedules: list[CategorySchedule]):
|
|
stmt = insert(CategorySchedule).values([asdict(s) for s in schedules])
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=[CategorySchedule.name],
|
|
set_=dict(
|
|
recurring=stmt.excluded.recurring,
|
|
period=stmt.excluded.period,
|
|
period_multiplier=stmt.excluded.period_multiplier,
|
|
),
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def remove_by_id(self, type, ids: list[int]):
|
|
stmt = delete(type).where(type.id.in_(ids))
|
|
self.__session.execute(stmt)
|
|
|
|
def update(self, type, values: list[dict]):
|
|
print(type, values)
|
|
self.__session.execute(update(type), values)
|
|
|
|
def remove_links(self, original: int, links: list[int]):
|
|
stmt = delete(Link).where(
|
|
Link.original == original, Link.link.in_(link for link in links)
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def session(self) -> ClientSession:
|
|
return self.ClientSession(self.engine)
|