diff --git a/ihatemoney/models.py b/ihatemoney/models.py index d072c555..8385ae74 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -250,13 +250,37 @@ class Project(db.Model): def get_bills(self): """Return the list of bills related to this project""" + return self.order_bills(self.get_bills_unordered()) + + @staticmethod + def order_bills(query): return ( - self.get_bills_unordered() - .order_by(Bill.date.desc()) + query.order_by(Bill.date.desc()) .order_by(Bill.creation_date.desc()) .order_by(Bill.id.desc()) ) + def get_bill_weights(self): + """ + Return all bills for this project, along with the sum of weight for each bill. + Each line is a (float, Bill) tuple. + + Result is unordered. + """ + return ( + db.session.query(func.sum(Person.weight), Bill) + .options(orm.subqueryload(Bill.owers)) + .select_from(Person) + .join(billowers, Bill, Project) + .filter(Person.project_id == Project.id) + .filter(Project.id == self.id) + .group_by(Bill.id) + ) + + def get_bill_weights_ordered(self): + """Ordered version of get_bill_weights""" + return self.order_bills(self.get_bill_weights()) + def get_member_bills(self, member_id): """Return the list of bills related to a specific member""" return ( diff --git a/ihatemoney/templates/list_bills.html b/ihatemoney/templates/list_bills.html index dda16f2e..44aa1bb5 100644 --- a/ihatemoney/templates/list_bills.html +++ b/ihatemoney/templates/list_bills.html @@ -1,7 +1,7 @@ {% extends "sidebar_table_layout.html" %} -{%- macro bill_amount(bill, currency=bill.original_currency, amount=bill.amount) %} - {{ amount|currency(currency) }} ({{ _("%(amount)s each", amount=bill.pay_each_default(amount)|currency(currency)) }}) +{%- macro weighted_bill_amount(bill, weights, currency=bill.original_currency, amount=bill.amount) %} + {{ amount|currency(currency) }} ({{ _("%(amount)s each", amount=(amount / weights)|currency(currency)) }}) {% endmacro -%} {% block title %} - {{ g.project.name }}{% endblock %} @@ -109,7 +109,7 @@ {{ _("Actions") }} - {% for bill in bills.items %} + {% for (weights, bill) in bills.items %} - {{ bill_amount(bill) }} + title="{{ weighted_bill_amount(bill, weights, g.project.default_currency, bill.converted_amount) if bill.original_currency != g.project.default_currency else '' }}"> + {{ weighted_bill_amount(bill, weights) }} diff --git a/ihatemoney/tests/main_test.py b/ihatemoney/tests/main_test.py index fac655d3..678f976a 100644 --- a/ihatemoney/tests/main_test.py +++ b/ihatemoney/tests/main_test.py @@ -108,6 +108,62 @@ class CommandTestCase(BaseTestCase): class ModelsTestCase(IhatemoneyTestCase): + def test_weighted_bills(self): + """Test the SQL request that fetch all bills and weights""" + self.post_project("raclette") + + # add members + self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "tata"}) + # Add a member with a balance=0 : + self.client.post("/raclette/members/add", data={"name": "pépé"}) + + # create bills + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2, 3], + "amount": "10.0", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "red wine", + "payer": 2, + "payed_for": [1], + "amount": "20", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "delicatessen", + "payer": 1, + "payed_for": [1, 2], + "amount": "10", + }, + ) + project = models.Project.query.get_by_name(name="raclette") + for (weight, bill) in project.get_bill_weights().all(): + if bill.what == "red wine": + pay_each_expected = 20 / 2 + self.assertEqual(bill.amount / weight, pay_each_expected) + if bill.what == "fromage à raclette": + pay_each_expected = 10 / 4 + self.assertEqual(bill.amount / weight, pay_each_expected) + if bill.what == "delicatessen": + pay_each_expected = 10 / 3 + self.assertEqual(bill.amount / weight, pay_each_expected) + def test_bill_pay_each(self): self.post_project("raclette") diff --git a/ihatemoney/web.py b/ihatemoney/web.py index 43276041..bffbfc4e 100644 --- a/ihatemoney/web.py +++ b/ihatemoney/web.py @@ -28,7 +28,6 @@ from flask import ( ) from flask_babel import gettext as _ from flask_mail import Message -from sqlalchemy import orm from sqlalchemy_continuum import Operation from werkzeug.exceptions import NotFound from werkzeug.security import check_password_hash, generate_password_hash @@ -609,16 +608,17 @@ def list_bills(): # set the last selected payer as default choice if exists if "last_selected_payer" in session: bill_form.payer.data = session["last_selected_payer"] - # Preload the "owers" relationship for all bills - bills = ( - g.project.get_bills() - .options(orm.subqueryload(Bill.owers)) - .paginate(per_page=100, error_out=True) + + # Each item will be a (weight_sum, Bill) tuple. + # TODO: improve this awkward result using column_property: + # https://docs.sqlalchemy.org/en/14/orm/mapped_sql_expr.html. + weighted_bills = g.project.get_bill_weights_ordered().paginate( + per_page=100, error_out=True ) return render_template( "list_bills.html", - bills=bills, + bills=weighted_bills, member_form=MemberForm(g.project), bill_form=bill_form, csrf_form=csrf_form,