diff --git a/pfbudget/graph.py b/pfbudget/graph.py index b134478..59a2820 100644 --- a/pfbudget/graph.py +++ b/pfbudget/graph.py @@ -12,7 +12,9 @@ if TYPE_CHECKING: from pfbudget.database import DBManager -def monthly(db: DBManager, start: dt.date = dt.date.min, end: dt.date = dt.date.max): +def monthly( + db: DBManager, args: dict, start: dt.date = dt.date.min, end: dt.date = dt.date.max +): transactions = db.get_daterange(start, end) start, end = transactions[0].date, transactions[-1].date monthly_transactions = tuple( @@ -59,10 +61,15 @@ def monthly(db: DBManager, start: dt.date = dt.date.min, end: dt.date = dt.date. ) plt.legend(loc="upper left") plt.tight_layout() - plt.savefig("graph.png") + if args["save"]: + plt.savefig("graph.png") + else: + plt.show() -def discrete(db: DBManager, start: dt.date = dt.date.min, end: dt.date = dt.date.max): +def discrete( + db: DBManager, args: dict, start: dt.date = dt.date.min, end: dt.date = dt.date.max +): transactions = db.get_daterange(start, end) start, end = transactions[0].date, transactions[-1].date monthly_transactions = tuple( @@ -118,4 +125,7 @@ def discrete(db: DBManager, start: dt.date = dt.date.min, end: dt.date = dt.date ) plt.legend(loc="upper left") plt.tight_layout() - plt.savefig("graph.png") + if args["save"]: + plt.savefig("graph.png") + else: + plt.show() diff --git a/pfbudget/runnable.py b/pfbudget/runnable.py index 541a179..5e00372 100644 --- a/pfbudget/runnable.py +++ b/pfbudget/runnable.py @@ -89,6 +89,7 @@ def argparser() -> argparse.ArgumentParser: default="monthly", help="graph option help", ) + p_graph.add_argument("--save", action="store_true") p_graph.set_defaults(func=graph) """ @@ -126,9 +127,9 @@ def graph(args): """ start, end = pfbudget.utils.parse_args_period(args) if args.option == "monthly": - pfbudget.graph.monthly(DBManager(args.database), start, end) + pfbudget.graph.monthly(DBManager(args.database), vars(args), start, end) elif args.option == "discrete": - pfbudget.graph.discrete(DBManager(args.database), start, end) + pfbudget.graph.discrete(DBManager(args.database), vars(args), start, end) def report(args):