diff --git a/pfbudget/cli/argparser.py b/pfbudget/cli/argparser.py index 74de893..cb2ba13 100644 --- a/pfbudget/cli/argparser.py +++ b/pfbudget/cli/argparser.py @@ -67,7 +67,7 @@ def argparser() -> argparse.ArgumentParser: pimport = subparsers.add_parser("import") pimport.set_defaults(op=Operation.Import) - pimport.add_argument("file", nargs=1, type=str) + file_options(pimport) # Parse from .csv parse = subparsers.add_parser("parse") diff --git a/pfbudget/core/manager.py b/pfbudget/core/manager.py index 08ee0b1..d1d0fa3 100644 --- a/pfbudget/core/manager.py +++ b/pfbudget/core/manager.py @@ -1,10 +1,12 @@ import csv +import json from pathlib import Path import pickle +from typing import Optional import webbrowser from pfbudget.common.types import Operation -from pfbudget.db.postgresql import DbClient +from pfbudget.db.client import Client from pfbudget.db.model import ( Bank, BankTransaction, @@ -24,6 +26,7 @@ from pfbudget.db.model import ( Transaction, TransactionCategory, ) +from pfbudget.db.postgresql import DbClient from pfbudget.extract.nordigen import NordigenClient, NordigenCredentialsManager from pfbudget.extract.parsers import parse_data from pfbudget.extract.psd2 import PSD2Extractor @@ -35,6 +38,7 @@ from pfbudget.transform.tagger import Tagger class Manager: def __init__(self, db: str, verbosity: int = 0): self._db = db + self._database: Optional[Client] = None self._verbosity = verbosity def action(self, op: Operation, params=None): @@ -49,10 +53,7 @@ class Manager: pass case Operation.Transactions: - with self.db.session() as session: - transactions = session.get(Transaction) - ret = [t.format for t in transactions] - return ret + return [t.format for t in self.database.select(Transaction)] case Operation.Parse: # Adapter for the parse_data method. Can be refactored. @@ -263,8 +264,10 @@ class Manager: session.insert(transactions) case Operation.Export: - with self.db.session() as session: - self.dump(params[0], params[1], sorted(session.get(Transaction))) + with self.database.session as session: + self.dump( + params[0], params[1], self.database.select(Transaction, session) + ) case Operation.Import: transactions = [] @@ -301,8 +304,8 @@ class Manager: session.insert(transactions) case Operation.ExportBanks: - with self.db.session() as session: - self.dump(params[0], params[1], session.get(Bank)) + with self.database.session as session: + self.dump(params[0], params[1], self.database.select(Bank, session)) case Operation.ImportBanks: banks = [] @@ -317,8 +320,12 @@ class Manager: session.insert(banks) case Operation.ExportCategoryRules: - with self.db.session() as session: - self.dump(params[0], params[1], session.get(CategoryRule)) + with self.database.session as session: + self.dump( + params[0], + params[1], + self.database.select(CategoryRule, session), + ) case Operation.ImportCategoryRules: rules = [CategoryRule(**row) for row in self.load(params[0], params[1])] @@ -328,8 +335,10 @@ class Manager: session.insert(rules) case Operation.ExportTagRules: - with self.db.session() as session: - self.dump(params[0], params[1], session.get(TagRule)) + with self.database.session as session: + self.dump( + params[0], params[1], self.database.select(TagRule, session) + ) case Operation.ImportTagRules: rules = [TagRule(**row) for row in self.load(params[0], params[1])] @@ -339,8 +348,10 @@ class Manager: session.insert(rules) case Operation.ExportCategories: - with self.db.session() as session: - self.dump(params[0], params[1], session.get(Category)) + with self.database.session as session: + self.dump( + params[0], params[1], self.database.select(Category, session) + ) case Operation.ImportCategories: # rules = [Category(**row) for row in self.load(params[0])] @@ -363,8 +374,12 @@ class Manager: session.insert(categories) case Operation.ExportCategoryGroups: - with self.db.session() as session: - self.dump(params[0], params[1], session.get(CategoryGroup)) + with self.database.session as session: + self.dump( + params[0], + params[1], + self.database.select(CategoryGroup, session), + ) case Operation.ImportCategoryGroups: groups = [ @@ -397,6 +412,9 @@ class Manager: elif format == "csv": with open(fn, "w", newline="") as f: csv.writer(f).writerows([e.format.values() for e in sequence]) + elif format == "json": + with open(fn, "w", newline="") as f: + json.dump([e.format for e in sequence], f, indent=4, default=str) else: print("format not well specified") @@ -418,9 +436,15 @@ class Manager: return False @property - def db(self) -> DbClient: + def db(self) -> Client: return DbClient(self._db, self._verbosity > 2) + @property + def database(self) -> Client: + if not self._database: + self._database = Client(self._db, echo=self._verbosity > 2) + return self._database + @db.setter def db(self, url: str): self._db = url diff --git a/pfbudget/db/client.py b/pfbudget/db/client.py index a652a35..0c1a5c5 100644 --- a/pfbudget/db/client.py +++ b/pfbudget/db/client.py @@ -8,7 +8,7 @@ from pfbudget.db.model import Transaction class Client: - def __init__(self, url: str, **kwargs: dict[str, Any]) -> None: + def __init__(self, url: str, **kwargs: Any) -> None: assert url, "Database URL is empty!" self._engine = create_engine(url, **kwargs) self._sessionmaker: Optional[sessionmaker[Session]] = None diff --git a/pfbudget/db/model.py b/pfbudget/db/model.py index 6cafb49..8ae36a8 100644 --- a/pfbudget/db/model.py +++ b/pfbudget/db/model.py @@ -216,7 +216,9 @@ class TransactionCategory(Base, Export): @property def format(self): - return dict(name=self.name, selector=self.selector.format) + return dict( + name=self.name, selector=self.selector.format if self.selector else None + ) class Note(Base):