diff --git a/ihatemoney/api/common.py b/ihatemoney/api/common.py
index cd247cdf..ede76e46 100644
--- a/ihatemoney/api/common.py
+++ b/ihatemoney/api/common.py
@@ -69,7 +69,7 @@ class ProjectHandler(Resource):
return "DELETED"
def put(self, project):
- form = EditProjectForm(meta={"csrf": False})
+ form = EditProjectForm(id=project.id, meta={"csrf": False})
if form.validate() and current_app.config.get("ALLOW_PUBLIC_PROJECT_CREATION"):
form.update(project)
db.session.commit()
diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py
index 9f8eefa3..866e5e2e 100644
--- a/ihatemoney/forms.py
+++ b/ihatemoney/forms.py
@@ -1,5 +1,6 @@
from datetime import datetime
from re import match
+from types import SimpleNamespace
import email_validator
from flask import request
@@ -110,6 +111,14 @@ class EditProjectForm(FlaskForm):
default_currency = SelectField(_("Default Currency"), validators=[DataRequired()])
def __init__(self, *args, **kwargs):
+ if not hasattr(self, "id"):
+ # We must access the project to validate the default currency, using its id.
+ # In ProjectForm, 'id' is provided, but not in this base class, so it *must*
+ # be provided by callers.
+ # Since id can be defined as a WTForms.StringField, we mimics it,
+ # using an object that can have a 'data' attribute.
+ # It defaults to empty string to ensure that query run smoothly.
+ self.id = SimpleNamespace(data=kwargs.pop("id", ""))
super().__init__(*args, **kwargs)
self.default_currency.choices = [
(currency_name, render_localized_currency(currency_name, detailed=True))
@@ -142,6 +151,22 @@ class EditProjectForm(FlaskForm):
)
return project
+ def validate_default_currency(form, field):
+ project = Project.query.get(form.id.data)
+ if (
+ project is not None
+ and field.data == CurrencyConverter.no_currency
+ and project.has_multiple_currencies()
+ ):
+ raise ValidationError(
+ _(
+ (
+ "This project cannot be set to 'no currency'"
+ " because it contains bills in multiple currencies."
+ )
+ )
+ )
+
def update(self, project):
"""Update the project with the information from the form"""
project.name = self.name.data
@@ -152,7 +177,7 @@ class EditProjectForm(FlaskForm):
project.contact_email = self.contact_email.data
project.logging_preference = self.logging_preference
- project.default_currency = self.default_currency.data
+ project.switch_currency(self.default_currency.data)
return project
diff --git a/ihatemoney/models.py b/ihatemoney/models.py
index 7984ab76..0ae2ef72 100644
--- a/ihatemoney/models.py
+++ b/ihatemoney/models.py
@@ -17,6 +17,7 @@ from sqlalchemy_continuum import make_versioned, version_class
from sqlalchemy_continuum.plugins import FlaskPlugin
from werkzeug.security import generate_password_hash
+from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder
from ihatemoney.versioning import (
ConditionalVersioningManager,
@@ -139,7 +140,7 @@ class Project(db.Model):
"spent": sum(
[
bill.pay_each() * member.weight
- for bill in self.get_bills().all()
+ for bill in self.get_bills_unordered().all()
if member in bill.owers
]
),
@@ -156,7 +157,7 @@ class Project(db.Model):
:rtype dict:
"""
monthly = defaultdict(lambda: defaultdict(float))
- for bill in self.get_bills().all():
+ for bill in self.get_bills_unordered().all():
monthly[bill.date.year][bill.date.month] += bill.converted_amount
return monthly
@@ -215,15 +216,25 @@ class Project(db.Model):
def has_bills(self):
"""return if the project do have bills or not"""
- return self.get_bills().count() > 0
+ return self.get_bills_unordered().count() > 0
- def get_bills(self):
- """Return the list of bills related to this project"""
+ def has_multiple_currencies(self):
+ """Return if multiple currencies are used"""
+ return self.get_bills_unordered().group_by(Bill.original_currency).count() > 1
+
+ def get_bills_unordered(self):
+ """Base query for bill list"""
return (
Bill.query.join(Person, Project)
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
.filter(Project.id == self.id)
+ )
+
+ def get_bills(self):
+ """Return the list of bills related to this project"""
+ return (
+ self.get_bills_unordered()
.order_by(Bill.date.desc())
.order_by(Bill.creation_date.desc())
.order_by(Bill.id.desc())
@@ -232,11 +243,8 @@ class Project(db.Model):
def get_member_bills(self, member_id):
"""Return the list of bills related to a specific member"""
return (
- Bill.query.join(Person, Project)
- .filter(Bill.payer_id == Person.id)
- .filter(Person.project_id == Project.id)
+ self.get_bills_unordered()
.filter(Person.id == member_id)
- .filter(Project.id == self.id)
.order_by(Bill.date.desc())
.order_by(Bill.id.desc())
)
@@ -263,6 +271,41 @@ class Project(db.Model):
)
return pretty_bills
+ def switch_currency(self, new_currency):
+ if new_currency == self.default_currency:
+ return
+ # Update converted currency
+ if new_currency == CurrencyConverter.no_currency:
+ if self.has_multiple_currencies():
+ raise ValueError(f"Can't unset currency of project {self.id}")
+
+ for bill in self.get_bills_unordered():
+ # We are removing the currency, and we already checked that all bills
+ # had the same currency: it means that we can simply strip the currency
+ # without converting the amounts. We basically ignore the current default_currency
+
+ # Reset converted amount in case it was different from the original amount
+ bill.converted_amount = bill.amount
+ # Strip currency
+ bill.original_currency = CurrencyConverter.no_currency
+ db.session.add(bill)
+ else:
+ for bill in self.get_bills_unordered():
+ if bill.original_currency == CurrencyConverter.no_currency:
+ # Bills that were created without currency will be set to the new currency
+ bill.original_currency = new_currency
+ bill.converted_amount = bill.amount
+ else:
+ # Convert amount for others, without touching original_currency
+ bill.converted_amount = CurrencyConverter().exchange_currency(
+ bill.amount, bill.original_currency, new_currency
+ )
+ db.session.add(bill)
+
+ self.default_currency = new_currency
+ db.session.add(self)
+ db.session.commit()
+
def remove_member(self, member_id):
"""Remove a member from the project.
diff --git a/ihatemoney/templates/dashboard.html b/ihatemoney/templates/dashboard.html
index d9c150c4..3e26441a 100644
--- a/ihatemoney/templates/dashboard.html
+++ b/ihatemoney/templates/dashboard.html
@@ -5,7 +5,7 @@
{{ _("Project") }} {{ _("Number of members") }} {{ _("Number of bills") }} {{_("Newest bill")}} {{_("Oldest bill")}} {{_("Actions")}}
', resp.data.decode("utf-8")) + self.assertEqual(models.Project.query.get("raclette").default_currency, "USD") + + def test_currency_switch_to_bill_currency(self): + + mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1} + converter = CurrencyConverter() + converter.get_rates = MagicMock(return_value=mock_data) + + # Default currency is 'XXX', but we should start from a project with a currency + self.post_project("raclette", default_currency="USD") + + # add members + self.client.post("/raclette/members/add", data={"name": "zorglub"}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + + # Bill with a different currency than project's default + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2], + "amount": "10.0", + "original_currency": "EUR", + }, + ) + + project = models.Project.query.get("raclette") + + bill = project.get_bills().first() + assert bill.converted_amount == converter.exchange_currency( + bill.amount, "EUR", "USD" + ) + + # And switch project to the currency from the bill we created + project.switch_currency("EUR") + bill = project.get_bills().first() + assert bill.converted_amount == bill.amount + + def test_currency_switch_to_no_currency(self): + + mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1} + converter = CurrencyConverter() + converter.get_rates = MagicMock(return_value=mock_data) + + # Default currency is 'XXX', but we should start from a project with a currency + self.post_project("raclette", default_currency="USD") + + # add members + self.client.post("/raclette/members/add", data={"name": "zorglub"}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + + # Bills with a different currency than project's default + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2], + "amount": "10.0", + "original_currency": "EUR", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2017-01-01", + "what": "aspirine", + "payer": 2, + "payed_for": [1, 2], + "amount": "5.0", + "original_currency": "EUR", + }, + ) + + project = models.Project.query.get("raclette") + + for bill in project.get_bills_unordered(): + assert bill.converted_amount == converter.exchange_currency( + bill.amount, "EUR", "USD" + ) + + # And switch project to no currency: amount should be equal to what was submitted + project.switch_currency(converter.no_currency) + no_currency_bills = [ + (bill.amount, bill.converted_amount) for bill in project.get_bills() + ] + assert no_currency_bills == [(5.0, 5.0), (10.0, 10.0)] + if __name__ == "__main__": unittest.main() diff --git a/ihatemoney/tests/common/ihatemoney_testcase.py b/ihatemoney/tests/common/ihatemoney_testcase.py index 2e590590..39b2919e 100644 --- a/ihatemoney/tests/common/ihatemoney_testcase.py +++ b/ihatemoney/tests/common/ihatemoney_testcase.py @@ -30,7 +30,7 @@ class BaseTestCase(TestCase): follow_redirects=True, ) - def post_project(self, name, follow_redirects=True): + def post_project(self, name, follow_redirects=True, default_currency="XXX"): """Create a fake project""" # create the project return self.client.post( @@ -40,18 +40,18 @@ class BaseTestCase(TestCase): "id": name, "password": name, "contact_email": f"{name}@notmyidea.org", - "default_currency": "USD", + "default_currency": default_currency, }, follow_redirects=follow_redirects, ) - def create_project(self, name): + def create_project(self, name, default_currency="XXX"): project = models.Project( id=name, name=str(name), password=generate_password_hash(name), contact_email=f"{name}@notmyidea.org", - default_currency="USD", + default_currency=default_currency, ) models.db.session.add(project) models.db.session.commit() diff --git a/ihatemoney/tests/history_test.py b/ihatemoney/tests/history_test.py index 3ffdbcc6..cee3dca6 100644 --- a/ihatemoney/tests/history_test.py +++ b/ihatemoney/tests/history_test.py @@ -25,7 +25,7 @@ class HistoryTestCase(IhatemoneyTestCase): "name": "demo", "contact_email": "demo@notmyidea.org", "password": "demo", - "default_currency": "USD", + "default_currency": "XXX", } if logging_preference != LoggingMode.DISABLED: @@ -78,7 +78,7 @@ class HistoryTestCase(IhatemoneyTestCase): "contact_email": "demo2@notmyidea.org", "password": "123456", "project_history": "y", - "default_currency": "USD", + "default_currency": "USD", # Currency changed from default } resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True) @@ -103,7 +103,7 @@ class HistoryTestCase(IhatemoneyTestCase): resp.data.decode("utf-8").index("Project renamed "), resp.data.decode("utf-8").index("Project private code changed"), ) - self.assertEqual(resp.data.decode("utf-8").count("