diff --git a/pfbudget/cli/runnable.py b/pfbudget/cli/runnable.py index 6778110..b980a1c 100644 --- a/pfbudget/cli/runnable.py +++ b/pfbudget/cli/runnable.py @@ -112,7 +112,7 @@ def argparser(manager: Manager) -> argparse.ArgumentParser: formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p_categorize.set_defaults( - func=lambda args: categorize_data(DatabaseClient(args.database)) + func=lambda args: manager.categorize(vars(args)) ) """ diff --git a/pfbudget/core/manager.py b/pfbudget/core/manager.py index b04e410..071324a 100644 --- a/pfbudget/core/manager.py +++ b/pfbudget/core/manager.py @@ -1,6 +1,7 @@ from pfbudget.input.input import Input from pfbudget.input.parsers import parse_data from pfbudget.db.client import DbClient +from pfbudget.core.categorizer import Categorizer from pfbudget.utils import convert @@ -38,6 +39,12 @@ class Manager: session.add(transactions) session.commit() + def categorize(self, args: dict): + with self.db.session() as session: + uncategorized = session.uncategorized() + Categorizer().categorize(uncategorized) + session.commit() + # def get_bank_by(self, key: str, value: str) -> Bank: # client = DatabaseClient(self.__db) # bank = client.get_bank(key, value) diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index 3da414c..0b4d1c5 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -1,7 +1,7 @@ from sqlalchemy import create_engine, select from sqlalchemy.orm import Session, joinedload, selectinload -from pfbudget.db.model import Bank, Transaction +from pfbudget.db.model import Bank, Category, Transaction # import logging @@ -55,3 +55,30 @@ class DbClient: @property def engine(self): return self._engine + + class ClientSession: + def __init__(self, engine): + self.__engine = engine + + def __enter__(self): + self.__session = Session(self.__engine) + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.__session.close() + + def commit(self): + self.__session.commit() + + def add(self, transactions: list[Transaction]): + self.__session.add_all(transactions) + + def addcategory(self, category: Category): + self.__session.add(category) + + def uncategorized(self) -> list[Transaction]: + stmt = select(Transaction).where(~Transaction.category.has()) + return self.__session.scalars(stmt).all() + + def session(self): + return self.ClientSession(self.engine) diff --git a/pfbudget/db/model.py b/pfbudget/db/model.py index f985e8a..fb900ce 100644 --- a/pfbudget/db/model.py +++ b/pfbudget/db/model.py @@ -105,9 +105,8 @@ class Category(Base): group: Mapped[Optional[str]] = mapped_column(ForeignKey(CategoryGroup.name)) rules: Mapped[Optional[set[CategoryRule]]] = relationship( - back_populates="category", cascade="all, delete-orphan", passive_deletes=True + cascade="all, delete-orphan", passive_deletes=True ) - categorygroup: Mapped[Optional[CategoryGroup]] = relationship() class TransactionCategory(Base): @@ -117,7 +116,6 @@ class TransactionCategory(Base): name: Mapped[str] = mapped_column(ForeignKey(Category.name)) original: Mapped[Transaction] = relationship(back_populates="category") - category: Mapped[Category] = relationship() def __repr__(self) -> str: return f"Category({self.name})" @@ -162,5 +160,3 @@ class CategoryRule(Base): ForeignKey(Category.name, ondelete="CASCADE"), primary_key=True ) rule: Mapped[str] = mapped_column(primary_key=True) - - category: Mapped[Category] = relationship(back_populates="rules")