Adds uncategorized method to the DB client to retrieve transactions w/o a category AND not splitted.
121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
from dataclasses import asdict
|
|
from sqlalchemy import create_engine, delete, select, update
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.sql.expression import false
|
|
from typing import Sequence, Type, TypeVar
|
|
|
|
from pfbudget.db.model import (
|
|
Category,
|
|
CategoryGroup,
|
|
CategorySchedule,
|
|
Link,
|
|
Transaction,
|
|
)
|
|
|
|
|
|
class DbClient:
|
|
"""
|
|
General database client using sqlalchemy
|
|
"""
|
|
|
|
__sessions: list[Session]
|
|
|
|
def __init__(self, url: str, echo=False) -> None:
|
|
self._engine = create_engine(url, echo=echo)
|
|
|
|
@property
|
|
def engine(self):
|
|
return self._engine
|
|
|
|
class ClientSession:
|
|
def __init__(self, engine):
|
|
self.__engine = engine
|
|
|
|
def __enter__(self):
|
|
self.__session = Session(self.__engine)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_tb):
|
|
self.commit()
|
|
self.__session.close()
|
|
|
|
def commit(self):
|
|
self.__session.commit()
|
|
|
|
def expunge_all(self):
|
|
self.__session.expunge_all()
|
|
|
|
T = TypeVar("T")
|
|
|
|
def get(self, type: Type[T], column=None, values=None) -> Sequence[T]:
|
|
if column is not None:
|
|
if values:
|
|
stmt = select(type).where(column.in_(values))
|
|
else:
|
|
stmt = select(type).where(column)
|
|
else:
|
|
stmt = select(type)
|
|
|
|
return self.__session.scalars(stmt).all()
|
|
|
|
def uncategorized(self) -> Sequence[Transaction]:
|
|
"""Selects all valid uncategorized transactions
|
|
At this moment that includes:
|
|
- Categories w/o category
|
|
- AND non-split categories
|
|
|
|
Returns:
|
|
Sequence[Transaction]: transactions left uncategorized
|
|
"""
|
|
stmt = (
|
|
select(Transaction)
|
|
.where(~Transaction.category.has())
|
|
.where(Transaction.split == false())
|
|
)
|
|
return self.__session.scalars(stmt).all()
|
|
|
|
def add(self, rows: list):
|
|
self.__session.add_all(rows)
|
|
|
|
def remove_by_name(self, type, rows: list):
|
|
stmt = delete(type).where(type.name.in_([row.name for row in rows]))
|
|
self.__session.execute(stmt)
|
|
|
|
def updategroup(self, categories: list[Category], group: CategoryGroup):
|
|
stmt = (
|
|
update(Category)
|
|
.where(Category.name.in_([cat.name for cat in categories]))
|
|
.values(group=group)
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def updateschedules(self, schedules: list[CategorySchedule]):
|
|
stmt = insert(CategorySchedule).values([asdict(s) for s in schedules])
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=[CategorySchedule.name],
|
|
set_=dict(
|
|
recurring=stmt.excluded.recurring,
|
|
period=stmt.excluded.period,
|
|
period_multiplier=stmt.excluded.period_multiplier,
|
|
),
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def remove_by_id(self, type, ids: list[int]):
|
|
stmt = delete(type).where(type.id.in_(ids))
|
|
self.__session.execute(stmt)
|
|
|
|
def update(self, type, values: list[dict]):
|
|
print(type, values)
|
|
self.__session.execute(update(type), values)
|
|
|
|
def remove_links(self, original: int, links: list[int]):
|
|
stmt = delete(Link).where(
|
|
Link.original == original, Link.link.in_(link for link in links)
|
|
)
|
|
self.__session.execute(stmt)
|
|
|
|
def session(self) -> ClientSession:
|
|
return self.ClientSession(self.engine)
|