refactoring

This commit is contained in:
Youe Graillot 2021-11-28 15:58:06 +01:00
parent 12f59f8288
commit ef3b9aac71
4 changed files with 125 additions and 106 deletions

View file

@ -151,8 +151,7 @@ class BillsHandler(Resource):
def post(self, project): def post(self, project):
form = get_billform_for(project, True, meta={"csrf": False}) form = get_billform_for(project, True, meta={"csrf": False})
if form.validate(): if form.validate():
bill = Bill() bill = form.export(project)
form.save(bill, project)
db.session.add(bill) db.session.add(bill)
db.session.commit() db.session.commit()
return bill.id, 201 return bill.id, 201

View file

@ -23,7 +23,7 @@ from wtforms.validators import (
) )
from ihatemoney.currency_convertor import CurrencyConverter from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.models import LoggingMode, Person, Project from ihatemoney.models import Bill, LoggingMode, Person, Project
from ihatemoney.utils import ( from ihatemoney.utils import (
eval_arithmetic_expression, eval_arithmetic_expression,
render_localized_currency, render_localized_currency,
@ -321,33 +321,31 @@ class BillForm(FlaskForm):
submit = SubmitField(_("Submit")) submit = SubmitField(_("Submit"))
submit2 = SubmitField(_("Submit and add a new one")) submit2 = SubmitField(_("Submit and add a new one"))
def export(self, project):
return Bill(
self.amount.data,
self.date.data,
self.external_link.data,
self.original_currency.data,
Person.query.get_by_ids(project, self.payed_for.data),
self.payer.data,
project.default_currency,
self.what.data,
)
def save(self, bill, project): def save(self, bill, project):
bill.payer_id = self.payer.data bill.payer_id = self.payer.data
bill.amount = self.amount.data bill.amount = self.amount.data
bill.what = self.what.data bill.what = self.what.data
bill.external_link = self.external_link.data bill.external_link = self.external_link.data
bill.date = self.date.data bill.date = self.date.data
bill.owers = [Person.query.get(ower, project) for ower in self.payed_for.data] bill.owers = Person.query.get_by_ids(project, self.payed_for.data)
bill.original_currency = self.original_currency.data bill.original_currency = self.original_currency.data
bill.converted_amount = self.currency_helper.exchange_currency( bill.converted_amount = self.currency_helper.exchange_currency(
bill.amount, bill.original_currency, project.default_currency bill.amount, bill.original_currency, project.default_currency
) )
return bill return bill
def fake_form(self, bill, project):
bill.payer_id = self.payer
bill.amount = self.amount
bill.what = self.what
bill.external_link = ""
bill.date = self.date
bill.owers = [Person.query.get(ower, project) for ower in self.payed_for]
bill.original_currency = self.original_currency
bill.converted_amount = self.currency_helper.exchange_currency(
bill.amount, bill.original_currency, project.default_currency
)
return bill
def fill(self, bill, project): def fill(self, bill, project):
self.payer.data = bill.payer_id self.payer.data = bill.payer_id
self.amount.data = bill.amount self.amount.data = bill.amount

View file

@ -1,6 +1,7 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from dateutil.parser import parse
from debts import settle from debts import settle
from flask import current_app, g from flask import current_app, g
from flask_sqlalchemy import BaseQuery, SQLAlchemy from flask_sqlalchemy import BaseQuery, SQLAlchemy
@ -19,6 +20,7 @@ from werkzeug.security import generate_password_hash
from ihatemoney.currency_convertor import CurrencyConverter from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder
from ihatemoney.utils import get_members, same_bill
from ihatemoney.versioning import ( from ihatemoney.versioning import (
ConditionalVersioningManager, ConditionalVersioningManager,
LoggingMode, LoggingMode,
@ -320,6 +322,42 @@ class Project(db.Model):
db.session.add(self) db.session.add(self)
db.session.commit() db.session.commit()
def import_bills(self, bills: list[dict]):
"""Import bills from a list of dictionaries"""
# Add members not already in the project
members_project = [str(m) for m in self.members]
members_new = [
m for m in get_members(bills) if str(m[0]) not in members_project
]
for m in members_new:
Person(name=m[0], project=self, weight=m[1])
db.session.commit()
# Import bills not already in the project
bills_project = self.get_pretty_bills()
id_dict = {m.name: m.id for m in self.members}
for b in bills:
same = False
for b_p in bills_project:
if same_bill(b_p, b):
same = True
break
if not same:
# Create bills
db.session.add(
Bill(
b["amount"],
parse(b["date"]),
"",
b["currency"],
Person.query.get_by_names(b["owers"], self),
id_dict[b["payer_name"]],
self.default_currency,
b["what"],
)
)
db.session.commit()
def remove_member(self, member_id): def remove_member(self, member_id):
"""Remove a member from the project. """Remove a member from the project.
@ -435,16 +473,19 @@ class Project(db.Model):
("Alice", 20, ("Amina", "Alice"), "Beer !"), ("Alice", 20, ("Amina", "Alice"), "Beer !"),
("Amina", 50, ("Amina", "Alice", "Georg"), "AMAP"), ("Amina", 50, ("Amina", "Alice", "Georg"), "AMAP"),
) )
for (payer, amount, owers, subject) in operations: for (payer, amount, owers, what) in operations:
bill = Bill() db.session.add(
bill.payer_id = members[payer].id Bill(
bill.what = subject amount,
bill.owers = [members[name] for name in owers] None,
bill.amount = amount None,
bill.original_currency = "XXX" "XXX",
bill.converted_amount = amount [members[name] for name in owers],
members[payer].id,
db.session.add(bill) project.default_currency,
what,
)
)
db.session.commit() db.session.commit()
return project return project
@ -459,6 +500,13 @@ class Person(db.Model):
.one_or_none() .one_or_none()
) )
def get_by_names(self, names, project):
return (
Person.query.filter(Person.name.in_(names))
.filter(Person.project_id == project.id)
.all()
)
def get(self, id, project=None): def get(self, id, project=None):
if not project: if not project:
project = g.project project = g.project
@ -468,6 +516,15 @@ class Person(db.Model):
.one_or_none() .one_or_none()
) )
def get_by_ids(self, ids, project=None):
if not project:
project = g.project
return (
Person.query.filter(Person.id.in_(ids))
.filter(Person.project_id == project.id)
.all()
)
query_class = PersonQuery query_class = PersonQuery
# Direct SQLAlchemy-Continuum to track changes to this model # Direct SQLAlchemy-Continuum to track changes to this model
@ -561,6 +618,31 @@ class Bill(db.Model):
archive = db.Column(db.Integer, db.ForeignKey("archive.id")) archive = db.Column(db.Integer, db.ForeignKey("archive.id"))
currency_helper = CurrencyConverter()
def __init__(
self,
amount,
date,
external_link,
original_currency,
owers,
payer_id,
project_default_currency,
what,
) -> None:
super().__init__()
self.amount = amount
self.date = date
self.external_link = external_link
self.original_currency = original_currency
self.owers = owers
self.payer_id = payer_id
self.what = what
self.converted_amount = self.currency_helper.exchange_currency(
self.amount, self.original_currency, project_default_currency
)
@property @property
def _to_serialize(self): def _to_serialize(self):
return { return {

View file

@ -60,11 +60,9 @@ from ihatemoney.utils import (
Redirect303, Redirect303,
csv2list_of_dicts, csv2list_of_dicts,
format_form_errors, format_form_errors,
get_members,
list_of_dicts2csv, list_of_dicts2csv,
list_of_dicts2json, list_of_dicts2json,
render_localized_template, render_localized_template,
same_bill,
send_email, send_email,
) )
@ -451,37 +449,33 @@ def import_project():
try: try:
data = form.file.data data = form.file.data
if data.mimetype == "application/json": if data.mimetype == "application/json":
json_file = json.load(data.stream) bills = json.load(data.stream)
elif data.mimetype == "text/csv": elif data.mimetype == "text/csv":
try: try:
json_file = csv2list_of_dicts(data) bills = csv2list_of_dicts(data)
except Exception as e: except Exception as b:
raise ValueError(_("Unable to parse CSV")) raise ValueError(_("Unable to parse CSV"))
else: else:
raise ValueError("Unsupported file type") raise ValueError("Unsupported file type")
# Check if JSON is correct # Check data
attr = [ attr = [
"what",
"payer_name",
"payer_weight",
"amount", "amount",
"currency", "currency",
"date", "date",
"owers", "owers",
"payer_name",
"payer_weight",
"what",
] ]
attr.sort()
currencies = set() currencies = set()
for e in json_file: for b in bills:
# If currency is absent, empty, or explicitly set to XXX if b.get("currency", "") in ["", "XXX"]:
# set it to project default. b["currency"] = g.project.default_currency
if e.get("currency", "") in ["", "XXX"]:
e["currency"] = g.project.default_currency
for a in attr: for a in attr:
if a not in e: if a not in b:
raise ValueError(_("Missing attribute {}").format(a)) raise ValueError(_("Missing attribute {}").format(a))
# Keep track of currencies currencies.add(b["currency"])
currencies.add(e["currency"])
# Additional checks if project has no default currency # Additional checks if project has no default currency
if g.project.default_currency == CurrencyConverter.no_currency: if g.project.default_currency == CurrencyConverter.no_currency:
@ -493,68 +487,15 @@ def import_project():
) )
) )
# Strip currency from bills (since it's the same for every bill) # Strip currency from bills (since it's the same for every bill)
for e in json_file: for b in bills:
e["currency"] = CurrencyConverter.no_currency b["currency"] = CurrencyConverter.no_currency
# From json : export list of members g.project.import_bills(bills)
members_json = get_members(json_file)
members = g.project.members
members_already_here = list()
for m in members:
members_already_here.append(str(m))
# List all members not in the project and weight associated
# List of tuples (name,weight)
members_to_add = list()
for i in members_json:
if str(i[0]) not in members_already_here:
members_to_add.append(i)
# List bills not in the project
# Same format than JSON element
project_bills = g.project.get_pretty_bills()
bill_to_add = list()
for j in json_file:
same = False
for p in project_bills:
if same_bill(p, j):
same = True
break
if not same:
bill_to_add.append(j)
# Add users to DB
for m in members_to_add:
Person(name=m[0], project=g.project, weight=m[1])
db.session.commit()
id_dict = {}
for i in g.project.members:
id_dict[i.name] = i.id
# Create bills
for b in bill_to_add:
owers_id = list()
for ower in b["owers"]:
owers_id.append(id_dict[ower])
bill = Bill()
form = get_billform_for(g.project)
form.what = b["what"]
form.amount = b["amount"]
form.original_currency = b["currency"]
form.date = parse(b["date"])
form.payer = id_dict[b["payer_name"]]
form.payed_for = owers_id
db.session.add(form.fake_form(bill, g.project))
# Add bills to DB
db.session.commit()
flash(_("Project successfully uploaded")) flash(_("Project successfully uploaded"))
return redirect(url_for("main.list_bills")) return redirect(url_for("main.list_bills"))
except ValueError as e: except ValueError as b:
flash(e.args[0], category="danger") flash(b.args[0], category="danger")
return redirect(url_for(".edit_project")) return redirect(url_for(".edit_project"))
else: else:
for component, errors in form.errors.items(): for component, errors in form.errors.items():
@ -777,8 +718,7 @@ def add_bill():
session["last_selected_payer"] = form.payer.data session["last_selected_payer"] = form.payer.data
session.update() session.update()
bill = Bill() db.session.add(form.export(g.project))
db.session.add(form.save(bill, g.project))
db.session.commit() db.session.commit()
flash(_("The bill has been added")) flash(_("The bill has been added"))