parser command now writes do DB

`parse_data` from parser.py now takes a `DBManager` and runs
`insert_transactions` after parsing.
`DBManager` now takes `Transaction` type and returns the same on select
queries.
`DBManager` refactored with extensive type hinting.
`Transaction` updated to reflect use in `DBManager`.
This commit is contained in:
Luís Murta 2021-06-21 23:49:23 +01:00
parent fe2b347a53
commit 3e20ae97db
Signed by: satprog
GPG Key ID: DDF2EFC6179009DC
4 changed files with 75 additions and 56 deletions

View File

@ -1,15 +1,22 @@
from __future__ import annotations
import csv import csv
import datetime import datetime
import logging import logging
import logging.config import logging.config
import pathlib import pathlib
import sqlite3 import sqlite3
from decimal import Decimal
from .transactions import Transaction
if not pathlib.Path("logs").is_dir(): if not pathlib.Path("logs").is_dir():
pathlib.Path("logs").mkdir() pathlib.Path("logs").mkdir()
logging.config.fileConfig("logging.conf") logging.config.fileConfig("logging.conf")
logger = logging.getLogger("pfbudget.transactions") logger = logging.getLogger("pfbudget.transactions")
sqlite3.register_adapter(Decimal, lambda d: float(d))
__DB_NAME = "data.db" __DB_NAME = "data.db"
CREATE_TRANSACTIONS_TABLE = """ CREATE_TRANSACTIONS_TABLE = """
@ -22,13 +29,6 @@ CREATE TABLE IF NOT EXISTS transactions (
); );
""" """
CREATE_VACATIONS_TABLE = """
CREATE TABLE IF NOT EXISTS vacations (
start TEXT NOT NULL,
end TEXT NOT NULL
)
"""
CREATE_BACKUPS_TABLE = """ CREATE_BACKUPS_TABLE = """
CREATE TABLE IF NOT EXISTS backups ( CREATE TABLE IF NOT EXISTS backups (
datetime TEXT NOT NULL, datetime TEXT NOT NULL,
@ -96,10 +96,10 @@ class DBManager:
__EXPORT_DIR = "export" __EXPORT_DIR = "export"
def __init__(self, db): def __init__(self, db: str):
self.db = db self.db = db
def __execute(self, query, params=None): def __execute(self, query: str, params: tuple = None) -> list | None:
ret = None ret = None
try: try:
con = sqlite3.connect(self.db) con = sqlite3.connect(self.db)
@ -120,7 +120,7 @@ class DBManager:
return ret return ret
def __executemany(self, query, list_of_params): def __executemany(self, query: str, list_of_params: list[tuple]) -> list | None:
ret = None ret = None
try: try:
con = sqlite3.connect(self.db) con = sqlite3.connect(self.db)
@ -136,20 +136,15 @@ class DBManager:
return ret return ret
def __create_tables(self, tables): def __create_tables(self, tables: tuple[tuple]):
for table_name, query in tables: for table_name, query in tables:
logger.info(f"Creating table if it doesn't exist {table_name}") logger.info(f"Creating table if it doesn't exist {table_name}")
self.__execute(query) self.__execute(query)
def query(self, query, params=None):
logger.info(f"Executing {query} with params={params}")
return self.__execute(query, params)
def init(self): def init(self):
self.__create_tables( self.__create_tables(
( (
("transactions", CREATE_TRANSACTIONS_TABLE), ("transactions", CREATE_TRANSACTIONS_TABLE),
("vacations", CREATE_VACATIONS_TABLE),
("backups", CREATE_BACKUPS_TABLE), ("backups", CREATE_BACKUPS_TABLE),
) )
) )
@ -158,55 +153,78 @@ class DBManager:
logger.info(f"Reading all transactions from {self.db}") logger.info(f"Reading all transactions from {self.db}")
return self.__execute("SELECT * FROM transactions") return self.__execute("SELECT * FROM transactions")
def add_transaction(self, transaction): def insert_transaction(self, transaction: Transaction):
logger.info(f"Adding {transaction} into {self.db}") logger.info(f"Adding {transaction} into {self.db}")
self.__execute(ADD_TRANSACTION, transaction) self.__execute(ADD_TRANSACTION, (transaction.to_list(),))
def add_transactions(self, transactions): def insert_transactions(self, transactions: list[Transaction]):
logger.info(f"Adding {len(transactions)} into {self.db}") logger.info(f"Adding {len(transactions)} into {self.db}")
transactions = [t.to_list() for t in transactions]
self.__executemany(ADD_TRANSACTION, transactions) self.__executemany(ADD_TRANSACTION, transactions)
def update_category(self, transaction): def update_category(self, transaction: Transaction):
logger.info(f"Update {transaction} category") logger.info(f"Update {transaction} category")
self.__execute(UPDATE_CATEGORY, (transaction[4], *transaction[:4])) self.__execute(UPDATE_CATEGORY, (transaction[4], *transaction[:4]))
def update_categories(self, transactions): def update_categories(self, transactions: list[Transaction]):
logger.info(f"Update {len(transactions)} transactions' categories") logger.info(f"Update {len(transactions)} transactions' categories")
self.__executemany( self.__executemany(
UPDATE_CATEGORY, UPDATE_CATEGORY, [transaction for transaction in transactions]
[(transaction[4], *transaction[:4]) for transaction in transactions],
) )
def get_duplicated_transactions(self): def get_duplicated_transactions(self) -> list[Transaction] | None:
logger.info("Get duplicated transactions") logger.info("Get duplicated transactions")
return self.__execute(DUPLICATED_TRANSACTIONS) transactions = self.__execute(DUPLICATED_TRANSACTIONS)
if transactions:
return [Transaction(t) for t in transactions]
return None
def get_sorted_transactions(self, key): def get_sorted_transactions(self, key: str) -> list[Transaction] | None:
logger.info(f"Get transactions sorted by {key}") logger.info(f"Get transactions sorted by {key}")
return self.__execute(SORTED_TRANSACTIONS, key) transactions = self.__execute(SORTED_TRANSACTIONS, key)
if transactions:
return [Transaction(t) for t in transactions]
return None
def get_daterange(self, start, end): def get_daterange(self, start: datetime, end: datetime) -> list[Transaction] | None:
logger.info(f"Get transactions from {start} to {end}") logger.info(f"Get transactions from {start} to {end}")
return self.__execute(SELECT_TRANSACTIONS_BETWEEN_DATES, (start, end)) transactions = self.__execute(SELECT_TRANSACTIONS_BETWEEN_DATES, (start, end))
if transactions:
return [Transaction(t) for t in transactions]
return None
def get_category(self, value): def get_category(self, value: str) -> list[Transaction] | None:
logger.info(f"Get transaction where category = {value}") logger.info(f"Get transaction where category = {value}")
return self.__execute(SELECT_TRANSACTIONS_BY_CATEGORY, (value,)) transactions = self.__execute(SELECT_TRANSACTIONS_BY_CATEGORY, (value,))
if transactions:
return [Transaction(t) for t in transactions]
return None
def get_by_period(self, period): def get_by_period(self, period: str) -> list[Transaction] | None:
logger.info(f"Get transactions by {period}") logger.info(f"Get transactions by {period}")
return self.__execute(SELECT_TRANSACTION_BY_PERIOD, period) transactions = self.__execute(SELECT_TRANSACTION_BY_PERIOD, period)
if transactions:
return [Transaction(t) for t in transactions]
return None
def get_uncategorized_transactions(self): def get_uncategorized_transactions(self) -> list[Transaction] | None:
logger.info("Get uncategorized transactions") logger.info("Get uncategorized transactions")
return self.get_category(None) transactions = self.get_category(None)
if transactions:
return [Transaction(t) for t in transactions]
return None
def get_daterage_without(self, start, end, *categories): def get_daterage_without(
self, start: datetime, end: datetime, *categories: str
) -> list[Transaction] | None:
logger.info(f"Get transactions between {start} and {end} not in {categories}") logger.info(f"Get transactions between {start} and {end} not in {categories}")
query = SELECT_TRANSACTIONS_BETWEEN_DATES_WITHOUT_CATEGORIES.format( query = SELECT_TRANSACTIONS_BETWEEN_DATES_WITHOUT_CATEGORIES.format(
"(" + ", ".join("?" for _ in categories) + ")" "(" + ", ".join("?" for _ in categories) + ")"
) )
return self.__execute(query, (start, end, *categories)) transactions = self.__execute(query, (start, end, *categories))
if transactions:
return [Transaction(t) for t in transactions]
return None
def export(self): def export(self):
filename = pathlib.Path( filename = pathlib.Path(

View File

@ -1,18 +1,15 @@
from collections import namedtuple from collections import namedtuple
from decimal import Decimal from decimal import Decimal
from importlib import import_module from importlib import import_module
from typing import Final from typing import TYPE_CHECKING
import datetime as dt import datetime as dt
import yaml import yaml
from .transactions import Transaction from .transactions import Transaction
from . import utils from . import utils
if TYPE_CHECKING:
cfg: Final = yaml.safe_load(open("parsers.yaml")) from .database import DBManager
assert (
"Banks" in cfg
), "parsers.yaml is missing the Banks section with the list of available banks"
Index = namedtuple( Index = namedtuple(
"Index", ["date", "text", "value", "negate"], defaults=[-1, -1, -1, False] "Index", ["date", "text", "value", "negate"], defaults=[-1, -1, -1, False]
@ -36,25 +33,33 @@ Options = namedtuple(
) )
def parse_data(filename: str, bank=None) -> list: def parse_data(db: DBManager, filename: str, bank: list = []) -> None:
cfg: dict = yaml.safe_load(open("parsers.yaml"))
assert (
"Banks" in cfg
), "parsers.yaml is missing the Banks section with the list of available banks"
if not bank: if not bank:
bank, creditcard = utils.find_credit_institution( bank, creditcard = utils.find_credit_institution(
filename, cfg.get("Banks"), cfg.get("CreditCards") filename, cfg.get("Banks"), cfg.get("CreditCards")
) )
else:
bank = bank[0]
creditcard = None
if creditcard: if creditcard:
options = cfg[bank][creditcard] options: dict = cfg[bank][creditcard]
bank += creditcard bank += creditcard
else: else:
options = cfg[bank] options: dict = cfg[bank]
if options.get("additional_parser", False): if options.get("additional_parser"):
parser = getattr(import_module("pfbudget.parsers"), bank) parser = getattr(import_module("pfbudget.parsers"), bank)
transactions = parser(filename, bank, options).parse() transactions = parser(filename, bank, options).parse()
else: else:
transactions = Parser(filename, bank, options).parse() transactions = Parser(filename, bank, options).parse()
return transactions db.insert_transactions(transactions)
def transaction(line: str, bank: str, options: Options, func) -> Transaction: def transaction(line: str, bank: str, options: Options, func) -> Transaction:

View File

@ -100,12 +100,11 @@ def parse(args, db):
for path in args.path: for path in args.path:
if (dir := Path(path)).is_dir(): if (dir := Path(path)).is_dir():
for file in dir.iterdir(): for file in dir.iterdir():
parse_data(file, args.bank) parse_data(db, file, args.bank)
elif Path(path).is_file(): elif Path(path).is_file():
trs = parse_data(path, args.bank) parse_data(db, path, args.bank)
else: else:
raise FileNotFoundError raise FileNotFoundError
print("\n".join([t.desc() for t in trs]))
def categorize(args, db): def categorize(args, db):

View File

@ -1,11 +1,8 @@
from csv import reader, writer from csv import reader, writer
from datetime import date from datetime import date
from dateutil.rrule import rrule, MONTHLY, YEARLY
from decimal import Decimal, InvalidOperation from decimal import Decimal, InvalidOperation
from pathlib import Path from pathlib import Path
from .categories import get_categories
COMMENT_TOKEN = "#" COMMENT_TOKEN = "#"
@ -19,7 +16,7 @@ class Transaction:
self.description = "" self.description = ""
self.bank = "" self.bank = ""
self.value = 0 self.value = 0
self.category = "" self.category = None
arg = args[0] if len(args) == 1 else list(args) arg = args[0] if len(args) == 1 else list(args)
try: try:
@ -46,7 +43,7 @@ class Transaction:
self.modified = False self.modified = False
def to_csv(self): def to_list(self):
return [self.date, self.description, self.bank, self.value, self.category] return [self.date, self.description, self.bank, self.value, self.category]
@property @property