Rework currency switching (#661)

Co-authored-by: Alexis Métaireau <alexis@notmyidea.org>

Currency switching is both simpler and less powerful. This was done primarily for users, to have a clear and logical understanding, but the code is also simpler. The main change is that it is now forbidden to switch a project to "no currency" if bills don't share the same currency.

Also, tests assume that projects are created without currency, as in the web UI.
This commit is contained in:
Glandos 2021-07-06 21:51:32 +02:00 committed by GitHub
parent fec5a82b0c
commit 07b86bc580
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 322 additions and 40 deletions

View file

@ -69,7 +69,7 @@ class ProjectHandler(Resource):
return "DELETED"
def put(self, project):
form = EditProjectForm(meta={"csrf": False})
form = EditProjectForm(id=project.id, meta={"csrf": False})
if form.validate() and current_app.config.get("ALLOW_PUBLIC_PROJECT_CREATION"):
form.update(project)
db.session.commit()

View file

@ -1,5 +1,6 @@
from datetime import datetime
from re import match
from types import SimpleNamespace
import email_validator
from flask import request
@ -110,6 +111,14 @@ class EditProjectForm(FlaskForm):
default_currency = SelectField(_("Default Currency"), validators=[DataRequired()])
def __init__(self, *args, **kwargs):
if not hasattr(self, "id"):
# We must access the project to validate the default currency, using its id.
# In ProjectForm, 'id' is provided, but not in this base class, so it *must*
# be provided by callers.
# Since id can be defined as a WTForms.StringField, we mimics it,
# using an object that can have a 'data' attribute.
# It defaults to empty string to ensure that query run smoothly.
self.id = SimpleNamespace(data=kwargs.pop("id", ""))
super().__init__(*args, **kwargs)
self.default_currency.choices = [
(currency_name, render_localized_currency(currency_name, detailed=True))
@ -142,6 +151,22 @@ class EditProjectForm(FlaskForm):
)
return project
def validate_default_currency(form, field):
project = Project.query.get(form.id.data)
if (
project is not None
and field.data == CurrencyConverter.no_currency
and project.has_multiple_currencies()
):
raise ValidationError(
_(
(
"This project cannot be set to 'no currency'"
" because it contains bills in multiple currencies."
)
)
)
def update(self, project):
"""Update the project with the information from the form"""
project.name = self.name.data
@ -152,7 +177,7 @@ class EditProjectForm(FlaskForm):
project.contact_email = self.contact_email.data
project.logging_preference = self.logging_preference
project.default_currency = self.default_currency.data
project.switch_currency(self.default_currency.data)
return project

View file

@ -17,6 +17,7 @@ from sqlalchemy_continuum import make_versioned, version_class
from sqlalchemy_continuum.plugins import FlaskPlugin
from werkzeug.security import generate_password_hash
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder
from ihatemoney.versioning import (
ConditionalVersioningManager,
@ -139,7 +140,7 @@ class Project(db.Model):
"spent": sum(
[
bill.pay_each() * member.weight
for bill in self.get_bills().all()
for bill in self.get_bills_unordered().all()
if member in bill.owers
]
),
@ -156,7 +157,7 @@ class Project(db.Model):
:rtype dict:
"""
monthly = defaultdict(lambda: defaultdict(float))
for bill in self.get_bills().all():
for bill in self.get_bills_unordered().all():
monthly[bill.date.year][bill.date.month] += bill.converted_amount
return monthly
@ -215,15 +216,25 @@ class Project(db.Model):
def has_bills(self):
"""return if the project do have bills or not"""
return self.get_bills().count() > 0
return self.get_bills_unordered().count() > 0
def get_bills(self):
"""Return the list of bills related to this project"""
def has_multiple_currencies(self):
"""Return if multiple currencies are used"""
return self.get_bills_unordered().group_by(Bill.original_currency).count() > 1
def get_bills_unordered(self):
"""Base query for bill list"""
return (
Bill.query.join(Person, Project)
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
.filter(Project.id == self.id)
)
def get_bills(self):
"""Return the list of bills related to this project"""
return (
self.get_bills_unordered()
.order_by(Bill.date.desc())
.order_by(Bill.creation_date.desc())
.order_by(Bill.id.desc())
@ -232,11 +243,8 @@ class Project(db.Model):
def get_member_bills(self, member_id):
"""Return the list of bills related to a specific member"""
return (
Bill.query.join(Person, Project)
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
self.get_bills_unordered()
.filter(Person.id == member_id)
.filter(Project.id == self.id)
.order_by(Bill.date.desc())
.order_by(Bill.id.desc())
)
@ -263,6 +271,41 @@ class Project(db.Model):
)
return pretty_bills
def switch_currency(self, new_currency):
if new_currency == self.default_currency:
return
# Update converted currency
if new_currency == CurrencyConverter.no_currency:
if self.has_multiple_currencies():
raise ValueError(f"Can't unset currency of project {self.id}")
for bill in self.get_bills_unordered():
# We are removing the currency, and we already checked that all bills
# had the same currency: it means that we can simply strip the currency
# without converting the amounts. We basically ignore the current default_currency
# Reset converted amount in case it was different from the original amount
bill.converted_amount = bill.amount
# Strip currency
bill.original_currency = CurrencyConverter.no_currency
db.session.add(bill)
else:
for bill in self.get_bills_unordered():
if bill.original_currency == CurrencyConverter.no_currency:
# Bills that were created without currency will be set to the new currency
bill.original_currency = new_currency
bill.converted_amount = bill.amount
else:
# Convert amount for others, without touching original_currency
bill.converted_amount = CurrencyConverter().exchange_currency(
bill.amount, bill.original_currency, new_currency
)
db.session.add(bill)
self.default_currency = new_currency
db.session.add(self)
db.session.commit()
def remove_member(self, member_id):
"""Remove a member from the project.

View file

@ -5,7 +5,7 @@
<thead><tr><th>{{ _("Project") }}</th><th>{{ _("Number of members") }}</th><th>{{ _("Number of bills") }}</th><th>{{_("Newest bill")}}</th><th>{{_("Oldest bill")}}</th><th>{{_("Actions")}}</th></tr></thead>
<tbody>{% for project in projects|sort(attribute='name') %}
<tr>
<td><a href="{{ url_for(".list_bills", project_id=project.id) }}" title="{{ project.name }}">{{ project.name }}</a></td><td>{{ project.members | count }}</td><td>{{ project.get_bills().count() }}</td>
<td><a href="{{ url_for(".list_bills", project_id=project.id) }}" title="{{ project.name }}">{{ project.name }}</a></td><td>{{ project.members | count }}</td><td>{{ project.get_bills_unordered().count() }}</td>
{% if project.has_bills() %}
<td>{{ project.get_bills().all()[0].date }}</td>
<td>{{ project.get_bills().all()[-1].date }}</td>

View file

@ -4,11 +4,14 @@ import json
import re
from time import sleep
import unittest
from unittest.mock import MagicMock
from flask import session
import pytest
from werkzeug.security import check_password_hash, generate_password_hash
from ihatemoney import models
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
from ihatemoney.versioning import LoggingMode
@ -802,7 +805,8 @@ class BudgetTestCase(IhatemoneyTestCase):
self.assertEqual(response.status_code, 200)
def test_statistics(self):
self.post_project("raclette")
# Output is checked with the USD sign
self.post_project("raclette", default_currency="USD")
# add members
self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2})
@ -1443,6 +1447,225 @@ class BudgetTestCase(IhatemoneyTestCase):
member = models.Person.query.filter(models.Person.id == 1).one_or_none()
self.assertEqual(member, None)
def test_currency_switch(self):
mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# A project should be editable
self.post_project("raclette")
# add members
self.client.post("/raclette/members/add", data={"name": "zorglub"})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3],
"amount": "10.0",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "red wine",
"payer": 2,
"payed_for": [1, 3],
"amount": "20",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "refund",
"payer": 3,
"payed_for": [2],
"amount": "13.33",
},
)
project = models.Project.query.get("raclette")
# First all converted_amount should be the same as amount, with no currency
for bill in project.get_bills():
assert bill.original_currency == CurrencyConverter.no_currency
assert bill.amount == bill.converted_amount
# Then, switch to EUR, all bills must have been changed to this currency
project.switch_currency("EUR")
for bill in project.get_bills():
assert bill.original_currency == "EUR"
assert bill.amount == bill.converted_amount
# Add a bill in EUR, the current default currency
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "refund from EUR",
"payer": 3,
"payed_for": [2],
"amount": "20",
"original_currency": "EUR",
},
)
last_bill = project.get_bills().first()
assert last_bill.converted_amount == last_bill.amount
# Erase all currencies
project.switch_currency(CurrencyConverter.no_currency)
for bill in project.get_bills():
assert bill.original_currency == CurrencyConverter.no_currency
assert bill.amount == bill.converted_amount
# Let's go back to EUR to test conversion
project.switch_currency("EUR")
# This is a bill in CAD
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "Poutine",
"payer": 3,
"payed_for": [2],
"amount": "18",
"original_currency": "CAD",
},
)
last_bill = project.get_bills().first()
expected_amount = converter.exchange_currency(last_bill.amount, "CAD", "EUR")
assert last_bill.converted_amount == expected_amount
# Switch to USD. Now, NO bill should be in USD, since they already had a currency
project.switch_currency("USD")
for bill in project.get_bills():
assert bill.original_currency != "USD"
expected_amount = converter.exchange_currency(
bill.amount, bill.original_currency, "USD"
)
assert bill.converted_amount == expected_amount
# Switching back to no currency must fail
with pytest.raises(ValueError):
project.switch_currency(CurrencyConverter.no_currency)
# It also must fails with a nice error using the form
resp = self.client.post(
"/raclette/edit",
data={
"name": "demonstration",
"password": "demo",
"contact_email": "demo@notmyidea.org",
"project_history": "y",
"default_currency": converter.no_currency,
},
)
# A user displayed error should be generated, and its currency should be the same.
self.assertStatus(200, resp)
self.assertIn('<p class="alert alert-danger">', resp.data.decode("utf-8"))
self.assertEqual(models.Project.query.get("raclette").default_currency, "USD")
def test_currency_switch_to_bill_currency(self):
mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# Default currency is 'XXX', but we should start from a project with a currency
self.post_project("raclette", default_currency="USD")
# add members
self.client.post("/raclette/members/add", data={"name": "zorglub"})
self.client.post("/raclette/members/add", data={"name": "fred"})
# Bill with a different currency than project's default
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2],
"amount": "10.0",
"original_currency": "EUR",
},
)
project = models.Project.query.get("raclette")
bill = project.get_bills().first()
assert bill.converted_amount == converter.exchange_currency(
bill.amount, "EUR", "USD"
)
# And switch project to the currency from the bill we created
project.switch_currency("EUR")
bill = project.get_bills().first()
assert bill.converted_amount == bill.amount
def test_currency_switch_to_no_currency(self):
mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# Default currency is 'XXX', but we should start from a project with a currency
self.post_project("raclette", default_currency="USD")
# add members
self.client.post("/raclette/members/add", data={"name": "zorglub"})
self.client.post("/raclette/members/add", data={"name": "fred"})
# Bills with a different currency than project's default
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2],
"amount": "10.0",
"original_currency": "EUR",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "aspirine",
"payer": 2,
"payed_for": [1, 2],
"amount": "5.0",
"original_currency": "EUR",
},
)
project = models.Project.query.get("raclette")
for bill in project.get_bills_unordered():
assert bill.converted_amount == converter.exchange_currency(
bill.amount, "EUR", "USD"
)
# And switch project to no currency: amount should be equal to what was submitted
project.switch_currency(converter.no_currency)
no_currency_bills = [
(bill.amount, bill.converted_amount) for bill in project.get_bills()
]
assert no_currency_bills == [(5.0, 5.0), (10.0, 10.0)]
if __name__ == "__main__":
unittest.main()

View file

@ -30,7 +30,7 @@ class BaseTestCase(TestCase):
follow_redirects=True,
)
def post_project(self, name, follow_redirects=True):
def post_project(self, name, follow_redirects=True, default_currency="XXX"):
"""Create a fake project"""
# create the project
return self.client.post(
@ -40,18 +40,18 @@ class BaseTestCase(TestCase):
"id": name,
"password": name,
"contact_email": f"{name}@notmyidea.org",
"default_currency": "USD",
"default_currency": default_currency,
},
follow_redirects=follow_redirects,
)
def create_project(self, name):
def create_project(self, name, default_currency="XXX"):
project = models.Project(
id=name,
name=str(name),
password=generate_password_hash(name),
contact_email=f"{name}@notmyidea.org",
default_currency="USD",
default_currency=default_currency,
)
models.db.session.add(project)
models.db.session.commit()

View file

@ -25,7 +25,7 @@ class HistoryTestCase(IhatemoneyTestCase):
"name": "demo",
"contact_email": "demo@notmyidea.org",
"password": "demo",
"default_currency": "USD",
"default_currency": "XXX",
}
if logging_preference != LoggingMode.DISABLED:
@ -78,7 +78,7 @@ class HistoryTestCase(IhatemoneyTestCase):
"contact_email": "demo2@notmyidea.org",
"password": "123456",
"project_history": "y",
"default_currency": "USD",
"default_currency": "USD", # Currency changed from default
}
resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True)
@ -103,7 +103,7 @@ class HistoryTestCase(IhatemoneyTestCase):
resp.data.decode("utf-8").index("Project renamed "),
resp.data.decode("utf-8").index("Project private code changed"),
)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 4)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 5)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
def test_project_privacy_edit(self):
@ -284,7 +284,7 @@ class HistoryTestCase(IhatemoneyTestCase):
self.assertIn(
"Some entries below contain IP addresses,", resp.data.decode("utf-8")
)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 10)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1)
# Generate more operations to confirm additional IP info isn't recorded
@ -292,8 +292,8 @@ class HistoryTestCase(IhatemoneyTestCase):
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 10)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 6)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 7)
# Clear IP Data
resp = self.client.post("/demo/strip_ip_addresses", follow_redirects=True)
@ -311,7 +311,7 @@ class HistoryTestCase(IhatemoneyTestCase):
"Some entries below contain IP addresses,", resp.data.decode("utf-8")
)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 0)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 16)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 19)
def test_logs_for_common_actions(self):
# adds a member to this project

View file

@ -240,7 +240,7 @@ class EmailFailureTestCase(IhatemoneyTestCase):
class TestCurrencyConverter(unittest.TestCase):
converter = CurrencyConverter()
mock_data = {"USD": 1, "EUR": 0.8115}
mock_data = {"USD": 1, "EUR": 0.8, "CAD": 1.2, CurrencyConverter.no_currency: 1}
converter.get_rates = MagicMock(return_value=mock_data)
def test_only_one_instance(self):
@ -249,11 +249,14 @@ class TestCurrencyConverter(unittest.TestCase):
self.assertEqual(one, two)
def test_get_currencies(self):
self.assertCountEqual(self.converter.get_currencies(), ["USD", "EUR"])
self.assertCountEqual(
self.converter.get_currencies(),
["USD", "EUR", "CAD", CurrencyConverter.no_currency],
)
def test_exchange_currency(self):
result = self.converter.exchange_currency(100, "USD", "EUR")
self.assertEqual(result, 81.15)
self.assertEqual(result, 80.0)
if __name__ == "__main__":

View file

@ -36,7 +36,6 @@ from sqlalchemy_continuum import Operation
from werkzeug.exceptions import NotFound
from werkzeug.security import check_password_hash, generate_password_hash
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.forms import (
AdminAuthenticationForm,
AuthenticationForm,
@ -400,7 +399,7 @@ def reset_password():
@main.route("/<project_id>/edit", methods=["GET", "POST"])
def edit_project():
edit_form = EditProjectForm()
edit_form = EditProjectForm(id=g.project.id)
import_form = UploadForm()
# Import form
if import_form.validate_on_submit():
@ -415,17 +414,6 @@ def edit_project():
# Edit form
if edit_form.validate_on_submit():
project = edit_form.update(g.project)
# Update converted currency
if project.default_currency != CurrencyConverter.no_currency:
for bill in project.get_bills():
if bill.original_currency == CurrencyConverter.no_currency:
bill.original_currency = project.default_currency
bill.converted_amount = CurrencyConverter().exchange_currency(
bill.amount, bill.original_currency, project.default_currency
)
db.session.add(bill)
db.session.add(project)
db.session.commit()