diff --git a/ihatemoney/api/common.py b/ihatemoney/api/common.py index cd247cdf..ede76e46 100644 --- a/ihatemoney/api/common.py +++ b/ihatemoney/api/common.py @@ -69,7 +69,7 @@ class ProjectHandler(Resource): return "DELETED" def put(self, project): - form = EditProjectForm(meta={"csrf": False}) + form = EditProjectForm(id=project.id, meta={"csrf": False}) if form.validate() and current_app.config.get("ALLOW_PUBLIC_PROJECT_CREATION"): form.update(project) db.session.commit() diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py index 9f8eefa3..866e5e2e 100644 --- a/ihatemoney/forms.py +++ b/ihatemoney/forms.py @@ -1,5 +1,6 @@ from datetime import datetime from re import match +from types import SimpleNamespace import email_validator from flask import request @@ -110,6 +111,14 @@ class EditProjectForm(FlaskForm): default_currency = SelectField(_("Default Currency"), validators=[DataRequired()]) def __init__(self, *args, **kwargs): + if not hasattr(self, "id"): + # We must access the project to validate the default currency, using its id. + # In ProjectForm, 'id' is provided, but not in this base class, so it *must* + # be provided by callers. + # Since id can be defined as a WTForms.StringField, we mimics it, + # using an object that can have a 'data' attribute. + # It defaults to empty string to ensure that query run smoothly. + self.id = SimpleNamespace(data=kwargs.pop("id", "")) super().__init__(*args, **kwargs) self.default_currency.choices = [ (currency_name, render_localized_currency(currency_name, detailed=True)) @@ -142,6 +151,22 @@ class EditProjectForm(FlaskForm): ) return project + def validate_default_currency(form, field): + project = Project.query.get(form.id.data) + if ( + project is not None + and field.data == CurrencyConverter.no_currency + and project.has_multiple_currencies() + ): + raise ValidationError( + _( + ( + "This project cannot be set to 'no currency'" + " because it contains bills in multiple currencies." + ) + ) + ) + def update(self, project): """Update the project with the information from the form""" project.name = self.name.data @@ -152,7 +177,7 @@ class EditProjectForm(FlaskForm): project.contact_email = self.contact_email.data project.logging_preference = self.logging_preference - project.default_currency = self.default_currency.data + project.switch_currency(self.default_currency.data) return project diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 7984ab76..0ae2ef72 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -17,6 +17,7 @@ from sqlalchemy_continuum import make_versioned, version_class from sqlalchemy_continuum.plugins import FlaskPlugin from werkzeug.security import generate_password_hash +from ihatemoney.currency_convertor import CurrencyConverter from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder from ihatemoney.versioning import ( ConditionalVersioningManager, @@ -139,7 +140,7 @@ class Project(db.Model): "spent": sum( [ bill.pay_each() * member.weight - for bill in self.get_bills().all() + for bill in self.get_bills_unordered().all() if member in bill.owers ] ), @@ -156,7 +157,7 @@ class Project(db.Model): :rtype dict: """ monthly = defaultdict(lambda: defaultdict(float)) - for bill in self.get_bills().all(): + for bill in self.get_bills_unordered().all(): monthly[bill.date.year][bill.date.month] += bill.converted_amount return monthly @@ -215,15 +216,25 @@ class Project(db.Model): def has_bills(self): """return if the project do have bills or not""" - return self.get_bills().count() > 0 + return self.get_bills_unordered().count() > 0 - def get_bills(self): - """Return the list of bills related to this project""" + def has_multiple_currencies(self): + """Return if multiple currencies are used""" + return self.get_bills_unordered().group_by(Bill.original_currency).count() > 1 + + def get_bills_unordered(self): + """Base query for bill list""" return ( Bill.query.join(Person, Project) .filter(Bill.payer_id == Person.id) .filter(Person.project_id == Project.id) .filter(Project.id == self.id) + ) + + def get_bills(self): + """Return the list of bills related to this project""" + return ( + self.get_bills_unordered() .order_by(Bill.date.desc()) .order_by(Bill.creation_date.desc()) .order_by(Bill.id.desc()) @@ -232,11 +243,8 @@ class Project(db.Model): def get_member_bills(self, member_id): """Return the list of bills related to a specific member""" return ( - Bill.query.join(Person, Project) - .filter(Bill.payer_id == Person.id) - .filter(Person.project_id == Project.id) + self.get_bills_unordered() .filter(Person.id == member_id) - .filter(Project.id == self.id) .order_by(Bill.date.desc()) .order_by(Bill.id.desc()) ) @@ -263,6 +271,41 @@ class Project(db.Model): ) return pretty_bills + def switch_currency(self, new_currency): + if new_currency == self.default_currency: + return + # Update converted currency + if new_currency == CurrencyConverter.no_currency: + if self.has_multiple_currencies(): + raise ValueError(f"Can't unset currency of project {self.id}") + + for bill in self.get_bills_unordered(): + # We are removing the currency, and we already checked that all bills + # had the same currency: it means that we can simply strip the currency + # without converting the amounts. We basically ignore the current default_currency + + # Reset converted amount in case it was different from the original amount + bill.converted_amount = bill.amount + # Strip currency + bill.original_currency = CurrencyConverter.no_currency + db.session.add(bill) + else: + for bill in self.get_bills_unordered(): + if bill.original_currency == CurrencyConverter.no_currency: + # Bills that were created without currency will be set to the new currency + bill.original_currency = new_currency + bill.converted_amount = bill.amount + else: + # Convert amount for others, without touching original_currency + bill.converted_amount = CurrencyConverter().exchange_currency( + bill.amount, bill.original_currency, new_currency + ) + db.session.add(bill) + + self.default_currency = new_currency + db.session.add(self) + db.session.commit() + def remove_member(self, member_id): """Remove a member from the project. diff --git a/ihatemoney/templates/dashboard.html b/ihatemoney/templates/dashboard.html index d9c150c4..3e26441a 100644 --- a/ihatemoney/templates/dashboard.html +++ b/ihatemoney/templates/dashboard.html @@ -5,7 +5,7 @@ {{ _("Project") }}{{ _("Number of members") }}{{ _("Number of bills") }}{{_("Newest bill")}}{{_("Oldest bill")}}{{_("Actions")}} {% for project in projects|sort(attribute='name') %} - {{ project.name }}{{ project.members | count }}{{ project.get_bills().count() }} + {{ project.name }}{{ project.members | count }}{{ project.get_bills_unordered().count() }} {% if project.has_bills() %} {{ project.get_bills().all()[0].date }} {{ project.get_bills().all()[-1].date }} diff --git a/ihatemoney/tests/budget_test.py b/ihatemoney/tests/budget_test.py index e34214e9..e3b6778f 100644 --- a/ihatemoney/tests/budget_test.py +++ b/ihatemoney/tests/budget_test.py @@ -4,11 +4,14 @@ import json import re from time import sleep import unittest +from unittest.mock import MagicMock from flask import session +import pytest from werkzeug.security import check_password_hash, generate_password_hash from ihatemoney import models +from ihatemoney.currency_convertor import CurrencyConverter from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase from ihatemoney.versioning import LoggingMode @@ -802,7 +805,8 @@ class BudgetTestCase(IhatemoneyTestCase): self.assertEqual(response.status_code, 200) def test_statistics(self): - self.post_project("raclette") + # Output is checked with the USD sign + self.post_project("raclette", default_currency="USD") # add members self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) @@ -1443,6 +1447,225 @@ class BudgetTestCase(IhatemoneyTestCase): member = models.Person.query.filter(models.Person.id == 1).one_or_none() self.assertEqual(member, None) + def test_currency_switch(self): + + mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1} + converter = CurrencyConverter() + converter.get_rates = MagicMock(return_value=mock_data) + + # A project should be editable + self.post_project("raclette") + + # add members + self.client.post("/raclette/members/add", data={"name": "zorglub"}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "tata"}) + + # create bills + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2, 3], + "amount": "10.0", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "red wine", + "payer": 2, + "payed_for": [1, 3], + "amount": "20", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2017-01-01", + "what": "refund", + "payer": 3, + "payed_for": [2], + "amount": "13.33", + }, + ) + + project = models.Project.query.get("raclette") + + # First all converted_amount should be the same as amount, with no currency + for bill in project.get_bills(): + assert bill.original_currency == CurrencyConverter.no_currency + assert bill.amount == bill.converted_amount + + # Then, switch to EUR, all bills must have been changed to this currency + project.switch_currency("EUR") + for bill in project.get_bills(): + assert bill.original_currency == "EUR" + assert bill.amount == bill.converted_amount + + # Add a bill in EUR, the current default currency + self.client.post( + "/raclette/add", + data={ + "date": "2017-01-01", + "what": "refund from EUR", + "payer": 3, + "payed_for": [2], + "amount": "20", + "original_currency": "EUR", + }, + ) + last_bill = project.get_bills().first() + assert last_bill.converted_amount == last_bill.amount + + # Erase all currencies + project.switch_currency(CurrencyConverter.no_currency) + for bill in project.get_bills(): + assert bill.original_currency == CurrencyConverter.no_currency + assert bill.amount == bill.converted_amount + + # Let's go back to EUR to test conversion + project.switch_currency("EUR") + # This is a bill in CAD + self.client.post( + "/raclette/add", + data={ + "date": "2017-01-01", + "what": "Poutine", + "payer": 3, + "payed_for": [2], + "amount": "18", + "original_currency": "CAD", + }, + ) + last_bill = project.get_bills().first() + expected_amount = converter.exchange_currency(last_bill.amount, "CAD", "EUR") + assert last_bill.converted_amount == expected_amount + + # Switch to USD. Now, NO bill should be in USD, since they already had a currency + project.switch_currency("USD") + for bill in project.get_bills(): + assert bill.original_currency != "USD" + expected_amount = converter.exchange_currency( + bill.amount, bill.original_currency, "USD" + ) + assert bill.converted_amount == expected_amount + + # Switching back to no currency must fail + with pytest.raises(ValueError): + project.switch_currency(CurrencyConverter.no_currency) + + # It also must fails with a nice error using the form + resp = self.client.post( + "/raclette/edit", + data={ + "name": "demonstration", + "password": "demo", + "contact_email": "demo@notmyidea.org", + "project_history": "y", + "default_currency": converter.no_currency, + }, + ) + # A user displayed error should be generated, and its currency should be the same. + self.assertStatus(200, resp) + self.assertIn('

', resp.data.decode("utf-8")) + self.assertEqual(models.Project.query.get("raclette").default_currency, "USD") + + def test_currency_switch_to_bill_currency(self): + + mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1} + converter = CurrencyConverter() + converter.get_rates = MagicMock(return_value=mock_data) + + # Default currency is 'XXX', but we should start from a project with a currency + self.post_project("raclette", default_currency="USD") + + # add members + self.client.post("/raclette/members/add", data={"name": "zorglub"}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + + # Bill with a different currency than project's default + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2], + "amount": "10.0", + "original_currency": "EUR", + }, + ) + + project = models.Project.query.get("raclette") + + bill = project.get_bills().first() + assert bill.converted_amount == converter.exchange_currency( + bill.amount, "EUR", "USD" + ) + + # And switch project to the currency from the bill we created + project.switch_currency("EUR") + bill = project.get_bills().first() + assert bill.converted_amount == bill.amount + + def test_currency_switch_to_no_currency(self): + + mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1} + converter = CurrencyConverter() + converter.get_rates = MagicMock(return_value=mock_data) + + # Default currency is 'XXX', but we should start from a project with a currency + self.post_project("raclette", default_currency="USD") + + # add members + self.client.post("/raclette/members/add", data={"name": "zorglub"}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + + # Bills with a different currency than project's default + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2], + "amount": "10.0", + "original_currency": "EUR", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2017-01-01", + "what": "aspirine", + "payer": 2, + "payed_for": [1, 2], + "amount": "5.0", + "original_currency": "EUR", + }, + ) + + project = models.Project.query.get("raclette") + + for bill in project.get_bills_unordered(): + assert bill.converted_amount == converter.exchange_currency( + bill.amount, "EUR", "USD" + ) + + # And switch project to no currency: amount should be equal to what was submitted + project.switch_currency(converter.no_currency) + no_currency_bills = [ + (bill.amount, bill.converted_amount) for bill in project.get_bills() + ] + assert no_currency_bills == [(5.0, 5.0), (10.0, 10.0)] + if __name__ == "__main__": unittest.main() diff --git a/ihatemoney/tests/common/ihatemoney_testcase.py b/ihatemoney/tests/common/ihatemoney_testcase.py index 2e590590..39b2919e 100644 --- a/ihatemoney/tests/common/ihatemoney_testcase.py +++ b/ihatemoney/tests/common/ihatemoney_testcase.py @@ -30,7 +30,7 @@ class BaseTestCase(TestCase): follow_redirects=True, ) - def post_project(self, name, follow_redirects=True): + def post_project(self, name, follow_redirects=True, default_currency="XXX"): """Create a fake project""" # create the project return self.client.post( @@ -40,18 +40,18 @@ class BaseTestCase(TestCase): "id": name, "password": name, "contact_email": f"{name}@notmyidea.org", - "default_currency": "USD", + "default_currency": default_currency, }, follow_redirects=follow_redirects, ) - def create_project(self, name): + def create_project(self, name, default_currency="XXX"): project = models.Project( id=name, name=str(name), password=generate_password_hash(name), contact_email=f"{name}@notmyidea.org", - default_currency="USD", + default_currency=default_currency, ) models.db.session.add(project) models.db.session.commit() diff --git a/ihatemoney/tests/history_test.py b/ihatemoney/tests/history_test.py index 3ffdbcc6..cee3dca6 100644 --- a/ihatemoney/tests/history_test.py +++ b/ihatemoney/tests/history_test.py @@ -25,7 +25,7 @@ class HistoryTestCase(IhatemoneyTestCase): "name": "demo", "contact_email": "demo@notmyidea.org", "password": "demo", - "default_currency": "USD", + "default_currency": "XXX", } if logging_preference != LoggingMode.DISABLED: @@ -78,7 +78,7 @@ class HistoryTestCase(IhatemoneyTestCase): "contact_email": "demo2@notmyidea.org", "password": "123456", "project_history": "y", - "default_currency": "USD", + "default_currency": "USD", # Currency changed from default } resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True) @@ -103,7 +103,7 @@ class HistoryTestCase(IhatemoneyTestCase): resp.data.decode("utf-8").index("Project renamed "), resp.data.decode("utf-8").index("Project private code changed"), ) - self.assertEqual(resp.data.decode("utf-8").count(" -- "), 4) + self.assertEqual(resp.data.decode("utf-8").count(" -- "), 5) self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) def test_project_privacy_edit(self): @@ -284,7 +284,7 @@ class HistoryTestCase(IhatemoneyTestCase): self.assertIn( "Some entries below contain IP addresses,", resp.data.decode("utf-8") ) - self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 10) + self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12) self.assertEqual(resp.data.decode("utf-8").count(" -- "), 1) # Generate more operations to confirm additional IP info isn't recorded @@ -292,8 +292,8 @@ class HistoryTestCase(IhatemoneyTestCase): resp = self.client.get("/demo/history") self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 10) - self.assertEqual(resp.data.decode("utf-8").count(" -- "), 6) + self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12) + self.assertEqual(resp.data.decode("utf-8").count(" -- "), 7) # Clear IP Data resp = self.client.post("/demo/strip_ip_addresses", follow_redirects=True) @@ -311,7 +311,7 @@ class HistoryTestCase(IhatemoneyTestCase): "Some entries below contain IP addresses,", resp.data.decode("utf-8") ) self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 0) - self.assertEqual(resp.data.decode("utf-8").count(" -- "), 16) + self.assertEqual(resp.data.decode("utf-8").count(" -- "), 19) def test_logs_for_common_actions(self): # adds a member to this project diff --git a/ihatemoney/tests/main_test.py b/ihatemoney/tests/main_test.py index 2e7742f9..2aa3cf29 100644 --- a/ihatemoney/tests/main_test.py +++ b/ihatemoney/tests/main_test.py @@ -240,7 +240,7 @@ class EmailFailureTestCase(IhatemoneyTestCase): class TestCurrencyConverter(unittest.TestCase): converter = CurrencyConverter() - mock_data = {"USD": 1, "EUR": 0.8115} + mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1} converter.get_rates = MagicMock(return_value=mock_data) def test_only_one_instance(self): @@ -249,11 +249,14 @@ class TestCurrencyConverter(unittest.TestCase): self.assertEqual(one, two) def test_get_currencies(self): - self.assertCountEqual(self.converter.get_currencies(), ["USD", "EUR"]) + self.assertCountEqual( + self.converter.get_currencies(), + ["USD", "EUR", "CAD", CurrencyConverter.no_currency], + ) def test_exchange_currency(self): result = self.converter.exchange_currency(100, "USD", "EUR") - self.assertEqual(result, 81.15) + self.assertEqual(result, 80.0) if __name__ == "__main__": diff --git a/ihatemoney/web.py b/ihatemoney/web.py index baacff47..41940b30 100644 --- a/ihatemoney/web.py +++ b/ihatemoney/web.py @@ -36,7 +36,6 @@ from sqlalchemy_continuum import Operation from werkzeug.exceptions import NotFound from werkzeug.security import check_password_hash, generate_password_hash -from ihatemoney.currency_convertor import CurrencyConverter from ihatemoney.forms import ( AdminAuthenticationForm, AuthenticationForm, @@ -400,7 +399,7 @@ def reset_password(): @main.route("//edit", methods=["GET", "POST"]) def edit_project(): - edit_form = EditProjectForm() + edit_form = EditProjectForm(id=g.project.id) import_form = UploadForm() # Import form if import_form.validate_on_submit(): @@ -415,17 +414,6 @@ def edit_project(): # Edit form if edit_form.validate_on_submit(): project = edit_form.update(g.project) - # Update converted currency - if project.default_currency != CurrencyConverter.no_currency: - for bill in project.get_bills(): - - if bill.original_currency == CurrencyConverter.no_currency: - bill.original_currency = project.default_currency - - bill.converted_amount = CurrencyConverter().exchange_currency( - bill.amount, bill.original_currency, project.default_currency - ) - db.session.add(bill) db.session.add(project) db.session.commit()