Rework currency switching (#661)

Co-authored-by: Alexis Métaireau <alexis@notmyidea.org>

Currency switching is both simpler and less powerful. This was done primarily for users, to have a clear and logical understanding, but the code is also simpler. The main change is that it is now forbidden to switch a project to "no currency" if bills don't share the same currency.

Also, tests assume that projects are created without currency, as in the web UI.
This commit is contained in:
Glandos 2021-07-06 21:51:32 +02:00 committed by GitHub
parent fec5a82b0c
commit 07b86bc580
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 322 additions and 40 deletions

View file

@ -69,7 +69,7 @@ class ProjectHandler(Resource):
return "DELETED" return "DELETED"
def put(self, project): def put(self, project):
form = EditProjectForm(meta={"csrf": False}) form = EditProjectForm(id=project.id, meta={"csrf": False})
if form.validate() and current_app.config.get("ALLOW_PUBLIC_PROJECT_CREATION"): if form.validate() and current_app.config.get("ALLOW_PUBLIC_PROJECT_CREATION"):
form.update(project) form.update(project)
db.session.commit() db.session.commit()

View file

@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from re import match from re import match
from types import SimpleNamespace
import email_validator import email_validator
from flask import request from flask import request
@ -110,6 +111,14 @@ class EditProjectForm(FlaskForm):
default_currency = SelectField(_("Default Currency"), validators=[DataRequired()]) default_currency = SelectField(_("Default Currency"), validators=[DataRequired()])
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if not hasattr(self, "id"):
# We must access the project to validate the default currency, using its id.
# In ProjectForm, 'id' is provided, but not in this base class, so it *must*
# be provided by callers.
# Since id can be defined as a WTForms.StringField, we mimics it,
# using an object that can have a 'data' attribute.
# It defaults to empty string to ensure that query run smoothly.
self.id = SimpleNamespace(data=kwargs.pop("id", ""))
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.default_currency.choices = [ self.default_currency.choices = [
(currency_name, render_localized_currency(currency_name, detailed=True)) (currency_name, render_localized_currency(currency_name, detailed=True))
@ -142,6 +151,22 @@ class EditProjectForm(FlaskForm):
) )
return project return project
def validate_default_currency(form, field):
project = Project.query.get(form.id.data)
if (
project is not None
and field.data == CurrencyConverter.no_currency
and project.has_multiple_currencies()
):
raise ValidationError(
_(
(
"This project cannot be set to 'no currency'"
" because it contains bills in multiple currencies."
)
)
)
def update(self, project): def update(self, project):
"""Update the project with the information from the form""" """Update the project with the information from the form"""
project.name = self.name.data project.name = self.name.data
@ -152,7 +177,7 @@ class EditProjectForm(FlaskForm):
project.contact_email = self.contact_email.data project.contact_email = self.contact_email.data
project.logging_preference = self.logging_preference project.logging_preference = self.logging_preference
project.default_currency = self.default_currency.data project.switch_currency(self.default_currency.data)
return project return project

View file

@ -17,6 +17,7 @@ from sqlalchemy_continuum import make_versioned, version_class
from sqlalchemy_continuum.plugins import FlaskPlugin from sqlalchemy_continuum.plugins import FlaskPlugin
from werkzeug.security import generate_password_hash from werkzeug.security import generate_password_hash
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder
from ihatemoney.versioning import ( from ihatemoney.versioning import (
ConditionalVersioningManager, ConditionalVersioningManager,
@ -139,7 +140,7 @@ class Project(db.Model):
"spent": sum( "spent": sum(
[ [
bill.pay_each() * member.weight bill.pay_each() * member.weight
for bill in self.get_bills().all() for bill in self.get_bills_unordered().all()
if member in bill.owers if member in bill.owers
] ]
), ),
@ -156,7 +157,7 @@ class Project(db.Model):
:rtype dict: :rtype dict:
""" """
monthly = defaultdict(lambda: defaultdict(float)) monthly = defaultdict(lambda: defaultdict(float))
for bill in self.get_bills().all(): for bill in self.get_bills_unordered().all():
monthly[bill.date.year][bill.date.month] += bill.converted_amount monthly[bill.date.year][bill.date.month] += bill.converted_amount
return monthly return monthly
@ -215,15 +216,25 @@ class Project(db.Model):
def has_bills(self): def has_bills(self):
"""return if the project do have bills or not""" """return if the project do have bills or not"""
return self.get_bills().count() > 0 return self.get_bills_unordered().count() > 0
def get_bills(self): def has_multiple_currencies(self):
"""Return the list of bills related to this project""" """Return if multiple currencies are used"""
return self.get_bills_unordered().group_by(Bill.original_currency).count() > 1
def get_bills_unordered(self):
"""Base query for bill list"""
return ( return (
Bill.query.join(Person, Project) Bill.query.join(Person, Project)
.filter(Bill.payer_id == Person.id) .filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id) .filter(Person.project_id == Project.id)
.filter(Project.id == self.id) .filter(Project.id == self.id)
)
def get_bills(self):
"""Return the list of bills related to this project"""
return (
self.get_bills_unordered()
.order_by(Bill.date.desc()) .order_by(Bill.date.desc())
.order_by(Bill.creation_date.desc()) .order_by(Bill.creation_date.desc())
.order_by(Bill.id.desc()) .order_by(Bill.id.desc())
@ -232,11 +243,8 @@ class Project(db.Model):
def get_member_bills(self, member_id): def get_member_bills(self, member_id):
"""Return the list of bills related to a specific member""" """Return the list of bills related to a specific member"""
return ( return (
Bill.query.join(Person, Project) self.get_bills_unordered()
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
.filter(Person.id == member_id) .filter(Person.id == member_id)
.filter(Project.id == self.id)
.order_by(Bill.date.desc()) .order_by(Bill.date.desc())
.order_by(Bill.id.desc()) .order_by(Bill.id.desc())
) )
@ -263,6 +271,41 @@ class Project(db.Model):
) )
return pretty_bills return pretty_bills
def switch_currency(self, new_currency):
if new_currency == self.default_currency:
return
# Update converted currency
if new_currency == CurrencyConverter.no_currency:
if self.has_multiple_currencies():
raise ValueError(f"Can't unset currency of project {self.id}")
for bill in self.get_bills_unordered():
# We are removing the currency, and we already checked that all bills
# had the same currency: it means that we can simply strip the currency
# without converting the amounts. We basically ignore the current default_currency
# Reset converted amount in case it was different from the original amount
bill.converted_amount = bill.amount
# Strip currency
bill.original_currency = CurrencyConverter.no_currency
db.session.add(bill)
else:
for bill in self.get_bills_unordered():
if bill.original_currency == CurrencyConverter.no_currency:
# Bills that were created without currency will be set to the new currency
bill.original_currency = new_currency
bill.converted_amount = bill.amount
else:
# Convert amount for others, without touching original_currency
bill.converted_amount = CurrencyConverter().exchange_currency(
bill.amount, bill.original_currency, new_currency
)
db.session.add(bill)
self.default_currency = new_currency
db.session.add(self)
db.session.commit()
def remove_member(self, member_id): def remove_member(self, member_id):
"""Remove a member from the project. """Remove a member from the project.

View file

@ -5,7 +5,7 @@
<thead><tr><th>{{ _("Project") }}</th><th>{{ _("Number of members") }}</th><th>{{ _("Number of bills") }}</th><th>{{_("Newest bill")}}</th><th>{{_("Oldest bill")}}</th><th>{{_("Actions")}}</th></tr></thead> <thead><tr><th>{{ _("Project") }}</th><th>{{ _("Number of members") }}</th><th>{{ _("Number of bills") }}</th><th>{{_("Newest bill")}}</th><th>{{_("Oldest bill")}}</th><th>{{_("Actions")}}</th></tr></thead>
<tbody>{% for project in projects|sort(attribute='name') %} <tbody>{% for project in projects|sort(attribute='name') %}
<tr> <tr>
<td><a href="{{ url_for(".list_bills", project_id=project.id) }}" title="{{ project.name }}">{{ project.name }}</a></td><td>{{ project.members | count }}</td><td>{{ project.get_bills().count() }}</td> <td><a href="{{ url_for(".list_bills", project_id=project.id) }}" title="{{ project.name }}">{{ project.name }}</a></td><td>{{ project.members | count }}</td><td>{{ project.get_bills_unordered().count() }}</td>
{% if project.has_bills() %} {% if project.has_bills() %}
<td>{{ project.get_bills().all()[0].date }}</td> <td>{{ project.get_bills().all()[0].date }}</td>
<td>{{ project.get_bills().all()[-1].date }}</td> <td>{{ project.get_bills().all()[-1].date }}</td>

View file

@ -4,11 +4,14 @@ import json
import re import re
from time import sleep from time import sleep
import unittest import unittest
from unittest.mock import MagicMock
from flask import session from flask import session
import pytest
from werkzeug.security import check_password_hash, generate_password_hash from werkzeug.security import check_password_hash, generate_password_hash
from ihatemoney import models from ihatemoney import models
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
from ihatemoney.versioning import LoggingMode from ihatemoney.versioning import LoggingMode
@ -802,7 +805,8 @@ class BudgetTestCase(IhatemoneyTestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_statistics(self): def test_statistics(self):
self.post_project("raclette") # Output is checked with the USD sign
self.post_project("raclette", default_currency="USD")
# add members # add members
self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2})
@ -1443,6 +1447,225 @@ class BudgetTestCase(IhatemoneyTestCase):
member = models.Person.query.filter(models.Person.id == 1).one_or_none() member = models.Person.query.filter(models.Person.id == 1).one_or_none()
self.assertEqual(member, None) self.assertEqual(member, None)
def test_currency_switch(self):
mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# A project should be editable
self.post_project("raclette")
# add members
self.client.post("/raclette/members/add", data={"name": "zorglub"})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3],
"amount": "10.0",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "red wine",
"payer": 2,
"payed_for": [1, 3],
"amount": "20",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "refund",
"payer": 3,
"payed_for": [2],
"amount": "13.33",
},
)
project = models.Project.query.get("raclette")
# First all converted_amount should be the same as amount, with no currency
for bill in project.get_bills():
assert bill.original_currency == CurrencyConverter.no_currency
assert bill.amount == bill.converted_amount
# Then, switch to EUR, all bills must have been changed to this currency
project.switch_currency("EUR")
for bill in project.get_bills():
assert bill.original_currency == "EUR"
assert bill.amount == bill.converted_amount
# Add a bill in EUR, the current default currency
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "refund from EUR",
"payer": 3,
"payed_for": [2],
"amount": "20",
"original_currency": "EUR",
},
)
last_bill = project.get_bills().first()
assert last_bill.converted_amount == last_bill.amount
# Erase all currencies
project.switch_currency(CurrencyConverter.no_currency)
for bill in project.get_bills():
assert bill.original_currency == CurrencyConverter.no_currency
assert bill.amount == bill.converted_amount
# Let's go back to EUR to test conversion
project.switch_currency("EUR")
# This is a bill in CAD
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "Poutine",
"payer": 3,
"payed_for": [2],
"amount": "18",
"original_currency": "CAD",
},
)
last_bill = project.get_bills().first()
expected_amount = converter.exchange_currency(last_bill.amount, "CAD", "EUR")
assert last_bill.converted_amount == expected_amount
# Switch to USD. Now, NO bill should be in USD, since they already had a currency
project.switch_currency("USD")
for bill in project.get_bills():
assert bill.original_currency != "USD"
expected_amount = converter.exchange_currency(
bill.amount, bill.original_currency, "USD"
)
assert bill.converted_amount == expected_amount
# Switching back to no currency must fail
with pytest.raises(ValueError):
project.switch_currency(CurrencyConverter.no_currency)
# It also must fails with a nice error using the form
resp = self.client.post(
"/raclette/edit",
data={
"name": "demonstration",
"password": "demo",
"contact_email": "demo@notmyidea.org",
"project_history": "y",
"default_currency": converter.no_currency,
},
)
# A user displayed error should be generated, and its currency should be the same.
self.assertStatus(200, resp)
self.assertIn('<p class="alert alert-danger">', resp.data.decode("utf-8"))
self.assertEqual(models.Project.query.get("raclette").default_currency, "USD")
def test_currency_switch_to_bill_currency(self):
mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# Default currency is 'XXX', but we should start from a project with a currency
self.post_project("raclette", default_currency="USD")
# add members
self.client.post("/raclette/members/add", data={"name": "zorglub"})
self.client.post("/raclette/members/add", data={"name": "fred"})
# Bill with a different currency than project's default
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2],
"amount": "10.0",
"original_currency": "EUR",
},
)
project = models.Project.query.get("raclette")
bill = project.get_bills().first()
assert bill.converted_amount == converter.exchange_currency(
bill.amount, "EUR", "USD"
)
# And switch project to the currency from the bill we created
project.switch_currency("EUR")
bill = project.get_bills().first()
assert bill.converted_amount == bill.amount
def test_currency_switch_to_no_currency(self):
mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# Default currency is 'XXX', but we should start from a project with a currency
self.post_project("raclette", default_currency="USD")
# add members
self.client.post("/raclette/members/add", data={"name": "zorglub"})
self.client.post("/raclette/members/add", data={"name": "fred"})
# Bills with a different currency than project's default
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2],
"amount": "10.0",
"original_currency": "EUR",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "aspirine",
"payer": 2,
"payed_for": [1, 2],
"amount": "5.0",
"original_currency": "EUR",
},
)
project = models.Project.query.get("raclette")
for bill in project.get_bills_unordered():
assert bill.converted_amount == converter.exchange_currency(
bill.amount, "EUR", "USD"
)
# And switch project to no currency: amount should be equal to what was submitted
project.switch_currency(converter.no_currency)
no_currency_bills = [
(bill.amount, bill.converted_amount) for bill in project.get_bills()
]
assert no_currency_bills == [(5.0, 5.0), (10.0, 10.0)]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -30,7 +30,7 @@ class BaseTestCase(TestCase):
follow_redirects=True, follow_redirects=True,
) )
def post_project(self, name, follow_redirects=True): def post_project(self, name, follow_redirects=True, default_currency="XXX"):
"""Create a fake project""" """Create a fake project"""
# create the project # create the project
return self.client.post( return self.client.post(
@ -40,18 +40,18 @@ class BaseTestCase(TestCase):
"id": name, "id": name,
"password": name, "password": name,
"contact_email": f"{name}@notmyidea.org", "contact_email": f"{name}@notmyidea.org",
"default_currency": "USD", "default_currency": default_currency,
}, },
follow_redirects=follow_redirects, follow_redirects=follow_redirects,
) )
def create_project(self, name): def create_project(self, name, default_currency="XXX"):
project = models.Project( project = models.Project(
id=name, id=name,
name=str(name), name=str(name),
password=generate_password_hash(name), password=generate_password_hash(name),
contact_email=f"{name}@notmyidea.org", contact_email=f"{name}@notmyidea.org",
default_currency="USD", default_currency=default_currency,
) )
models.db.session.add(project) models.db.session.add(project)
models.db.session.commit() models.db.session.commit()

View file

@ -25,7 +25,7 @@ class HistoryTestCase(IhatemoneyTestCase):
"name": "demo", "name": "demo",
"contact_email": "demo@notmyidea.org", "contact_email": "demo@notmyidea.org",
"password": "demo", "password": "demo",
"default_currency": "USD", "default_currency": "XXX",
} }
if logging_preference != LoggingMode.DISABLED: if logging_preference != LoggingMode.DISABLED:
@ -78,7 +78,7 @@ class HistoryTestCase(IhatemoneyTestCase):
"contact_email": "demo2@notmyidea.org", "contact_email": "demo2@notmyidea.org",
"password": "123456", "password": "123456",
"project_history": "y", "project_history": "y",
"default_currency": "USD", "default_currency": "USD", # Currency changed from default
} }
resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True) resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True)
@ -103,7 +103,7 @@ class HistoryTestCase(IhatemoneyTestCase):
resp.data.decode("utf-8").index("Project renamed "), resp.data.decode("utf-8").index("Project renamed "),
resp.data.decode("utf-8").index("Project private code changed"), resp.data.decode("utf-8").index("Project private code changed"),
) )
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 4) self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 5)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
def test_project_privacy_edit(self): def test_project_privacy_edit(self):
@ -284,7 +284,7 @@ class HistoryTestCase(IhatemoneyTestCase):
self.assertIn( self.assertIn(
"Some entries below contain IP addresses,", resp.data.decode("utf-8") "Some entries below contain IP addresses,", resp.data.decode("utf-8")
) )
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 10) self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1) self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1)
# Generate more operations to confirm additional IP info isn't recorded # Generate more operations to confirm additional IP info isn't recorded
@ -292,8 +292,8 @@ class HistoryTestCase(IhatemoneyTestCase):
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 10) self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 6) self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 7)
# Clear IP Data # Clear IP Data
resp = self.client.post("/demo/strip_ip_addresses", follow_redirects=True) resp = self.client.post("/demo/strip_ip_addresses", follow_redirects=True)
@ -311,7 +311,7 @@ class HistoryTestCase(IhatemoneyTestCase):
"Some entries below contain IP addresses,", resp.data.decode("utf-8") "Some entries below contain IP addresses,", resp.data.decode("utf-8")
) )
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 0) self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 0)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 16) self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 19)
def test_logs_for_common_actions(self): def test_logs_for_common_actions(self):
# adds a member to this project # adds a member to this project

View file

@ -240,7 +240,7 @@ class EmailFailureTestCase(IhatemoneyTestCase):
class TestCurrencyConverter(unittest.TestCase): class TestCurrencyConverter(unittest.TestCase):
converter = CurrencyConverter() converter = CurrencyConverter()
mock_data = {"USD": 1, "EUR": 0.8115} mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1}
converter.get_rates = MagicMock(return_value=mock_data) converter.get_rates = MagicMock(return_value=mock_data)
def test_only_one_instance(self): def test_only_one_instance(self):
@ -249,11 +249,14 @@ class TestCurrencyConverter(unittest.TestCase):
self.assertEqual(one, two) self.assertEqual(one, two)
def test_get_currencies(self): def test_get_currencies(self):
self.assertCountEqual(self.converter.get_currencies(), ["USD", "EUR"]) self.assertCountEqual(
self.converter.get_currencies(),
["USD", "EUR", "CAD", CurrencyConverter.no_currency],
)
def test_exchange_currency(self): def test_exchange_currency(self):
result = self.converter.exchange_currency(100, "USD", "EUR") result = self.converter.exchange_currency(100, "USD", "EUR")
self.assertEqual(result, 81.15) self.assertEqual(result, 80.0)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -36,7 +36,6 @@ from sqlalchemy_continuum import Operation
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from werkzeug.security import check_password_hash, generate_password_hash from werkzeug.security import check_password_hash, generate_password_hash
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.forms import ( from ihatemoney.forms import (
AdminAuthenticationForm, AdminAuthenticationForm,
AuthenticationForm, AuthenticationForm,
@ -400,7 +399,7 @@ def reset_password():
@main.route("/<project_id>/edit", methods=["GET", "POST"]) @main.route("/<project_id>/edit", methods=["GET", "POST"])
def edit_project(): def edit_project():
edit_form = EditProjectForm() edit_form = EditProjectForm(id=g.project.id)
import_form = UploadForm() import_form = UploadForm()
# Import form # Import form
if import_form.validate_on_submit(): if import_form.validate_on_submit():
@ -415,17 +414,6 @@ def edit_project():
# Edit form # Edit form
if edit_form.validate_on_submit(): if edit_form.validate_on_submit():
project = edit_form.update(g.project) project = edit_form.update(g.project)
# Update converted currency
if project.default_currency != CurrencyConverter.no_currency:
for bill in project.get_bills():
if bill.original_currency == CurrencyConverter.no_currency:
bill.original_currency = project.default_currency
bill.converted_amount = CurrencyConverter().exchange_currency(
bill.amount, bill.original_currency, project.default_currency
)
db.session.add(bill)
db.session.add(project) db.session.add(project)
db.session.commit() db.session.commit()