Compare commits

..

No commits in common. "29c5206638d13ef42e70d9bbbfff0435624101a8" and "bdd7cac4be597cba3232ffc60c9e539bb13fcbd5" have entirely different histories.

32 changed files with 1357 additions and 2291 deletions

3
.gitignore vendored
View File

@ -174,6 +174,3 @@ poetry.toml
pyrightconfig.json pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python # End of https://www.toptal.com/developers/gitignore/api/python
# Project specific ignores
database.db

View File

@ -1,35 +0,0 @@
"""nordigen tokens
Revision ID: 325b901ac712
Revises: 60469d5dd2b0
Create Date: 2023-05-25 19:10:10.374008+00:00
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "325b901ac712"
down_revision = "60469d5dd2b0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"nordigen",
sa.Column("type", sa.String(), nullable=False),
sa.Column("token", sa.String(), nullable=False),
sa.Column("expires", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("type", name=op.f("pk_nordigen")),
schema="pfbudget",
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("nordigen", schema="pfbudget")
# ### end Alembic commands ###

View File

@ -1,48 +0,0 @@
"""Drop SQLAlchemy enum
Revision ID: 60469d5dd2b0
Revises: b599dafcf468
Create Date: 2023-05-15 19:24:07.911352+00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "60469d5dd2b0"
down_revision = "b599dafcf468"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
CREATE TYPE pfbudget.scheduleperiod
AS ENUM ('daily', 'weekly', 'monthly', 'yearly')
"""
)
op.execute(
"""ALTER TABLE pfbudget.category_schedules
ALTER COLUMN period TYPE pfbudget.scheduleperiod
USING period::text::pfbudget.scheduleperiod
"""
)
op.execute("DROP TYPE pfbudget.period")
def downgrade() -> None:
op.execute(
"""
CREATE TYPE pfbudget.period
AS ENUM ('daily', 'weekly', 'monthly', 'yearly')
"""
)
op.execute(
"""ALTER TABLE pfbudget.category_schedules
ALTER COLUMN period TYPE pfbudget.period
USING period::text::pfbudget.period
"""
)
op.execute("DROP TYPE pfbudget.scheduleperiod")

View File

@ -1,74 +0,0 @@
"""Compact category selector
Revision ID: 8623e709e111
Revises: ce68ee15e5d2
Create Date: 2023-05-08 19:00:51.063240+00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "8623e709e111"
down_revision = "ce68ee15e5d2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("category_selectors", schema="pfbudget")
op.add_column(
"transactions_categorized",
sa.Column(
"selector",
sa.Enum(
"unknown",
"nullifier",
"vacations",
"rules",
"algorithm",
"manual",
name="selector_t",
schema="pfbudget",
inherit_schema=True,
),
nullable=False,
),
schema="pfbudget",
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transactions_categorized", "selector", schema="pfbudget")
op.create_table(
"category_selectors",
sa.Column("id", sa.BIGINT(), autoincrement=False, nullable=False),
sa.Column(
"selector",
postgresql.ENUM(
"unknown",
"nullifier",
"vacations",
"rules",
"algorithm",
"manual",
name="selector_t",
schema="pfbudget",
),
autoincrement=False,
nullable=False,
),
sa.ForeignKeyConstraint(
["id"],
["pfbudget.transactions_categorized.id"],
name="fk_category_selectors_id_transactions_categorized",
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name="pk_category_selectors"),
schema="pfbudget",
)
# ### end Alembic commands ###

View File

@ -1,46 +0,0 @@
"""Selector type name change
Revision ID: b599dafcf468
Revises: 8623e709e111
Create Date: 2023-05-08 19:46:20.661214+00:00
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b599dafcf468"
down_revision = "8623e709e111"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
CREATE TYPE pfbudget.categoryselector
AS ENUM ('unknown', 'nullifier', 'vacations', 'rules', 'algorithm', 'manual')
"""
)
op.execute(
"""ALTER TABLE pfbudget.transactions_categorized
ALTER COLUMN selector TYPE pfbudget.categoryselector
USING selector::text::pfbudget.categoryselector
"""
)
op.execute("DROP TYPE pfbudget.selector_t")
def downgrade() -> None:
op.execute(
"""
CREATE TYPE pfbudget.selector_t
AS ENUM ('unknown', 'nullifier', 'vacations', 'rules', 'algorithm', 'manual')
"""
)
op.execute(
"""ALTER TABLE pfbudget.transactions_categorized
ALTER COLUMN selector TYPE pfbudget.selector_t
USING selector::text::pfbudget.selector_t
"""
)
op.execute("DROP TYPE pfbudget.categoryselector")

View File

@ -38,10 +38,10 @@ if __name__ == "__main__":
params = [args["path"], args["bank"], args["creditcard"]] params = [args["path"], args["bank"], args["creditcard"]]
case Operation.RequisitionId: case Operation.RequisitionId:
keys = {"bank"} keys = {"name", "country"}
assert args.keys() >= keys, f"missing {args.keys() - keys}" assert args.keys() >= keys, f"missing {args.keys() - keys}"
params = [args["bank"][0]] params = [args["name"][0], args["country"][0]]
case Operation.Download: case Operation.Download:
keys = {"all", "banks", "interval", "start", "end", "year", "dry_run"} keys = {"all", "banks", "interval", "start", "end", "year", "dry_run"}
@ -163,14 +163,14 @@ if __name__ == "__main__":
params = [ params = [
type.CategoryRule( type.CategoryRule(
args["start"][0] if args["start"] else None,
args["end"][0] if args["end"] else None,
args["description"][0] if args["description"] else None,
args["regex"][0] if args["regex"] else None,
args["bank"][0] if args["bank"] else None,
args["min"][0] if args["min"] else None,
args["max"][0] if args["max"] else None,
cat, cat,
start=args["start"][0] if args["start"] else None,
end=args["end"][0] if args["end"] else None,
description=args["description"][0] if args["description"] else None,
regex=args["regex"][0] if args["regex"] else None,
bank=args["bank"][0] if args["bank"] else None,
min=args["min"][0] if args["min"] else None,
max=args["max"][0] if args["max"] else None,
) )
for cat in args["category"] for cat in args["category"]
] ]
@ -215,14 +215,14 @@ if __name__ == "__main__":
params = [ params = [
type.TagRule( type.TagRule(
args["start"][0] if args["start"] else None,
args["end"][0] if args["end"] else None,
args["description"][0] if args["description"] else None,
args["regex"][0] if args["regex"] else None,
args["bank"][0] if args["bank"] else None,
args["min"][0] if args["min"] else None,
args["max"][0] if args["max"] else None,
tag, tag,
start=args["start"][0] if args["start"] else None,
end=args["end"][0] if args["end"] else None,
description=args["description"][0] if args["description"] else None,
regex=args["regex"][0] if args["regex"] else None,
bank=args["bank"][0] if args["bank"] else None,
min=args["min"][0] if args["min"] else None,
max=args["max"][0] if args["max"] else None,
) )
for tag in args["tag"] for tag in args["tag"]
] ]

View File

@ -6,7 +6,7 @@ import os
import re import re
from pfbudget.common.types import Operation from pfbudget.common.types import Operation
from pfbudget.db.model import AccountType, SchedulePeriod from pfbudget.db.model import AccountType, Period
from pfbudget.db.sqlite import DatabaseClient from pfbudget.db.sqlite import DatabaseClient
import pfbudget.reporting.graph import pfbudget.reporting.graph
@ -60,12 +60,11 @@ def argparser() -> argparse.ArgumentParser:
# init = subparsers.add_parser("init") # init = subparsers.add_parser("init")
# init.set_defaults(op=Operation.Init) # init.set_defaults(op=Operation.Init)
# Exports transactions to specified format and file # Exports transactions to .csv file
export = subparsers.add_parser("export") export = subparsers.add_parser("export")
export.set_defaults(op=Operation.Export) export.set_defaults(op=Operation.Export)
file_options(export) file_options(export)
# Imports transactions from specified format and file
pimport = subparsers.add_parser("import") pimport = subparsers.add_parser("import")
pimport.set_defaults(op=Operation.Import) pimport.set_defaults(op=Operation.Import)
file_options(pimport) file_options(pimport)
@ -133,7 +132,8 @@ def argparser() -> argparse.ArgumentParser:
# PSD2 requisition id # PSD2 requisition id
requisition = subparsers.add_parser("eua") requisition = subparsers.add_parser("eua")
requisition.set_defaults(op=Operation.RequisitionId) requisition.set_defaults(op=Operation.RequisitionId)
requisition.add_argument("bank", nargs=1, type=str) requisition.add_argument("id", nargs=1, type=str)
requisition.add_argument("country", nargs=1, type=str)
# Download through the PSD2 API # Download through the PSD2 API
download = subparsers.add_parser("download", parents=[period]) download = subparsers.add_parser("download", parents=[period])
@ -268,7 +268,7 @@ def category(parser: argparse.ArgumentParser):
schedule = commands.add_parser("schedule") schedule = commands.add_parser("schedule")
schedule.set_defaults(op=Operation.CategorySchedule) schedule.set_defaults(op=Operation.CategorySchedule)
schedule.add_argument("category", nargs="+", type=str) schedule.add_argument("category", nargs="+", type=str)
schedule.add_argument("period", nargs=1, choices=[e.value for e in SchedulePeriod]) schedule.add_argument("period", nargs=1, choices=[e.value for e in Period])
schedule.add_argument("--frequency", nargs=1, default=[1], type=int) schedule.add_argument("--frequency", nargs=1, default=[1], type=int)
rule = commands.add_parser("rule") rule = commands.add_parser("rule")

View File

@ -3,8 +3,9 @@ import decimal
from ..core.manager import Manager from ..core.manager import Manager
from ..db.model import ( from ..db.model import (
Category, Category,
Note,
CategorySelector, CategorySelector,
Note,
Selector_T,
SplitTransaction, SplitTransaction,
Tag, Tag,
Transaction, Transaction,
@ -15,13 +16,15 @@ from ..db.model import (
class Interactive: class Interactive:
help = "category(:tag)/split/note:/skip/quit" help = "category(:tag)/split/note:/skip/quit"
selector = CategorySelector.manual selector = Selector_T.manual
def __init__(self, manager: Manager) -> None: def __init__(self, manager: Manager) -> None:
self.manager = manager self.manager = manager
self.categories = self.manager.database.select(Category) with self.manager.db.session() as session:
self.tags = self.manager.database.select(Tag) self.categories = session.get(Category)
self.tags = session.get(Tag)
session.expunge_all()
def intro(self) -> None: def intro(self) -> None:
print( print(
@ -32,34 +35,28 @@ class Interactive:
def start(self) -> None: def start(self) -> None:
self.intro() self.intro()
with self.manager.database.session as session: with self.manager.db.session() as session:
uncategorized = session.select( uncategorized = session.uncategorized()
Transaction, lambda: ~Transaction.category.has()
)
uncategorized.sort()
n = len(uncategorized) n = len(uncategorized)
print(f"{n} left to categorize") print(f"{n} left to categorize")
i = 0 i = 0
new = [] new = []
next = uncategorized[i]
while (command := input("$ ")) != "quit" and i < len(uncategorized): print(next)
current = uncategorized[i] if len(new) == 0 else new.pop() while (command := input("$ ")) != "quit":
print(current)
match command: match command:
case "help": case "help":
print(self.help) print(self.help)
case "skip": case "skip":
if len(uncategorized) == 0:
i += 1 i += 1
case "quit": case "quit":
break break
case "split": case "split":
new = self.split(current) new = self.split(next)
session.insert(new) session.insert(new)
case other: case other:
@ -70,32 +67,35 @@ class Interactive:
if other.startswith("note:"): if other.startswith("note:"):
# TODO adding notes to a splitted transaction won't allow # TODO adding notes to a splitted transaction won't allow
# categorization # categorization
current.note = Note(other[len("note:") :].strip()) next.note = Note(other[len("note:") :].strip())
else: else:
ct = other.split(":") ct = other.split(":")
if (category := ct[0]) not in [ if (category := ct[0]) not in [
c.name for c in self.categories c.name for c in self.categories
]: ]:
print(self.help, self.categories) print(self.help, self.categories)
continue
tags = [] tags = []
if len(ct) > 1: if len(ct) > 1:
tags = ct[1:] tags = ct[1:]
current.category = TransactionCategory( next.category = TransactionCategory(
category, self.selector category, CategorySelector(self.selector)
) )
for tag in tags: for tag in tags:
if tag not in [t.name for t in self.tags]: if tag not in [t.name for t in self.tags]:
session.insert([Tag(tag)]) session.insert([Tag(tag)])
self.tags = session.get(Tag) self.tags = session.get(Tag)
current.tags.add(TransactionTag(tag)) next.tags.add(TransactionTag(tag))
if len(new) == 0:
i += 1 i += 1
session.commit()
next = uncategorized[i] if len(new) == 0 else new.pop()
print(next)
def split(self, original: Transaction) -> list[SplitTransaction]: def split(self, original: Transaction) -> list[SplitTransaction]:
total = original.amount total = original.amount
new = [] new = []

View File

@ -51,11 +51,6 @@ class Operation(Enum):
ImportCategoryGroups = auto() ImportCategoryGroups = auto()
class ExportFormat(Enum):
JSON = auto()
pickle = auto()
class TransactionError(Exception): class TransactionError(Exception):
pass pass

View File

@ -1,128 +0,0 @@
from abc import ABC, abstractmethod
import json
from pathlib import Path
import pickle
from typing import Type
from pfbudget.common.types import ExportFormat
from pfbudget.db.client import Client
from pfbudget.db.model import (
Bank,
Category,
CategoryGroup,
Serializable,
Tag,
Transaction,
)
# required for the backup import
import pfbudget.db.model
class Command(ABC):
@abstractmethod
def execute(self) -> None:
raise NotImplementedError
def undo(self) -> None:
raise NotImplementedError
class ExportCommand(Command):
def __init__(
self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat
):
self.__client = client
self.what = what
self.fn = fn
self.format = format
def execute(self) -> None:
values = self.__client.select(self.what)
match self.format:
case ExportFormat.JSON:
with open(self.fn, "w", newline="") as f:
json.dump([e.serialize() for e in values], f, indent=4)
case ExportFormat.pickle:
raise AttributeError("pickle export not working at the moment!")
with open(self.fn, "wb") as f:
pickle.dump(values, f)
class ImportCommand(Command):
def __init__(
self, client: Client, what: Type[Serializable], fn: Path, format: ExportFormat
):
self.__client = client
self.what = what
self.fn = fn
self.format = format
def execute(self) -> None:
match self.format:
case ExportFormat.JSON:
with open(self.fn, "r") as f:
try:
values = json.load(f)
values = [self.what.deserialize(v) for v in values]
except json.JSONDecodeError as e:
raise ImportFailedError(e)
case ExportFormat.pickle:
raise AttributeError("pickle import not working at the moment!")
with open(self.fn, "rb") as f:
values = pickle.load(f)
self.__client.insert(values)
class ImportFailedError(Exception):
pass
class BackupCommand(Command):
def __init__(self, client: Client, fn: Path, format: ExportFormat) -> None:
self.__client = client
self.fn = fn
self.format = format
def execute(self) -> None:
banks = self.__client.select(Bank)
groups = self.__client.select(CategoryGroup)
categories = self.__client.select(Category)
tags = self.__client.select(Tag)
transactions = self.__client.select(Transaction)
values = [*banks, *groups, *categories, *tags, *transactions]
match self.format:
case ExportFormat.JSON:
with open(self.fn, "w", newline="") as f:
json.dump([e.serialize() for e in values], f, indent=4)
case ExportFormat.pickle:
raise AttributeError("pickle export not working at the moment!")
class ImportBackupCommand(Command):
def __init__(self, client: Client, fn: Path, format: ExportFormat) -> None:
self.__client = client
self.fn = fn
self.format = format
def execute(self) -> None:
match self.format:
case ExportFormat.JSON:
with open(self.fn, "r") as f:
try:
values = json.load(f)
values = [
getattr(pfbudget.db.model, v["class_"]).deserialize(v)
for v in values
]
except json.JSONDecodeError as e:
raise ImportFailedError(e)
case ExportFormat.pickle:
raise AttributeError("pickle import not working at the moment!")
self.__client.insert(values)

View File

@ -1,3 +1,4 @@
import csv
import json import json
from pathlib import Path from pathlib import Path
import pickle import pickle
@ -13,11 +14,12 @@ from pfbudget.db.model import (
CategoryGroup, CategoryGroup,
CategoryRule, CategoryRule,
CategorySchedule, CategorySchedule,
CategorySelector,
Link, Link,
MoneyTransaction, MoneyTransaction,
NordigenBank, Nordigen,
Rule, Rule,
CategorySelector, Selector_T,
SplitTransaction, SplitTransaction,
Tag, Tag,
TagRule, TagRule,
@ -79,7 +81,7 @@ class Manager:
else: else:
banks = self.database.select(Bank, Bank.nordigen) banks = self.database.select(Bank, Bank.nordigen)
extractor = PSD2Extractor(self.nordigen_client()) extractor = PSD2Extractor(Manager.nordigen_client())
transactions = [] transactions = []
for bank in banks: for bank in banks:
@ -101,20 +103,10 @@ class Manager:
categories = session.select(Category) categories = session.select(Category)
tags = session.select(Tag) tags = session.select(Tag)
rules = [ rules = [cat.rules for cat in categories if cat.name == "null"]
rule
for cat in categories
if cat.name == "null"
for rule in cat.rules
]
Nullifier(rules).transform_inplace(uncategorized) Nullifier(rules).transform_inplace(uncategorized)
rules = [ rules = [rule for cat in categories for rule in cat.rules]
rule
for cat in categories
if cat.name != "null"
for rule in cat.rules
]
Categorizer(rules).transform_inplace(uncategorized) Categorizer(rules).transform_inplace(uncategorized)
rules = [rule for tag in tags for rule in tag.rules] rules = [rule for tag in tags for rule in tag.rules]
@ -124,34 +116,24 @@ class Manager:
self.database.update(Bank, params) self.database.update(Bank, params)
case Operation.PSD2Mod: case Operation.PSD2Mod:
self.database.update(NordigenBank, params) self.database.update(Nordigen, params)
case Operation.BankDel: case Operation.BankDel:
self.database.delete(Bank, Bank.name, params) self.database.delete(Bank, Bank.name, params)
case Operation.PSD2Del: case Operation.PSD2Del:
self.database.delete(NordigenBank, NordigenBank.name, params) self.database.delete(Nordigen, Nordigen.name, params)
case Operation.Token:
Manager.nordigen_client().generate_token()
case Operation.RequisitionId: case Operation.RequisitionId:
bank_name = params[0] link, _ = Manager.nordigen_client().requisition(params[0], params[1])
bank = self.database.select(Bank, (lambda: Bank.name == bank_name))[0] print(f"Opening {link} to request access to {params[0]}")
if not bank.nordigen or not bank.nordigen.bank_id:
raise ValueError(f"{bank} doesn't have a Nordigen ID")
link, req_id = self.nordigen_client().new_requisition(
bank.nordigen.bank_id
)
self.database.update(
NordigenBank,
[{"name": bank.nordigen.name, "requisition_id": req_id}],
)
webbrowser.open(link) webbrowser.open(link)
case Operation.PSD2CountryBanks: case Operation.PSD2CountryBanks:
banks = self.nordigen_client().country_banks(params[0]) banks = Manager.nordigen_client().country_banks(params[0])
print(banks) print(banks)
case ( case (
@ -263,7 +245,10 @@ class Manager:
session.insert(transactions) session.insert(transactions)
case Operation.Export: case Operation.Export:
self.dump(params[0], params[1], self.database.select(Transaction)) with self.database.session as session:
self.dump(
params[0], params[1], self.database.select(Transaction, session)
)
case Operation.Import: case Operation.Import:
transactions = [] transactions = []
@ -289,7 +274,8 @@ class Manager:
if category := row.pop("category", None): if category := row.pop("category", None):
transaction.category = TransactionCategory( transaction.category = TransactionCategory(
category["name"], category["selector"]["selector"] category["name"],
CategorySelector(category["selector"]["selector"]),
) )
transactions.append(transaction) transactions.append(transaction)
@ -298,21 +284,27 @@ class Manager:
self.database.insert(transactions) self.database.insert(transactions)
case Operation.ExportBanks: case Operation.ExportBanks:
self.dump(params[0], params[1], self.database.select(Bank)) with self.database.session as session:
self.dump(params[0], params[1], self.database.select(Bank, session))
case Operation.ImportBanks: case Operation.ImportBanks:
banks = [] banks = []
for row in self.load(params[0], params[1]): for row in self.load(params[0], params[1]):
bank = Bank(row["name"], row["BIC"], row["type"]) bank = Bank(row["name"], row["BIC"], row["type"])
if row["nordigen"]: if row["nordigen"]:
bank.nordigen = NordigenBank(**row["nordigen"]) bank.nordigen = Nordigen(**row["nordigen"])
banks.append(bank) banks.append(bank)
if self.certify(banks): if self.certify(banks):
self.database.insert(banks) self.database.insert(banks)
case Operation.ExportCategoryRules: case Operation.ExportCategoryRules:
self.dump(params[0], params[1], self.database.select(CategoryRule)) with self.database.session as session:
self.dump(
params[0],
params[1],
self.database.select(CategoryRule, session),
)
case Operation.ImportCategoryRules: case Operation.ImportCategoryRules:
rules = [CategoryRule(**row) for row in self.load(params[0], params[1])] rules = [CategoryRule(**row) for row in self.load(params[0], params[1])]
@ -321,7 +313,10 @@ class Manager:
self.database.insert(rules) self.database.insert(rules)
case Operation.ExportTagRules: case Operation.ExportTagRules:
self.dump(params[0], params[1], self.database.select(TagRule)) with self.database.session as session:
self.dump(
params[0], params[1], self.database.select(TagRule, session)
)
case Operation.ImportTagRules: case Operation.ImportTagRules:
rules = [TagRule(**row) for row in self.load(params[0], params[1])] rules = [TagRule(**row) for row in self.load(params[0], params[1])]
@ -330,7 +325,10 @@ class Manager:
self.database.insert(rules) self.database.insert(rules)
case Operation.ExportCategories: case Operation.ExportCategories:
self.dump(params[0], params[1], self.database.select(Category)) with self.database.session as session:
self.dump(
params[0], params[1], self.database.select(Category, session)
)
case Operation.ImportCategories: case Operation.ImportCategories:
# rules = [Category(**row) for row in self.load(params[0])] # rules = [Category(**row) for row in self.load(params[0])]
@ -343,7 +341,7 @@ class Manager:
for rule in rules: for rule in rules:
del rule["type"] del rule["type"]
category.rules = [CategoryRule(**rule) for rule in rules] category.rules = set(CategoryRule(**rule) for rule in rules)
if row["schedule"]: if row["schedule"]:
category.schedule = CategorySchedule(**row["schedule"]) category.schedule = CategorySchedule(**row["schedule"])
categories.append(category) categories.append(category)
@ -352,7 +350,12 @@ class Manager:
self.database.insert(categories) self.database.insert(categories)
case Operation.ExportCategoryGroups: case Operation.ExportCategoryGroups:
self.dump(params[0], params[1], self.database.select(CategoryGroup)) with self.database.session as session:
self.dump(
params[0],
params[1],
self.database.select(CategoryGroup, session),
)
case Operation.ImportCategoryGroups: case Operation.ImportCategoryGroups:
groups = [ groups = [
@ -366,7 +369,7 @@ class Manager:
return parse_data(filename, args) return parse_data(filename, args)
def askcategory(self, transaction: Transaction): def askcategory(self, transaction: Transaction):
selector = CategorySelector.manual selector = CategorySelector(Selector_T.manual)
categories = self.database.select(Category) categories = self.database.select(Category)
@ -380,6 +383,9 @@ class Manager:
if format == "pickle": if format == "pickle":
with open(fn, "wb") as f: with open(fn, "wb") as f:
pickle.dump([e.format for e in sequence], f) pickle.dump([e.format for e in sequence], f)
elif format == "csv":
with open(fn, "w", newline="") as f:
csv.writer(f).writerows([e.format.values() for e in sequence])
elif format == "json": elif format == "json":
with open(fn, "w", newline="") as f: with open(fn, "w", newline="") as f:
json.dump([e.format for e in sequence], f, indent=4, default=str) json.dump([e.format for e in sequence], f, indent=4, default=str)
@ -391,6 +397,8 @@ class Manager:
if format == "pickle": if format == "pickle":
with open(fn, "rb") as f: with open(fn, "rb") as f:
return pickle.load(f) return pickle.load(f)
elif format == "csv":
raise Exception("CSV import not supported")
else: else:
print("format not well specified") print("format not well specified")
return [] return []
@ -407,5 +415,6 @@ class Manager:
self._database = Client(self._db, echo=self._verbosity > 2) self._database = Client(self._db, echo=self._verbosity > 2)
return self._database return self._database
def nordigen_client(self) -> NordigenClient: @staticmethod
return NordigenClient(NordigenCredentialsManager.default, self.database) def nordigen_client() -> NordigenClient:
return NordigenClient(NordigenCredentialsManager.default)

View File

@ -1,11 +1,10 @@
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from sqlalchemy import Engine, create_engine, delete, select, update from sqlalchemy import Engine, create_engine, delete, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from typing import Any, Mapping, Optional, Type, TypeVar from typing import Any, Mapping, Optional, Type, TypeVar
from pfbudget.db.exceptions import InsertError # from pfbudget.db.exceptions import InsertError, SelectError
class DatabaseSession: class DatabaseSession:
@ -17,17 +16,10 @@ class DatabaseSession:
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
try:
if exc_type: if exc_type:
self.__session.rollback() self.__session.rollback()
else: else:
self.__session.commit() self.__session.commit()
except IntegrityError as e:
raise InsertError() from e
finally:
self.__session.close()
def close(self):
self.__session.close() self.__session.close()
def insert(self, sequence: Sequence[Any]) -> None: def insert(self, sequence: Sequence[Any]) -> None:
@ -41,10 +33,7 @@ class DatabaseSession:
else: else:
stmt = select(what) stmt = select(what)
return self.__session.scalars(stmt).unique().all() return self.__session.scalars(stmt).all()
def delete(self, obj: Any) -> None:
self.__session.delete(obj)
class Client: class Client:
@ -61,16 +50,13 @@ class Client:
T = TypeVar("T") T = TypeVar("T")
def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]: def select(self, what: Type[T], exists: Optional[Any] = None) -> Sequence[T]:
session = self.session return self.session.select(what, exists)
result = session.select(what, exists)
session.close()
return result
def update(self, what: Type[Any], values: Sequence[Mapping[str, Any]]) -> None: def update(self, what: Type[Any], values: Sequence[Mapping[str, Any]]) -> None:
with self._sessionmaker() as session, session.begin(): with self._sessionmaker() as session, session.begin():
session.execute(update(what), values) session.execute(update(what), values)
def delete(self, what: Type[Any], column: Any, values: Sequence[Any]) -> None: def delete(self, what: Type[Any], column: Any, values: Sequence[str]) -> None:
with self._sessionmaker() as session, session.begin(): with self._sessionmaker() as session, session.begin():
session.execute(delete(what).where(column.in_(values))) session.execute(delete(what).where(column.in_(values)))

View File

@ -1,11 +1,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, MutableMapping, Sequence
from dataclasses import dataclass
import datetime as dt import datetime as dt
import decimal import decimal
import enum import enum
import re import re
from typing import Annotated, Any, Callable, Optional, Self, cast from typing import Annotated, Any, Optional
from sqlalchemy import ( from sqlalchemy import (
BigInteger, BigInteger,
@ -38,20 +36,6 @@ class Base(MappedAsDataclass, DeclarativeBase):
}, },
) )
type_annotation_map = {
enum.Enum: Enum(enum.Enum, create_constraint=True, inherit_schema=True),
}
@dataclass
class Serializable:
def serialize(self) -> Mapping[str, Any]:
return dict(class_=type(self).__name__)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
raise NotImplementedError
class AccountType(enum.Enum): class AccountType(enum.Enum):
checking = enum.auto() checking = enum.auto()
@ -62,38 +46,36 @@ class AccountType(enum.Enum):
MASTERCARD = enum.auto() MASTERCARD = enum.auto()
class Bank(Base, Serializable): accounttype = Annotated[
AccountType,
mapped_column(Enum(AccountType, inherit_schema=True)),
]
class Export:
@property
def format(self) -> dict[str, Any]:
raise NotImplementedError
class Bank(Base, Export):
__tablename__ = "banks" __tablename__ = "banks"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
BIC: Mapped[str] = mapped_column(String(8)) BIC: Mapped[str] = mapped_column(String(8))
type: Mapped[AccountType] type: Mapped[accounttype]
nordigen: Mapped[Optional[NordigenBank]] = relationship(default=None, lazy="joined") nordigen: Mapped[Optional[Nordigen]] = relationship(lazy="joined", init=False)
def serialize(self) -> Mapping[str, Any]: @property
nordigen = None def format(self) -> dict[str, Any]:
if self.nordigen: return dict(
nordigen = {
"bank_id": self.nordigen.bank_id,
"requisition_id": self.nordigen.requisition_id,
"invert": self.nordigen.invert,
}
return super().serialize() | dict(
name=self.name, name=self.name,
BIC=self.BIC, BIC=self.BIC,
type=self.type.name, type=self.type,
nordigen=nordigen, nordigen=self.nordigen.format if self.nordigen else None,
) )
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
bank = cls(map["name"], map["BIC"], map["type"])
if map["nordigen"]:
bank.nordigen = NordigenBank(**map["nordigen"])
return bank
bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))] bankfk = Annotated[str, mapped_column(Text, ForeignKey(Bank.name))]
@ -108,7 +90,7 @@ idpk = Annotated[
money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))] money = Annotated[decimal.Decimal, mapped_column(Numeric(16, 2))]
class Transaction(Base, Serializable): class Transaction(Base, Export):
__tablename__ = "transactions" __tablename__ = "transactions"
id: Mapped[idpk] = mapped_column(init=False) id: Mapped[idpk] = mapped_column(init=False)
@ -116,83 +98,32 @@ class Transaction(Base, Serializable):
description: Mapped[Optional[str]] description: Mapped[Optional[str]]
amount: Mapped[money] amount: Mapped[money]
split: Mapped[bool] = mapped_column(default=False) split: Mapped[bool] = mapped_column(init=False, default=False)
category: Mapped[Optional[TransactionCategory]] = relationship(
back_populates="transaction", default=None, lazy="joined"
)
tags: Mapped[set[TransactionTag]] = relationship(default_factory=set, lazy="joined")
note: Mapped[Optional[Note]] = relationship(
cascade="all, delete-orphan", passive_deletes=True, default=None, lazy="joined"
)
type: Mapped[str] = mapped_column(init=False) type: Mapped[str] = mapped_column(init=False)
category: Mapped[Optional[TransactionCategory]] = relationship(init=False)
note: Mapped[Optional[Note]] = relationship(
cascade="all, delete-orphan", init=False, passive_deletes=True
)
tags: Mapped[set[TransactionTag]] = relationship(init=False)
__mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "transaction"} __mapper_args__ = {"polymorphic_on": "type", "polymorphic_identity": "transaction"}
def serialize(self) -> Mapping[str, Any]: @property
category = None def format(self) -> dict[str, Any]:
if self.category: return dict(
category = {
"name": self.category.name,
"selector": self.category.selector.name,
}
return super().serialize() | dict(
id=self.id, id=self.id,
date=self.date.isoformat(), date=self.date,
description=self.description, description=self.description,
amount=str(self.amount), amount=self.amount,
split=self.split, split=self.split,
category=category if category else None,
tags=[{"tag": tag.tag} for tag in self.tags],
note={"note": self.note.note} if self.note else None,
type=self.type, type=self.type,
category=self.category.format if self.category else None,
# TODO note
tags=[tag.format for tag in self.tags] if self.tags else None,
) )
@classmethod
def deserialize(
cls, map: Mapping[str, Any]
) -> Transaction | BankTransaction | MoneyTransaction | SplitTransaction:
match map["type"]:
case "bank":
return BankTransaction.deserialize(map)
case "money":
return MoneyTransaction.deserialize(map)
case "split":
return SplitTransaction.deserialize(map)
case _:
return cls._deserialize(map)
@classmethod
def _deserialize(cls, map: Mapping[str, Any]) -> Self:
category = None
if map["category"]:
category = TransactionCategory(map["category"]["name"])
if map["category"]["selector"]:
category.selector = map["category"]["selector"]
tags: set[TransactionTag] = set()
if map["tags"]:
tags = set(TransactionTag(t["tag"]) for t in map["tags"])
note = None
if map["note"]:
note = Note(map["note"]["note"])
result = cls(
dt.date.fromisoformat(map["date"]),
map["description"],
map["amount"],
map["split"],
category,
tags,
note,
)
if map["id"]:
result.id = map["id"]
return result
def __lt__(self, other: Transaction): def __lt__(self, other: Transaction):
return self.date < other.date return self.date < other.date
@ -203,64 +134,40 @@ idfk = Annotated[
class BankTransaction(Transaction): class BankTransaction(Transaction):
bank: Mapped[Optional[bankfk]] = mapped_column(default=None) bank: Mapped[bankfk] = mapped_column(nullable=True)
__mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"} __mapper_args__ = {"polymorphic_identity": "bank", "polymorphic_load": "inline"}
def serialize(self) -> Mapping[str, Any]: @property
map = cast(MutableMapping[str, Any], super().serialize()) def format(self) -> dict[str, Any]:
map["bank"] = self.bank return super().format | dict(bank=self.bank)
return map
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
transaction = cls._deserialize(map)
transaction.bank = map["bank"]
return transaction
class MoneyTransaction(Transaction): class MoneyTransaction(Transaction):
__mapper_args__ = {"polymorphic_identity": "money"} __mapper_args__ = {"polymorphic_identity": "money"}
def serialize(self) -> Mapping[str, Any]:
return super().serialize()
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls._deserialize(map)
class SplitTransaction(Transaction): class SplitTransaction(Transaction):
original: Mapped[Optional[idfk]] = mapped_column(default=None) original: Mapped[idfk] = mapped_column(nullable=True)
__mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"} __mapper_args__ = {"polymorphic_identity": "split", "polymorphic_load": "inline"}
def serialize(self) -> Mapping[str, Any]: @property
map = cast(MutableMapping[str, Any], super().serialize()) def format(self) -> dict[str, Any]:
map["original"] = self.original return super().format | dict(original=self.original)
return map
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
transaction = cls._deserialize(map)
transaction.original = map["original"]
return transaction
class CategoryGroup(Base, Serializable): class CategoryGroup(Base, Export):
__tablename__ = "category_groups" __tablename__ = "category_groups"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
def serialize(self) -> Mapping[str, Any]: @property
return super().serialize() | dict(name=self.name) def format(self) -> dict[str, Any]:
return dict(name=self.name)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
return cls(map["name"])
class Category(Base, Serializable, repr=False): class Category(Base, Export):
__tablename__ = "categories" __tablename__ = "categories"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
@ -268,67 +175,11 @@ class Category(Base, Serializable, repr=False):
ForeignKey(CategoryGroup.name), default=None ForeignKey(CategoryGroup.name), default=None
) )
rules: Mapped[list[CategoryRule]] = relationship( rules: Mapped[set[CategoryRule]] = relationship(
cascade="all, delete-orphan", cascade="all, delete-orphan", passive_deletes=True, default_factory=set
passive_deletes=True,
default_factory=list,
lazy="joined",
) )
schedule: Mapped[Optional[CategorySchedule]] = relationship( schedule: Mapped[Optional[CategorySchedule]] = relationship(
cascade="all, delete-orphan", passive_deletes=True, default=None, lazy="joined" cascade="all, delete-orphan", passive_deletes=True, default=None
)
def serialize(self) -> Mapping[str, Any]:
rules: Sequence[Mapping[str, Any]] = []
for rule in self.rules:
rules.append(
{
"start": rule.start.isoformat() if rule.start else None,
"end": rule.end.isoformat() if rule.end else None,
"description": rule.description,
"regex": rule.regex,
"bank": rule.bank,
"min": str(rule.min) if rule.min is not None else None,
"max": str(rule.max) if rule.max is not None else None,
}
)
schedule = None
if self.schedule:
schedule = {
"period": self.schedule.period.name if self.schedule.period else None,
"period_multiplier": self.schedule.period_multiplier,
"amount": self.schedule.amount,
}
return super().serialize() | dict(
name=self.name,
group=self.group,
rules=rules,
schedule=schedule,
)
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
rules: list[CategoryRule] = []
for rule in map["rules"]:
rules.append(
CategoryRule(
dt.date.fromisoformat(rule["start"]) if rule["start"] else None,
dt.date.fromisoformat(rule["end"]) if rule["end"] else None,
rule["description"],
rule["regex"],
rule["bank"],
rule["min"],
rule["max"],
)
)
return cls(
map["name"],
map["group"],
rules,
CategorySchedule(**map["schedule"]) if map["schedule"] else None,
) )
def __repr__(self) -> str: def __repr__(self) -> str:
@ -337,6 +188,15 @@ class Category(Base, Serializable, repr=False):
f" schedule={self.schedule})" f" schedule={self.schedule})"
) )
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
group=self.group if self.group else None,
rules=[rule.format for rule in self.rules],
schedule=self.schedule.format if self.schedule else None,
)
catfk = Annotated[ catfk = Annotated[
str, str,
@ -344,25 +204,20 @@ catfk = Annotated[
] ]
class CategorySelector(enum.Enum): class TransactionCategory(Base, Export):
unknown = enum.auto()
nullifier = enum.auto()
vacations = enum.auto()
rules = enum.auto()
algorithm = enum.auto()
manual = enum.auto()
class TransactionCategory(Base):
__tablename__ = "transactions_categorized" __tablename__ = "transactions_categorized"
id: Mapped[idfk] = mapped_column(primary_key=True, init=False) id: Mapped[idfk] = mapped_column(primary_key=True, init=False)
name: Mapped[catfk] name: Mapped[catfk]
selector: Mapped[CategorySelector] = mapped_column(default=CategorySelector.unknown) selector: Mapped[CategorySelector] = relationship(
cascade="all, delete-orphan", lazy="joined"
)
transaction: Mapped[Transaction] = relationship( @property
back_populates="category", init=False, compare=False def format(self):
return dict(
name=self.name, selector=self.selector.format if self.selector else None
) )
@ -373,85 +228,106 @@ class Note(Base):
note: Mapped[str] note: Mapped[str]
class NordigenBank(Base): class Nordigen(Base, Export):
__tablename__ = "banks_nordigen" __tablename__ = "banks_nordigen"
name: Mapped[bankfk] = mapped_column(primary_key=True, init=False) name: Mapped[bankfk] = mapped_column(primary_key=True)
bank_id: Mapped[Optional[str]] bank_id: Mapped[Optional[str]]
requisition_id: Mapped[Optional[str]] requisition_id: Mapped[Optional[str]]
invert: Mapped[Optional[bool]] = mapped_column(default=None) invert: Mapped[Optional[bool]]
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
bank_id=self.bank_id,
requisition_id=self.requisition_id,
invert=self.invert,
)
class Tag(Base, Serializable): class Tag(Base):
__tablename__ = "tags" __tablename__ = "tags"
name: Mapped[str] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(primary_key=True)
rules: Mapped[list[TagRule]] = relationship( rules: Mapped[set[TagRule]] = relationship(
cascade="all, delete-orphan", cascade="all, delete-orphan", passive_deletes=True, default_factory=set
passive_deletes=True,
default_factory=list,
lazy="joined",
) )
def serialize(self) -> Mapping[str, Any]:
rules: Sequence[Mapping[str, Any]] = []
for rule in self.rules:
rules.append(
{
"start": rule.start,
"end": rule.end,
"description": rule.description,
"regex": rule.regex,
"bank": rule.bank,
"min": str(rule.min) if rule.min is not None else None,
"max": str(rule.max) if rule.max is not None else None,
}
)
return super().serialize() | dict(name=self.name, rules=rules) class TransactionTag(Base, Export):
@classmethod
def deserialize(cls, map: Mapping[str, Any]) -> Self:
rules: list[TagRule] = []
for rule in map["rules"]:
rules.append(
TagRule(
dt.date.fromisoformat(rule["start"]) if rule["start"] else None,
dt.date.fromisoformat(rule["end"]) if rule["end"] else None,
rule["description"],
rule["regex"],
rule["bank"],
rule["min"],
rule["max"],
)
)
return cls(map["name"], rules)
class TransactionTag(Base, unsafe_hash=True):
__tablename__ = "transactions_tagged" __tablename__ = "transactions_tagged"
id: Mapped[idfk] = mapped_column(primary_key=True, init=False) id: Mapped[idfk] = mapped_column(primary_key=True, init=False)
tag: Mapped[str] = mapped_column(ForeignKey(Tag.name), primary_key=True) tag: Mapped[str] = mapped_column(ForeignKey(Tag.name), primary_key=True)
@property
def format(self):
return dict(tag=self.tag)
class SchedulePeriod(enum.Enum): def __hash__(self):
daily = enum.auto() return hash(self.id)
weekly = enum.auto()
monthly = enum.auto()
yearly = enum.auto()
class CategorySchedule(Base): class Selector_T(enum.Enum):
unknown = enum.auto()
nullifier = enum.auto()
vacations = enum.auto()
rules = enum.auto()
algorithm = enum.auto()
manual = enum.auto()
categoryselector = Annotated[
Selector_T,
mapped_column(Enum(Selector_T, inherit_schema=True), default=Selector_T.unknown),
]
class CategorySelector(Base, Export):
__tablename__ = "category_selectors"
id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey(TransactionCategory.id, ondelete="CASCADE"),
primary_key=True,
init=False,
)
selector: Mapped[categoryselector]
@property
def format(self):
return dict(selector=self.selector)
class Period(enum.Enum):
daily = "daily"
weekly = "weekly"
monthly = "monthly"
yearly = "yearly"
scheduleperiod = Annotated[Selector_T, mapped_column(Enum(Period, inherit_schema=True))]
class CategorySchedule(Base, Export):
__tablename__ = "category_schedules" __tablename__ = "category_schedules"
name: Mapped[catfk] = mapped_column(primary_key=True, init=False) name: Mapped[catfk] = mapped_column(primary_key=True)
period: Mapped[Optional[SchedulePeriod]] period: Mapped[Optional[scheduleperiod]]
period_multiplier: Mapped[Optional[int]] period_multiplier: Mapped[Optional[int]]
amount: Mapped[Optional[int]] amount: Mapped[Optional[int]]
@property
def format(self) -> dict[str, Any]:
return dict(
name=self.name,
period=self.period,
period_multiplier=self.period_multiplier,
amount=self.amount,
)
class Link(Base): class Link(Base):
__tablename__ = "links" __tablename__ = "links"
@ -460,17 +336,17 @@ class Link(Base):
link: Mapped[idfk] = mapped_column(primary_key=True) link: Mapped[idfk] = mapped_column(primary_key=True)
class Rule(Base): class Rule(Base, Export):
__tablename__ = "rules" __tablename__ = "rules"
id: Mapped[idpk] = mapped_column(init=False) id: Mapped[idpk] = mapped_column(init=False)
start: Mapped[Optional[dt.date]] = mapped_column(default=None) start: Mapped[Optional[dt.date]]
end: Mapped[Optional[dt.date]] = mapped_column(default=None) end: Mapped[Optional[dt.date]]
description: Mapped[Optional[str]] = mapped_column(default=None) description: Mapped[Optional[str]]
regex: Mapped[Optional[str]] = mapped_column(default=None) regex: Mapped[Optional[str]]
bank: Mapped[Optional[str]] = mapped_column(default=None) bank: Mapped[Optional[str]]
min: Mapped[Optional[money]] = mapped_column(default=None) min: Mapped[Optional[money]]
max: Mapped[Optional[money]] = mapped_column(default=None) max: Mapped[Optional[money]]
type: Mapped[str] = mapped_column(init=False) type: Mapped[str] = mapped_column(init=False)
@ -485,16 +361,16 @@ class Rule(Base):
valid = re.compile(self.regex, re.IGNORECASE) valid = re.compile(self.regex, re.IGNORECASE)
ops = ( ops = (
Rule.exists(self.start, lambda r: t.date >= r), Rule.exists(self.start, lambda r: r < t.date),
Rule.exists(self.end, lambda r: t.date <= r), Rule.exists(self.end, lambda r: r > t.date),
Rule.exists(self.description, lambda r: r == t.description), Rule.exists(self.description, lambda r: r == t.description),
Rule.exists( Rule.exists(
valid, valid,
lambda r: r.search(t.description) if t.description else False, lambda r: r.search(t.description) if t.description else False,
), ),
Rule.exists(self.bank, lambda r: r == t.bank), Rule.exists(self.bank, lambda r: r == t.bank),
Rule.exists(self.min, lambda r: t.amount >= r), Rule.exists(self.min, lambda r: r < t.amount),
Rule.exists(self.max, lambda r: t.amount <= r), Rule.exists(self.max, lambda r: r > t.amount),
) )
if all(ops): if all(ops):
@ -502,8 +378,21 @@ class Rule(Base):
return False return False
@property
def format(self) -> dict[str, Any]:
return dict(
start=self.start,
end=self.end,
description=self.description,
regex=self.regex,
bank=self.bank,
min=self.min,
max=self.max,
type=self.type,
)
@staticmethod @staticmethod
def exists(r: Optional[Any], op: Callable[[Any], bool]) -> bool: def exists(r, op) -> bool:
return op(r) if r is not None else True return op(r) if r is not None else True
@ -516,13 +405,19 @@ class CategoryRule(Rule):
primary_key=True, primary_key=True,
init=False, init=False,
) )
name: Mapped[catfk] = mapped_column(init=False) name: Mapped[catfk]
__mapper_args__ = { __mapper_args__ = {
"polymorphic_identity": "category_rule", "polymorphic_identity": "category_rule",
"polymorphic_load": "selectin",
} }
@property
def format(self) -> dict[str, Any]:
return super().format | dict(name=self.name)
def __hash__(self):
return hash(self.id)
class TagRule(Rule): class TagRule(Rule):
__tablename__ = "tag_rules" __tablename__ = "tag_rules"
@ -533,19 +428,15 @@ class TagRule(Rule):
primary_key=True, primary_key=True,
init=False, init=False,
) )
tag: Mapped[str] = mapped_column( tag: Mapped[str] = mapped_column(ForeignKey(Tag.name, ondelete="CASCADE"))
ForeignKey(Tag.name, ondelete="CASCADE"), init=False
)
__mapper_args__ = { __mapper_args__ = {
"polymorphic_identity": "tag_rule", "polymorphic_identity": "tag_rule",
"polymorphic_load": "selectin",
} }
@property
def format(self) -> dict[str, Any]:
return super().format | dict(tag=self.tag)
class Nordigen(Base): def __hash__(self):
__tablename__ = "nordigen" return hash(self.id)
type: Mapped[str] = mapped_column(primary_key=True)
token: Mapped[str]
expires: Mapped[dt.datetime]

View File

@ -1,16 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
import datetime as dt
import dotenv import dotenv
import json
import nordigen import nordigen
import os import os
import requests import requests
import time import time
from typing import Any, Optional, Sequence, Tuple
import uuid import uuid
from pfbudget.db.client import Client
from pfbudget.db.model import Nordigen
from .exceptions import CredentialsError, DownloadError from .exceptions import CredentialsError, DownloadError
dotenv.load_dotenv() dotenv.load_dotenv()
@ -20,38 +16,40 @@ dotenv.load_dotenv()
class NordigenCredentials: class NordigenCredentials:
id: str id: str
key: str key: str
token: str = ""
def valid(self) -> bool: def valid(self) -> bool:
return len(self.id) != 0 and len(self.key) != 0 return self.id and self.key
class NordigenClient: class NordigenClient:
redirect_url = "https://murta.dev" redirect_url = "https://murta.dev"
def __init__(self, credentials: NordigenCredentials, client: Client): def __init__(self, credentials: NordigenCredentials):
super().__init__()
if not credentials.valid(): if not credentials.valid():
raise CredentialsError raise CredentialsError
self.__client = nordigen.NordigenClient( self._client = nordigen.NordigenClient(
secret_key=credentials.key, secret_id=credentials.id, timeout=5 secret_key=credentials.key, secret_id=credentials.id, timeout=5
) )
self.__client.token = self.__token(client)
def download(self, requisition_id) -> Sequence[dict[str, Any]]: if credentials.token:
self._client.token = credentials.token
def download(self, requisition_id):
try: try:
requisition = self.__client.requisition.get_requisition_by_id( requisition = self._client.requisition.get_requisition_by_id(requisition_id)
requisition_id
)
print(requisition) print(requisition)
except requests.HTTPError as e: except requests.HTTPError as e:
raise DownloadError(e) raise DownloadError(e)
transactions = [] transactions = {}
for acc in requisition["accounts"]: for acc in requisition["accounts"]:
account = self.__client.account_api(acc) account = self._client.account_api(acc)
retries = 0 retries = 0
downloaded = None
while retries < 3: while retries < 3:
try: try:
downloaded = account.get_transactions() downloaded = account.get_transactions()
@ -62,93 +60,55 @@ class NordigenClient:
time.sleep(1) time.sleep(1)
if not downloaded: if not downloaded:
print(f"Couldn't download transactions for {account.get_metadata()}") print(f"Couldn't download transactions for {account}")
continue continue
if ( transactions.update(downloaded)
"transactions" not in downloaded
or "booked" not in downloaded["transactions"]
):
print(f"{account} doesn't have transactions")
continue
transactions.extend(downloaded["transactions"]["booked"])
return transactions return transactions
def dump(self, bank, downloaded): def dump(self, bank, downloaded):
# @TODO log received JSON with open("json/" + bank.name + ".json", "w") as f:
pass json.dump(downloaded, f)
def new_requisition( def generate_token(self):
self, self.token = self._client.generate_token()
institution_id: str, print(f"New access token: {self.token}")
max_historical_days: Optional[int] = None, return self.token
access_valid_for_days: Optional[int] = None,
) -> Tuple[str, str]:
kwargs = {
"max_historical_days": max_historical_days,
"access_valid_for_days": access_valid_for_days,
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
req = self.__client.initialize_session( def requisition(self, id: str, country: str = "PT"):
self.redirect_url, institution_id, str(uuid.uuid4()), **kwargs requisition = self._client.initialize_session(
redirect_uri=self.redirect_url,
institution_id=id,
reference_id=str(uuid.uuid4()),
) )
return req.link, req.requisition_id return requisition.link, requisition.requisition_id
def country_banks(self, country: str): def country_banks(self, country: str):
return self.__client.institution.get_institutions(country) return self._client.institution.get_institutions(country)
def __token(self, client: Client) -> str: # def __token(self):
with client.session as session: # if token := os.environ.get("TOKEN"):
token = session.select(Nordigen) # return token
# else:
# token = self._client.generate_token()
# print(f"New access token: {token}")
# return token["access"]
def datetime(seconds: int) -> dt.datetime: @property
return dt.datetime.now() + dt.timedelta(seconds=seconds) def token(self):
return self._token
if not len(token): @token.setter
print("First time nordigen token setup") def token(self, value):
new = self.__client.generate_token() if self._token:
session.insert( print("Replacing existing token with {value}")
[ self._token = value
Nordigen(
"access",
new["access"],
datetime(new["access_expires"]),
),
Nordigen(
"refresh",
new["refresh"],
datetime(new["refresh_expires"]),
),
]
)
return new["access"]
else:
access = next(t for t in token if t.type == "access")
refresh = next(t for t in token if t.type == "refresh")
if access.expires > dt.datetime.now():
pass
elif refresh.expires > dt.datetime.now():
new = self.__client.exchange_token(refresh.token)
access.token = new["access"]
access.expires = datetime(new["access_expires"])
else:
new = self.__client.generate_token()
access.token = new["access"]
access.expires = datetime(new["access_expires"])
refresh.token = new["refresh"]
refresh.expires = datetime(new["refresh_expires"])
return access.token
class NordigenCredentialsManager: class NordigenCredentialsManager:
default = NordigenCredentials( default = NordigenCredentials(
os.environ.get("SECRET_ID", ""), os.environ.get("SECRET_ID"),
os.environ.get("SECRET_KEY", ""), os.environ.get("SECRET_KEY"),
os.environ.get("TOKEN"),
) )

View File

@ -1,45 +1,58 @@
from __future__ import annotations from collections import namedtuple
from decimal import Decimal from decimal import Decimal
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
import datetime as dt import datetime as dt
from typing import Any, Callable, NamedTuple, Optional
import yaml import yaml
from pfbudget.common.types import NoBankSelected from pfbudget.common.types import NoBankSelected
from pfbudget.db.model import BankTransaction from pfbudget.db.model import Transaction
from pfbudget.utils import utils from pfbudget.utils import utils
Index = namedtuple(
class Index(NamedTuple): "Index", ["date", "text", "value", "negate"], defaults=[-1, -1, -1, False]
date: int = -1 )
text: int = -1 Options = namedtuple(
value: int = -1 "Options",
negate: bool = False [
"encoding",
"separator",
"date_fmt",
"start",
"end",
"debit",
"credit",
"additional_parser",
"category",
"VISA",
"MasterCard",
"AmericanExpress",
],
defaults=[
"",
"",
"",
1,
None,
Index(),
Index(),
False,
None,
None,
None,
None,
],
)
class Options(NamedTuple): def parse_data(filename: Path, args: dict) -> list[Transaction]:
encoding: str cfg: dict = yaml.safe_load(open("parsers.yaml"))
separator: str
date_fmt: str
start: int = 1
end: Optional[int] = None
debit: Index = Index()
credit: Index = Index()
additional_parser: bool = False
VISA: Optional[Options] = None
MasterCard: Optional[Options] = None
AmericanExpress: Optional[Options] = None
def parse_data(filename: Path, args: dict[str, Any]) -> list[BankTransaction]:
cfg: dict[str, Any] = yaml.safe_load(open("parsers.yaml"))
assert ( assert (
"Banks" in cfg "Banks" in cfg
), "parsers.yaml is missing the Banks section with the list of available banks" ), "parsers.yaml is missing the Banks section with the list of available banks"
if not args["bank"]: if not args["bank"]:
bank, creditcard = utils.find_credit_institution( # type: ignore bank, creditcard = utils.find_credit_institution(
filename, cfg.get("Banks"), cfg.get("CreditCards") filename, cfg.get("Banks"), cfg.get("CreditCards")
) )
else: else:
@ -47,7 +60,7 @@ def parse_data(filename: Path, args: dict[str, Any]) -> list[BankTransaction]:
creditcard = None if not args["creditcard"] else args["creditcard"][0] creditcard = None if not args["creditcard"] else args["creditcard"][0]
try: try:
options: dict[str, Any] = cfg[bank] options: dict = cfg[bank]
except KeyError as e: except KeyError as e:
banks = cfg["Banks"] banks = cfg["Banks"]
raise NoBankSelected(f"{e} not a valid bank, try one of {banks}") raise NoBankSelected(f"{e} not a valid bank, try one of {banks}")
@ -60,6 +73,9 @@ def parse_data(filename: Path, args: dict[str, Any]) -> list[BankTransaction]:
raise NoBankSelected(f"{e} not a valid bank, try one of {creditcards}") raise NoBankSelected(f"{e} not a valid bank, try one of {creditcards}")
bank += creditcard bank += creditcard
if args["category"]:
options["category"] = args["category"][0]
if options.get("additional_parser"): if options.get("additional_parser"):
parser = getattr(import_module("pfbudget.extract.parsers"), bank) parser = getattr(import_module("pfbudget.extract.parsers"), bank)
transactions = parser(filename, bank, options).parse() transactions = parser(filename, bank, options).parse()
@ -70,7 +86,7 @@ def parse_data(filename: Path, args: dict[str, Any]) -> list[BankTransaction]:
class Parser: class Parser:
def __init__(self, filename: Path, bank: str, options: dict[str, Any]): def __init__(self, filename: Path, bank: str, options: dict):
self.filename = filename self.filename = filename
self.bank = bank self.bank = bank
@ -81,10 +97,10 @@ class Parser:
self.options = Options(**options) self.options = Options(**options)
def func(self, transaction: BankTransaction): def func(self, transaction: Transaction):
pass pass
def parse(self) -> list[BankTransaction]: def parse(self) -> list[Transaction]:
transactions = [ transactions = [
Parser.transaction(line, self.bank, self.options, self.func) Parser.transaction(line, self.bank, self.options, self.func)
for line in list(open(self.filename, encoding=self.options.encoding))[ for line in list(open(self.filename, encoding=self.options.encoding))[
@ -95,8 +111,7 @@ class Parser:
return transactions return transactions
@staticmethod @staticmethod
def index(line: list[str], options: Options) -> Index: def index(line: list, options: Options) -> Index:
index = None
if options.debit.date != -1 and options.credit.date != -1: if options.debit.date != -1 and options.credit.date != -1:
if options.debit.value != options.credit.value: if options.debit.value != options.credit.value:
if line[options.debit.value]: if line[options.debit.value]:
@ -123,57 +138,49 @@ class Parser:
else: else:
raise IndexError("No debit not credit indexes available") raise IndexError("No debit not credit indexes available")
return index if index else Index() return index
@staticmethod @staticmethod
def transaction( def transaction(line: str, bank: str, options: Options, func) -> Transaction:
line_: str, bank: str, options: Options, func: Callable[[BankTransaction], None] line = line.rstrip().split(options.separator)
) -> BankTransaction:
line = line_.rstrip().split(options.separator)
index = Parser.index(line, options) index = Parser.index(line, options)
try: date = (
date_str = line[index.date].strip() dt.datetime.strptime(line[index.date].strip(), options.date_fmt)
date = dt.datetime.strptime(date_str, options.date_fmt).date() .date()
.isoformat()
)
text = line[index.text] text = line[index.text]
value = utils.parse_decimal(line[index.value]) value = utils.parse_decimal(line[index.value])
if index.negate: if index.negate:
value = -value value = -value
transaction = BankTransaction(date, text, value, bank=bank) if options.category:
category = line[options.category]
transaction = Transaction(date, text, bank, value, category)
else:
transaction = Transaction(date, text, bank, value)
if options.additional_parser: if options.additional_parser:
func(transaction) func(transaction)
return transaction return transaction
except IndexError:
raise IndexError(line_)
class Bank1(Parser): class Bank1(Parser):
def __init__(self, filename: Path, bank: str, options: dict[str, Any]): def __init__(self, filename: str, bank: str, options: dict):
super().__init__(filename, bank, options) super().__init__(filename, bank, options)
self.transfers: list[dt.date] = [] self.transfers = []
self.transaction_cost = -Decimal("1") self.transaction_cost = -Decimal("1")
def func(self, transaction: BankTransaction): def func(self, transaction: Transaction):
if ( if "transf" in transaction.description.lower() and transaction.value < 0:
transaction.description transaction.value -= self.transaction_cost
and "transf" in transaction.description.lower()
and transaction.amount < 0
):
transaction.amount -= self.transaction_cost
self.transfers.append(transaction.date) self.transfers.append(transaction.date)
def parse(self) -> list[BankTransaction]: def parse(self) -> list:
transactions = super().parse() transactions = super().parse()
for date in self.transfers: for date in self.transfers:
transactions.append( transactions.append(
BankTransaction( Transaction(date, "Transaction cost", self.bank, self.transaction_cost)
date, "Transaction cost", self.transaction_cost, bank=self.bank
)
) )
return transactions return transactions

View File

@ -35,4 +35,4 @@ class PSD2Extractor(Extractor):
] ]
def convert(self, bank, downloaded, start, end): def convert(self, bank, downloaded, start, end):
return [convert(t, bank) for t in downloaded] return [convert(t, bank) for t in downloaded["transactions"]["booked"]]

View File

@ -4,10 +4,11 @@ from typing import Iterable, Sequence
from pfbudget.db.model import ( from pfbudget.db.model import (
CategoryRule, CategoryRule,
CategorySelector, CategorySelector,
Selector_T,
Transaction, Transaction,
TransactionCategory, TransactionCategory,
TransactionTag,
) )
from .exceptions import TransactionCategorizedError
from .transform import Transformer from .transform import Transformer
@ -24,15 +25,12 @@ class Categorizer(Transformer):
def transform_inplace(self, transactions: Sequence[Transaction]) -> None: def transform_inplace(self, transactions: Sequence[Transaction]) -> None:
for rule in self.rules: for rule in self.rules:
for transaction in transactions: for transaction in transactions:
if transaction.category:
raise TransactionCategorizedError(transaction)
if not rule.matches(transaction): if not rule.matches(transaction):
continue continue
if not transaction.category:
transaction.category = TransactionCategory( transaction.category = TransactionCategory(
rule.name, CategorySelector.rules rule.name, CategorySelector(Selector_T.rules)
) )
else:
if not transaction.tags:
transaction.tags = {TransactionTag(rule.name)}
else:
transaction.tags.add(TransactionTag(rule.name))

View File

@ -1,2 +1,6 @@
class MoreThanOneMatchError(Exception): class MoreThanOneMatchError(Exception):
pass pass
class TransactionCategorizedError(Exception):
pass

View File

@ -6,6 +6,7 @@ from .exceptions import MoreThanOneMatchError
from .transform import Transformer from .transform import Transformer
from pfbudget.db.model import ( from pfbudget.db.model import (
CategorySelector, CategorySelector,
Selector_T,
Transaction, Transaction,
TransactionCategory, TransactionCategory,
) )
@ -15,7 +16,7 @@ class Nullifier(Transformer):
NULL_DAYS = 4 NULL_DAYS = 4
def __init__(self, rules=None): def __init__(self, rules=None):
self.rules = rules if rules else [] self.rules = rules
def transform(self, transactions: Sequence[Transaction]) -> Sequence[Transaction]: def transform(self, transactions: Sequence[Transaction]) -> Sequence[Transaction]:
"""transform """transform
@ -88,6 +89,6 @@ class Nullifier(Transformer):
def _nullify(self, transaction: Transaction) -> Transaction: def _nullify(self, transaction: Transaction) -> Transaction:
transaction.category = TransactionCategory( transaction.category = TransactionCategory(
"null", selector=CategorySelector.nullifier "null", selector=CategorySelector(Selector_T.nullifier)
) )
return transaction return transaction

1776
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,16 +8,17 @@ readme = "README.md"
packages = [{include = "pfbudget"}] packages = [{include = "pfbudget"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.11" python = "^3.10"
codetiming = "^1.4.0" codetiming = "^1.4.0"
matplotlib = "^3.7.1" matplotlib = "^3.7.1"
nordigen = "^1.3.1" nordigen = "^1.3.1"
psycopg2 = "^2.9.6" psycopg2 = {extras = ["binary"], version = "^2.9.6"}
python-dateutil = "^2.8.2" python-dateutil = "^2.8.2"
python-dotenv = "^1.0.0" python-dotenv = "^1.0.0"
pyyaml = "^6.0" pyyaml = "^6.0"
sqlalchemy = "^2.0.9" sqlalchemy = "^2.0.9"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
alembic = "^1.10.3" alembic = "^1.10.3"
black = "^23.3.0" black = "^23.3.0"
@ -27,15 +28,11 @@ pytest = "^7.3.0"
pytest-cov = "^4.0.0" pytest-cov = "^4.0.0"
pytest-mock = "^3.10.0" pytest-mock = "^3.10.0"
sqlalchemy = {extras = ["mypy"], version = "^2.0.9"} sqlalchemy = {extras = ["mypy"], version = "^2.0.9"}
ruff = "^0.0.267"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
pythonpath = ". tests"
[pytest] [pytest]
mock_use_standalone_module = true mock_use_standalone_module = true

View File

@ -1,8 +0,0 @@
from pfbudget.db.model import AccountType, Bank, NordigenBank
checking = Bank(
"bank", "BANK", AccountType.checking, NordigenBank("bank_id", "requisition_id")
)
cc = Bank("cc", "CC", AccountType.MASTERCARD)

View File

@ -1,21 +1,15 @@
from decimal import Decimal from decimal import Decimal
from pfbudget.db.model import Category, CategoryGroup, CategoryRule, Tag, TagRule from pfbudget.db.model import Category, CategoryRule, Tag, TagRule
category_null = Category("null") category_null = Category("null", None, set())
categorygroup1 = CategoryGroup("group#1")
category1 = Category( category1 = Category(
"cat#1", "cat#1",
"group#1", None,
rules=[CategoryRule(description="desc#1", max=Decimal(0))], {CategoryRule(None, None, "desc#1", None, None, None, Decimal(0), "cat#1")},
) )
category2 = Category( tag_1 = Tag(
"cat#2", "tag#1", {TagRule(None, None, "desc#1", None, None, None, Decimal(0), "tag#1")}
"group#1",
rules=[CategoryRule(description="desc#1", max=Decimal(0))],
) )
tag_1 = Tag("tag#1", rules=[TagRule(description="desc#1", max=Decimal(0))])

View File

@ -1,22 +0,0 @@
import datetime as dt
from pfbudget.db.client import Client
from pfbudget.db.model import Base, Nordigen
class MockClient(Client):
now = dt.datetime.now()
def __init__(self):
url = "sqlite://"
super().__init__(
url, execution_options={"schema_translate_map": {"pfbudget": None}}
)
Base.metadata.create_all(self.engine)
self.insert(
[
Nordigen("access", "token#1", self.now + dt.timedelta(days=1)),
Nordigen("refresh", "token#2", self.now + dt.timedelta(days=30)),
]
)

View File

@ -1,11 +1,3 @@
from typing import Any, Dict, List, Optional
import nordigen
from nordigen.types.http_enums import HTTPMethod
from nordigen.types.types import RequisitionDto, TokenType
from pfbudget.extract.nordigen import NordigenCredentials
id = "3fa85f64-5717-4562-b3fc-2c963f66afa6" id = "3fa85f64-5717-4562-b3fc-2c963f66afa6"
accounts_id = { accounts_id = {
@ -18,7 +10,6 @@ accounts_id = {
"owner_name": "string", "owner_name": "string",
} }
# The downloaded transactions match the simple and simple_transformed mocks
accounts_id_transactions = { accounts_id_transactions = {
"transactions": { "transactions": {
"booked": [ "booked": [
@ -89,58 +80,3 @@ requisitions_id = {
"account_selection": False, "account_selection": False,
"redirect_immediate": False, "redirect_immediate": False,
} }
credentials = NordigenCredentials("ID", "KEY")
class MockNordigenClient(nordigen.NordigenClient):
def __init__(
self,
secret_key: str = "ID",
secret_id: str = "KEY",
timeout: int = 10,
base_url: str = "https://ob.nordigen.com/api/v2",
) -> None:
super().__init__(secret_key, secret_id, timeout, base_url)
def generate_token(self) -> TokenType:
return {
"access": "access_token",
"refresh": "refresh_token",
"access_expires": 86400,
"refresh_expires": 2592000,
}
def exchange_token(self, refresh_token: str) -> TokenType:
assert len(refresh_token) > 0, "invalid refresh token"
return {
"access": "access_token",
"refresh": "refresh_token",
"access_expires": 86400,
"refresh_expires": 2592000,
}
def request(
self,
method: HTTPMethod,
endpoint: str,
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
) -> Any:
if endpoint == "requisitions/" + "requisition_id" + "/":
return requisitions_id
elif endpoint == "accounts/" + id + "/transactions/":
return accounts_id_transactions
else:
raise NotImplementedError(endpoint)
def initialize_session(
self,
redirect_uri: str,
institution_id: str,
reference_id: str,
max_historical_days: int = 90,
access_valid_for_days: int = 90,
access_scope: List[str] | None = None,
) -> RequisitionDto:
return RequisitionDto("http://random", "requisition_id")

View File

@ -1,73 +0,0 @@
from datetime import date
from decimal import Decimal
from pfbudget.db.model import (
BankTransaction,
CategorySelector,
MoneyTransaction,
Note,
SplitTransaction,
Transaction,
TransactionCategory,
TransactionTag,
)
# The simple and simple_transformed match the nordigen mocks
simple = [
BankTransaction(date(2023, 1, 14), "string", Decimal("328.18"), bank="bank"),
BankTransaction(date(2023, 2, 14), "string", Decimal("947.26"), bank="bank"),
]
simple_transformed = [
BankTransaction(
date(2023, 1, 14),
"",
Decimal("328.18"),
bank="bank",
category=TransactionCategory("category#1", CategorySelector.algorithm),
),
BankTransaction(
date(2023, 2, 14),
"",
Decimal("947.26"),
bank="bank",
category=TransactionCategory("category#2", CategorySelector.algorithm),
),
]
bank = [
BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#1"),
BankTransaction(date(2023, 1, 1), "", Decimal("-10"), bank="bank#2"),
]
money = [
MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")),
MoneyTransaction(date(2023, 1, 1), "", Decimal("-10")),
]
__original = Transaction(date(2023, 1, 1), "", Decimal("-10"), split=True)
__original.id = 9000
split = [
__original,
SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id),
SplitTransaction(date(2023, 1, 1), "", Decimal("-5"), original=__original.id),
]
tagged = [
Transaction(
date(2023, 1, 1),
"",
Decimal("-10"),
tags={TransactionTag("tag#1"), TransactionTag("tag#1")},
)
]
noted = [
Transaction(
date(2023, 1, 1),
"",
Decimal("-10"),
note=Note("note#1"),
)
]

View File

@ -1,144 +0,0 @@
from pathlib import Path
from typing import Any, Sequence, Type
import pytest
from mocks import banks, categories, transactions
from mocks.client import MockClient
from pfbudget.common.types import ExportFormat
from pfbudget.core.command import (
BackupCommand,
ExportCommand,
ImportBackupCommand,
ImportCommand,
ImportFailedError,
)
from pfbudget.db.client import Client
from pfbudget.db.model import (
Bank,
BankTransaction,
Base,
Category,
CategoryGroup,
MoneyTransaction,
Note,
SplitTransaction,
Tag,
Transaction,
TransactionCategory,
TransactionTag,
)
@pytest.fixture
def client() -> Client:
return MockClient()
params = [
(transactions.simple, Transaction),
(transactions.simple_transformed, Transaction),
(transactions.bank, Transaction),
(transactions.bank, BankTransaction),
(transactions.money, Transaction),
(transactions.money, MoneyTransaction),
(transactions.split, SplitTransaction),
([banks.checking, banks.cc], Bank),
([categories.category_null, categories.category1, categories.category2], Category),
(
[
categories.categorygroup1,
categories.category_null,
categories.category1,
categories.category2,
],
CategoryGroup,
),
([categories.tag_1], Tag),
]
not_serializable = [
(transactions.simple_transformed, TransactionCategory),
(transactions.tagged, TransactionTag),
(transactions.noted, Note),
]
class TestBackup:
@pytest.mark.parametrize("input, what", params)
def test_import(self, tmp_path: Path, input: Sequence[Any], what: Type[Any]):
file = tmp_path / "test.json"
client = MockClient()
client.insert(input)
originals = client.select(what)
assert originals
command = ExportCommand(client, what, file, ExportFormat.JSON)
command.execute()
other = MockClient()
command = ImportCommand(other, what, file, ExportFormat.JSON)
command.execute()
imported = other.select(what)
assert originals == imported
command = ExportCommand(client, what, file, ExportFormat.pickle)
with pytest.raises(AttributeError):
command.execute()
command = ImportCommand(other, what, file, ExportFormat.pickle)
with pytest.raises(AttributeError):
command.execute()
@pytest.mark.parametrize("input, what", not_serializable)
def test_try_backup_not_serializable(
self, tmp_path: Path, input: Sequence[Any], what: Type[Any]
):
file = tmp_path / "test.json"
client = MockClient()
client.insert(input)
originals = client.select(what)
assert originals
command = ExportCommand(client, what, file, ExportFormat.JSON)
with pytest.raises(AttributeError):
command.execute()
other = MockClient()
command = ImportCommand(other, what, file, ExportFormat.JSON)
with pytest.raises(ImportFailedError):
command.execute()
imported = other.select(what)
assert not imported
def test_full_backup(self, tmp_path: Path):
file = tmp_path / "test.json"
client = MockClient()
client.insert([e for t in params for e in t[0]])
command = BackupCommand(client, file, ExportFormat.JSON)
command.execute()
other = MockClient()
command = ImportBackupCommand(other, file, ExportFormat.JSON)
command.execute()
def subclasses(cls: Type[Any]) -> set[Type[Any]]:
return set(cls.__subclasses__()) | {
s for c in cls.__subclasses__() for s in subclasses(c)
}
for t in [cls for cls in subclasses(Base)]:
originals = client.select(t)
imported = other.select(t)
assert originals == imported, f"{t}"

View File

@ -1,54 +0,0 @@
import json
from pathlib import Path
import pytest
from mocks.client import MockClient
import mocks.transactions
from pfbudget.common.types import ExportFormat
from pfbudget.core.command import ExportCommand, ImportCommand
from pfbudget.db.client import Client
from pfbudget.db.exceptions import InsertError
from pfbudget.db.model import Transaction
@pytest.fixture
def client() -> Client:
return MockClient()
class TestCommand:
def test_export_json(self, tmp_path: Path, client: Client):
file = tmp_path / "test.json"
client.insert(mocks.transactions.simple)
command = ExportCommand(client, Transaction, file, ExportFormat.JSON)
command.execute()
with open(file, newline="") as f:
result = json.load(f)
assert result == [t.serialize() for t in client.select(Transaction)]
def test_export_pickle(self, tmp_path: Path, client: Client):
file = tmp_path / "test.pickle"
command = ExportCommand(client, Transaction, file, ExportFormat.pickle)
with pytest.raises(AttributeError):
command.execute()
def test_import_json(self, tmp_path: Path, client: Client):
file = tmp_path / "test"
client.insert(mocks.transactions.simple)
command = ExportCommand(client, Transaction, file, ExportFormat.JSON)
command.execute()
# Since the transactions are already in the DB, we expect an insert error
with pytest.raises(InsertError):
command = ImportCommand(client, Transaction, file, ExportFormat.JSON)
command.execute()
def test_import_pickle(self, tmp_path: Path, client: Client):
file = tmp_path / "test"
command = ExportCommand(client, Transaction, file, ExportFormat.pickle)
with pytest.raises(AttributeError):
command.execute()

View File

@ -2,14 +2,14 @@ from datetime import date
from decimal import Decimal from decimal import Decimal
import pytest import pytest
from mocks.client import MockClient
from pfbudget.db.client import Client from pfbudget.db.client import Client
from pfbudget.db.model import ( from pfbudget.db.model import (
AccountType, AccountType,
Bank, Bank,
NordigenBank, Base,
CategorySelector, CategorySelector,
Nordigen,
Selector_T,
Transaction, Transaction,
TransactionCategory, TransactionCategory,
) )
@ -17,21 +17,20 @@ from pfbudget.db.model import (
@pytest.fixture @pytest.fixture
def client() -> Client: def client() -> Client:
return MockClient() url = "sqlite://"
client = Client(url, execution_options={"schema_translate_map": {"pfbudget": None}})
Base.metadata.create_all(client.engine)
return client
@pytest.fixture @pytest.fixture
def banks(client: Client) -> list[Bank]: def banks(client: Client) -> list[Bank]:
banks = [ banks = [
Bank("bank", "BANK", AccountType.checking, NordigenBank(None, "req", None)), Bank("bank", "BANK", AccountType.checking),
Bank("broker", "BROKER", AccountType.investment), Bank("broker", "BROKER", AccountType.investment),
Bank("creditcard", "CC", AccountType.MASTERCARD), Bank("creditcard", "CC", AccountType.MASTERCARD),
] ]
banks[0].nordigen = Nordigen("bank", None, "req", None)
# fix nordigen bank names which would be generated post DB insert
for bank in banks:
if bank.nordigen:
bank.nordigen.name = bank.name
client.insert(banks) client.insert(banks)
return banks return banks
@ -40,22 +39,19 @@ def banks(client: Client) -> list[Bank]:
@pytest.fixture @pytest.fixture
def transactions(client: Client) -> list[Transaction]: def transactions(client: Client) -> list[Transaction]:
transactions = [ transactions = [
Transaction( Transaction(date(2023, 1, 1), "", Decimal("-10")),
date(2023, 1, 1),
"",
Decimal("-10"),
category=TransactionCategory("category", CategorySelector.algorithm),
),
Transaction(date(2023, 1, 2), "", Decimal("-50")), Transaction(date(2023, 1, 2), "", Decimal("-50")),
] ]
transactions[0].category = TransactionCategory(
"name", CategorySelector(Selector_T.algorithm)
)
client.insert(transactions) client.insert(transactions)
# fix ids which would be generated post DB insert
for i, transaction in enumerate(transactions): for i, transaction in enumerate(transactions):
transaction.id = i + 1 transaction.id = i + 1
if transaction.category: transaction.split = False # default
transaction.category.id = 1 transactions[0].category.id = 1
transactions[0].category.selector.id = 1
return transactions return transactions
@ -125,13 +121,13 @@ class TestDatabase:
def test_update_nordigen(self, client: Client, banks: list[Bank]): def test_update_nordigen(self, client: Client, banks: list[Bank]):
name = banks[0].name name = banks[0].name
result = client.select(NordigenBank, lambda: NordigenBank.name == name) result = client.select(Nordigen, lambda: Nordigen.name == name)
assert result[0].requisition_id == "req" assert result[0].requisition_id == "req"
update = {"name": name, "requisition_id": "anotherreq"} update = {"name": name, "requisition_id": "anotherreq"}
client.update(NordigenBank, [update]) client.update(Nordigen, [update])
result = client.select(NordigenBank, lambda: NordigenBank.name == name) result = client.select(Nordigen, lambda: Nordigen.name == name)
assert result[0].requisition_id == "anotherreq" assert result[0].requisition_id == "anotherreq"
result = client.select(Bank, lambda: Bank.name == name) result = client.select(Bank, lambda: Bank.name == name)

View File

@ -31,8 +31,8 @@ class TestDatabaseLoad:
def test_insert(self, loader: Loader): def test_insert(self, loader: Loader):
transactions = [ transactions = [
BankTransaction(date(2023, 1, 1), "", Decimal("-500"), bank="Bank#1"), BankTransaction(date(2023, 1, 1), "", Decimal("-500"), "Bank#1"),
BankTransaction(date(2023, 1, 2), "", Decimal("500"), bank="Bank#2"), BankTransaction(date(2023, 1, 2), "", Decimal("500"), "Bank#2"),
] ]
loader.load(transactions) loader.load(transactions)

View File

@ -4,10 +4,9 @@ from typing import Any, Optional
import pytest import pytest
import requests import requests
from mocks.client import MockClient
import mocks.nordigen as mock import mocks.nordigen as mock
from pfbudget.db.model import AccountType, Bank, BankTransaction, NordigenBank from pfbudget.db.model import AccountType, Bank, BankTransaction, Nordigen
from pfbudget.extract.exceptions import BankError, CredentialsError from pfbudget.extract.exceptions import BankError, CredentialsError
from pfbudget.extract.extract import Extractor from pfbudget.extract.extract import Extractor
from pfbudget.extract.nordigen import NordigenClient, NordigenCredentials from pfbudget.extract.nordigen import NordigenClient, NordigenCredentials
@ -59,13 +58,14 @@ def mock_requests(monkeypatch: pytest.MonkeyPatch):
@pytest.fixture @pytest.fixture
def extractor() -> Extractor: def extractor() -> Extractor:
credentials = NordigenCredentials("ID", "KEY") credentials = NordigenCredentials("ID", "KEY", "TOKEN")
return PSD2Extractor(NordigenClient(credentials, MockClient())) return PSD2Extractor(NordigenClient(credentials))
@pytest.fixture @pytest.fixture
def bank() -> Bank: def bank() -> Bank:
bank = Bank("Bank#1", "", AccountType.checking, NordigenBank("", mock.id, False)) bank = Bank("Bank#1", "", AccountType.checking)
bank.nordigen = Nordigen("", "", mock.id, False)
return bank return bank
@ -73,7 +73,7 @@ class TestExtractPSD2:
def test_empty_credentials(self): def test_empty_credentials(self):
cred = NordigenCredentials("", "") cred = NordigenCredentials("", "")
with pytest.raises(CredentialsError): with pytest.raises(CredentialsError):
NordigenClient(cred, MockClient()) NordigenClient(cred)
def test_no_psd2_bank(self, extractor: Extractor): def test_no_psd2_bank(self, extractor: Extractor):
with pytest.raises(BankError): with pytest.raises(BankError):
@ -88,17 +88,12 @@ class TestExtractPSD2:
with pytest.raises(requests.Timeout): with pytest.raises(requests.Timeout):
extractor.extract(bank) extractor.extract(bank)
def test_extract( def test_extract(self, extractor: Extractor, bank: Bank):
self, monkeypatch: pytest.MonkeyPatch, extractor: Extractor, bank: Bank
):
monkeypatch.setattr(
"pfbudget.extract.nordigen.NordigenClient.dump", lambda *args: None
)
assert extractor.extract(bank) == [ assert extractor.extract(bank) == [
BankTransaction( BankTransaction(
dt.date(2023, 1, 14), "string", Decimal("328.18"), bank="Bank#1" dt.date(2023, 1, 14), "string", Decimal("328.18"), "Bank#1"
), ),
BankTransaction( BankTransaction(
dt.date(2023, 2, 14), "string", Decimal("947.26"), bank="Bank#1" dt.date(2023, 2, 14), "string", Decimal("947.26"), "Bank#1"
), ),
] ]

View File

@ -5,9 +5,9 @@ import mocks.categories as mock
from pfbudget.db.model import ( from pfbudget.db.model import (
BankTransaction, BankTransaction,
Category,
CategoryRule, CategoryRule,
CategorySelector, CategorySelector,
Selector_T,
TransactionCategory, TransactionCategory,
TransactionTag, TransactionTag,
) )
@ -20,8 +20,8 @@ from pfbudget.transform.transform import Transformer
class TestTransform: class TestTransform:
def test_nullifier(self): def test_nullifier(self):
transactions = [ transactions = [
BankTransaction(date(2023, 1, 1), "", Decimal("-500"), bank="Bank#1"), BankTransaction(date(2023, 1, 1), "", Decimal("-500"), "Bank#1"),
BankTransaction(date(2023, 1, 2), "", Decimal("500"), bank="Bank#2"), BankTransaction(date(2023, 1, 2), "", Decimal("500"), "Bank#2"),
] ]
for t in transactions: for t in transactions:
@ -31,12 +31,14 @@ class TestTransform:
transactions = categorizer.transform(transactions) transactions = categorizer.transform(transactions)
for t in transactions: for t in transactions:
assert t.category == TransactionCategory("null", CategorySelector.nullifier) assert t.category == TransactionCategory(
"null", CategorySelector(Selector_T.nullifier)
)
def test_nullifier_inplace(self): def test_nullifier_inplace(self):
transactions = [ transactions = [
BankTransaction(date(2023, 1, 1), "", Decimal("-500"), bank="Bank#1"), BankTransaction(date(2023, 1, 1), "", Decimal("-500"), "Bank#1"),
BankTransaction(date(2023, 1, 2), "", Decimal("500"), bank="Bank#2"), BankTransaction(date(2023, 1, 2), "", Decimal("500"), "Bank#2"),
] ]
for t in transactions: for t in transactions:
@ -46,20 +48,20 @@ class TestTransform:
categorizer.transform_inplace(transactions) categorizer.transform_inplace(transactions)
for t in transactions: for t in transactions:
assert t.category == TransactionCategory("null", CategorySelector.nullifier) assert t.category == TransactionCategory(
"null", CategorySelector(Selector_T.nullifier)
)
def test_nullifier_with_rules(self): def test_nullifier_with_rules(self):
transactions = [ transactions = [
BankTransaction(date(2023, 1, 1), "", Decimal("-500"), bank="Bank#1"), BankTransaction(date(2023, 1, 1), "", Decimal("-500"), "Bank#1"),
BankTransaction(date(2023, 1, 2), "", Decimal("500"), bank="Bank#2"), BankTransaction(date(2023, 1, 2), "", Decimal("500"), "Bank#2"),
] ]
for t in transactions: for t in transactions:
assert not t.category assert not t.category
rule = CategoryRule(bank="Bank#1") rules = [CategoryRule(None, None, None, None, "Bank#1", None, None, "null")]
rule.name = "null"
rules = [rule]
categorizer: Transformer = Nullifier(rules) categorizer: Transformer = Nullifier(rules)
transactions = categorizer.transform(transactions) transactions = categorizer.transform(transactions)
@ -67,28 +69,24 @@ class TestTransform:
for t in transactions: for t in transactions:
assert not t.category assert not t.category
rule = CategoryRule(bank="Bank#2") rules.append(CategoryRule(None, None, None, None, "Bank#2", None, None, "null"))
rule.name = "null"
rules.append(rule)
categorizer = Nullifier(rules) categorizer = Nullifier(rules)
transactions = categorizer.transform(transactions) transactions = categorizer.transform(transactions)
for t in transactions: for t in transactions:
assert t.category == TransactionCategory("null", CategorySelector.nullifier) assert t.category == TransactionCategory(
"null", CategorySelector(Selector_T.nullifier)
)
def test_tagger(self): def test_tagger(self):
transactions = [ transactions = [
BankTransaction(date(2023, 1, 1), "desc#1", Decimal("-10"), bank="Bank#1") BankTransaction(date(2023, 1, 1), "desc#1", Decimal("-10"), "Bank#1")
] ]
for t in transactions: for t in transactions:
assert not t.category assert not t.category
rules = mock.tag_1.rules categorizer: Transformer = Tagger(mock.tag_1.rules)
for rule in rules:
rule.tag = mock.tag_1.name
categorizer: Transformer = Tagger(rules)
transactions = categorizer.transform(transactions) transactions = categorizer.transform(transactions)
for t in transactions: for t in transactions:
@ -96,32 +94,16 @@ class TestTransform:
def test_categorize(self): def test_categorize(self):
transactions = [ transactions = [
BankTransaction(date(2023, 1, 1), "desc#1", Decimal("-10"), bank="Bank#1") BankTransaction(date(2023, 1, 1), "desc#1", Decimal("-10"), "Bank#1")
] ]
for t in transactions: for t in transactions:
assert not t.category assert not t.category
rules = mock.category1.rules categorizer: Transformer = Categorizer(mock.category1.rules)
for rule in rules:
rule.name = mock.category1.name
categorizer: Transformer = Categorizer(rules)
transactions = categorizer.transform(transactions) transactions = categorizer.transform(transactions)
for t in transactions: for t in transactions:
assert t.category == TransactionCategory("cat#1", CategorySelector.rules) assert t.category == TransactionCategory(
"cat#1", CategorySelector(Selector_T.rules)
def test_rule_limits(self): )
transactions = [
BankTransaction(date.today(), "", Decimal("-60"), bank="Bank#1"),
BankTransaction(date.today(), "", Decimal("-120"), bank="Bank#1"),
]
cat = Category("cat")
cat.rules = [CategoryRule(min=-120, max=-60)]
for r in cat.rules:
r.name = cat.name
transactions = Categorizer(cat.rules).transform(transactions)
assert all(t.category.name == cat.name for t in transactions)