[Fix] Add typing information to tests

And fix caught errors.
This commit is contained in:
Luís Murta 2023-04-23 00:51:22 +01:00
parent 541295ef05
commit 761720b712
Signed by: satprog
GPG Key ID: 169EF1BBD7049F94
7 changed files with 38 additions and 32 deletions

View File

@ -1,9 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import date
from typing import Sequence
from pfbudget.db.model import Transaction from pfbudget.db.model import Bank, Transaction
class Extract(ABC): class Extractor(ABC):
@abstractmethod @abstractmethod
def extract(self) -> list[Transaction]: def extract(
return NotImplementedError self, bank: Bank, start: date = date.min, end: date = date.max
) -> Sequence[Transaction]:
raise NotImplementedError

View File

@ -1,21 +1,20 @@
import datetime as dt from datetime import date
from typing import Sequence from typing import Sequence
from pfbudget.db.model import Bank, BankTransaction from pfbudget.db.model import Bank, BankTransaction
from pfbudget.utils.converters import convert from pfbudget.utils.converters import convert
from .exceptions import BankError, DownloadError, ExtractError from .exceptions import BankError, DownloadError, ExtractError
from .extract import Extract from .extract import Extractor
from .nordigen import NordigenClient from .nordigen import NordigenClient
class PSD2Extractor(Extract): class PSD2Extractor(Extractor):
def __init__(self, client: NordigenClient): def __init__(self, client: NordigenClient):
self.__client = client self.__client = client
def extract( def extract(
self, bank: Bank, start=dt.date.min, end=dt.date.max self, bank: Bank, start: date = date.min, end: date = date.max
) -> Sequence[BankTransaction]: ) -> Sequence[BankTransaction]:
if not bank.nordigen: if not bank.nordigen:
raise BankError("Bank doesn't have Nordigen info") raise BankError("Bank doesn't have Nordigen info")

View File

@ -1,5 +1,5 @@
from copy import deepcopy from copy import deepcopy
from typing import Sequence from typing import Iterable, Sequence
from pfbudget.db.model import ( from pfbudget.db.model import (
CategoryRule, CategoryRule,
@ -13,7 +13,7 @@ from .transform import Transformer
class Categorizer(Transformer): class Categorizer(Transformer):
def __init__(self, rules: Sequence[CategoryRule]): def __init__(self, rules: Iterable[CategoryRule]):
self.rules = rules self.rules = rules
def transform(self, transactions: Sequence[Transaction]) -> Sequence[Transaction]: def transform(self, transactions: Sequence[Transaction]) -> Sequence[Transaction]:

View File

@ -1,12 +1,12 @@
from copy import deepcopy from copy import deepcopy
from typing import Sequence from typing import Iterable, Sequence
from pfbudget.db.model import TagRule, Transaction, TransactionTag from pfbudget.db.model import TagRule, Transaction, TransactionTag
from .transform import Transformer from .transform import Transformer
class Tagger(Transformer): class Tagger(Transformer):
def __init__(self, rules: Sequence[TagRule]): def __init__(self, rules: Iterable[TagRule]):
self.rules = rules self.rules = rules
def transform(self, transactions: Sequence[Transaction]) -> Sequence[Transaction]: def transform(self, transactions: Sequence[Transaction]) -> Sequence[Transaction]:
@ -18,7 +18,7 @@ class Tagger(Transformer):
def transform_inplace(self, transactions: Sequence[Transaction]) -> None: def transform_inplace(self, transactions: Sequence[Transaction]) -> None:
for rule in self.rules: for rule in self.rules:
for transaction in transactions: for transaction in transactions:
if rule.tag in transaction.tags: if rule.tag in [tag.tag for tag in transaction.tags]:
continue continue
if not rule.matches(transaction): if not rule.matches(transaction):

View File

@ -6,9 +6,9 @@ from pfbudget.db.model import Transaction
class Transformer(ABC): class Transformer(ABC):
@abstractmethod @abstractmethod
def transform(self, _: Sequence[Transaction]) -> Sequence[Transaction]: def transform(self, transactions: Sequence[Transaction]) -> Sequence[Transaction]:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def transform_inplace(self, _: Sequence[Transaction]) -> None: def transform_inplace(self, transactions: Sequence[Transaction]) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -1,30 +1,32 @@
import datetime as dt import datetime as dt
from decimal import Decimal from decimal import Decimal
from typing import Any, Optional
import pytest import pytest
import requests import requests
import mocks.nordigen as mock import mocks.nordigen as mock
from pfbudget.db.model import Bank, BankTransaction, Nordigen from pfbudget.db.model import AccountType, Bank, BankTransaction, Nordigen
from pfbudget.extract.exceptions import BankError, CredentialsError from pfbudget.extract.exceptions import BankError, CredentialsError
from pfbudget.extract.extract import Extractor
from pfbudget.extract.nordigen import NordigenClient, NordigenCredentials from pfbudget.extract.nordigen import NordigenClient, NordigenCredentials
from pfbudget.extract.psd2 import PSD2Extractor from pfbudget.extract.psd2 import PSD2Extractor
class MockGet: class MockGet:
def __init__(self, mock_exception=None): def __init__(self, mock_exception: Optional[Exception] = None):
self._status_code = 200 self._status_code = 200
self._mock_exception = mock_exception self._mock_exception = mock_exception
def __call__(self, *args, **kwargs): def __call__(self, *args: Any, **kwargs: Any):
if self._mock_exception: if self._mock_exception:
raise self._mock_exception raise self._mock_exception
self._headers = kwargs["headers"] self._headers: dict[str, str] = kwargs["headers"]
if "Authorization" not in self._headers or not self._headers["Authorization"]: if "Authorization" not in self._headers or not self._headers["Authorization"]:
self._status_code = 401 self._status_code = 401
self.url = kwargs["url"] self.url: str = kwargs["url"]
return self return self
@property @property
@ -47,7 +49,7 @@ class MockGet:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_requests(monkeypatch): def mock_requests(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr("requests.get", MockGet()) monkeypatch.setattr("requests.get", MockGet())
monkeypatch.delattr("requests.post") monkeypatch.delattr("requests.post")
monkeypatch.delattr("requests.put") monkeypatch.delattr("requests.put")
@ -55,14 +57,14 @@ def mock_requests(monkeypatch):
@pytest.fixture @pytest.fixture
def extractor() -> NordigenClient: def extractor() -> Extractor:
credentials = NordigenCredentials("ID", "KEY", "TOKEN") credentials = NordigenCredentials("ID", "KEY", "TOKEN")
return PSD2Extractor(NordigenClient(credentials)) return PSD2Extractor(NordigenClient(credentials))
@pytest.fixture @pytest.fixture
def bank() -> list[Bank]: def bank() -> Bank:
bank = Bank("Bank#1", "", "") bank = Bank("Bank#1", "", AccountType.checking)
bank.nordigen = Nordigen("", "", mock.id, False) bank.nordigen = Nordigen("", "", mock.id, False)
return bank return bank
@ -73,18 +75,20 @@ class TestExtractPSD2:
with pytest.raises(CredentialsError): with pytest.raises(CredentialsError):
NordigenClient(cred) NordigenClient(cred)
def test_no_psd2_bank(self, extractor): def test_no_psd2_bank(self, extractor: Extractor):
with pytest.raises(BankError): with pytest.raises(BankError):
extractor.extract(Bank("", "", "")) extractor.extract(Bank("", "", AccountType.checking))
def test_timeout(self, monkeypatch, extractor, bank): def test_timeout(
self, monkeypatch: pytest.MonkeyPatch, extractor: Extractor, bank: Bank
):
monkeypatch.setattr( monkeypatch.setattr(
"requests.get", MockGet(mock_exception=requests.ReadTimeout) "requests.get", MockGet(mock_exception=requests.ReadTimeout())
) )
with pytest.raises(requests.Timeout): with pytest.raises(requests.Timeout):
extractor.extract(bank) extractor.extract(bank)
def test_extract(self, extractor, bank): def test_extract(self, extractor: Extractor, bank: Bank):
assert extractor.extract(bank) == [ assert extractor.extract(bank) == [
BankTransaction( BankTransaction(
dt.date(2023, 1, 14), "string", Decimal("328.18"), "Bank#1" dt.date(2023, 1, 14), "string", Decimal("328.18"), "Bank#1"

View File

@ -4,7 +4,6 @@ from decimal import Decimal
import mocks.categories as mock import mocks.categories as mock
from pfbudget.db.model import ( from pfbudget.db.model import (
Bank,
BankTransaction, BankTransaction,
CategoryRule, CategoryRule,
CategorySelector, CategorySelector,
@ -102,7 +101,7 @@ class TestTransform:
assert not t.category assert not t.category
categorizer: Transformer = Categorizer(mock.category1.rules) categorizer: Transformer = Categorizer(mock.category1.rules)
transactions: Transformer = categorizer.transform(transactions) transactions = categorizer.transform(transactions)
for t in transactions: for t in transactions:
assert t.category == TransactionCategory( assert t.category == TransactionCategory(