diff --git a/ihatemoney/tests/budget_test.py b/ihatemoney/tests/budget_test.py index b26e9572..1537233f 100644 --- a/ihatemoney/tests/budget_test.py +++ b/ihatemoney/tests/budget_test.py @@ -2100,3 +2100,30 @@ class TestBudget(IhatemoneyTestCase): session["last_selected_payer_per_project"]["tartiflette"] == members_ids_tartif[2] ) + + def test_remember_payed_for(self): + """ + Tests that the last ower is remembered + """ + self.post_project("raclette") + self.client.post("/raclette/members/add", data={"name": "zorglub"}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "pipistrelle"}) + members_ids = [m.id for m in self.get_project("raclette").members] + # create a bill + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "fromage à raclette", + "payer": members_ids[1], + "payed_for": members_ids[1:], + "amount": "25", + }, + ) + + with self.client as c: + c.post("/authenticate", data={"id": "raclette", "password": "raclette"}) + assert isinstance(session["last_selected_payed_for"], dict) + assert "raclette" in session["last_selected_payed_for"] + assert session["last_selected_payed_for"]["raclette"] == members_ids[1:] diff --git a/ihatemoney/web.py b/ihatemoney/web.py index 2afcb312..bbb19d3d 100644 --- a/ihatemoney/web.py +++ b/ihatemoney/web.py @@ -645,7 +645,7 @@ def list_bills(): bill_form = get_billform_for(g.project) # Used for CSRF validation csrf_form = EmptyForm() - # set the last selected payer as default choice if exists + # set the last selected payer and last selected owers as default choice if they exist if "last_selected_payer_per_project" in session: if g.project.id in session["last_selected_payer_per_project"]: bill_form.payer.data = session["last_selected_payer_per_project"][ @@ -655,6 +655,11 @@ def list_bills(): else: if "last_selected_payer" in session: bill_form.payer.data = session["last_selected_payer"] + if ( + "last_selected_payed_for" in session + and g.project.id in session["last_selected_payed_for"] + ): + bill_form.payed_for.data = session["last_selected_payed_for"][g.project.id] # Each item will be a (weight_sum, Bill) tuple. # TODO: improve this awkward result using column_property: @@ -758,10 +763,13 @@ def add_bill(): form = get_billform_for(g.project) if request.method == "POST": if form.validate(): - # save last selected payer in session + # save last selected payer and last selected owers in session if "last_selected_payer_per_project" not in session: session["last_selected_payer_per_project"] = {} session["last_selected_payer_per_project"][g.project.id] = form.payer.data + if "last_selected_payed_for" not in session: + session["last_selected_payed_for"] = {} + session["last_selected_payed_for"][g.project.id] = form.payed_for.data session.update() db.session.add(form.export(g.project))