diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py index e8d437bf..8b7f6efd 100644 --- a/ihatemoney/forms.py +++ b/ihatemoney/forms.py @@ -8,6 +8,7 @@ from flask import request from werkzeug.security import generate_password_hash from datetime import datetime +from re import match from jinja2 import Markup import email_validator @@ -44,6 +45,28 @@ class CommaDecimalField(DecimalField): return super(CommaDecimalField, self).process_formdata(value) +class CalculatorStringField(StringField): + """ + A class to deal with math ops (+, -, *, /) + in StringField + """ + + def process_formdata(self, valuelist): + if valuelist: + error_msg = "Not a valid amount or expression" + value = str(valuelist[0]).replace(" ", "").replace(",", ".") + + if not match(r'^[ 0-9\.\+\-\*/\(\)]{0,50}$', value) or "**" in value: + raise ValueError(error_msg) + + try: + valuelist[0] = str(eval(value, {"__builtins__": None}, {})) + except (SyntaxError, NameError, TypeError, ZeroDivisionError): + raise ValueError(error_msg) + + return super(CalculatorStringField, self).process_formdata(valuelist) + + class EditProjectForm(FlaskForm): name = StringField(_("Project name"), validators=[Required()]) password = StringField(_("Private code"), validators=[Required()]) @@ -117,7 +140,7 @@ class BillForm(FlaskForm): date = DateField(_("Date"), validators=[Required()], default=datetime.now) what = StringField(_("What?"), validators=[Required()]) payer = SelectField(_("Payer"), validators=[Required()], coerce=int) - amount = CommaDecimalField(_("Amount paid"), validators=[Required()]) + amount = CalculatorStringField(_("Amount paid"), validators=[Required()]) payed_for = SelectMultipleField(_("For whom?"), validators=[Required()], coerce=int) submit = SubmitField(_("Submit")) diff --git a/ihatemoney/tests/tests.py b/ihatemoney/tests/tests.py index 2f3d4ac2..d7af9d63 100644 --- a/ihatemoney/tests/tests.py +++ b/ihatemoney/tests/tests.py @@ -1349,6 +1349,69 @@ class APITestCase(IhatemoneyTestCase): headers=self.get_auth("raclette")) self.assertStatus(404, req) + def test_bills_with_calculation(self): + # create a project + self.api_create("raclette") + + # add members + self.api_add_member("raclette", "alexis") + self.api_add_member("raclette", "fred") + + # add a bill + req = self.client.post("/api/projects/raclette/bills", data={ + 'date': '2011-08-10', + 'what': 'fromage', + 'payer': "1", + 'payed_for': ["1", "2"], + 'amount': '((100 + 200.25) * 2 - 100) / 2', + }, headers=self.get_auth("raclette")) + + # should return the id + self.assertStatus(201, req) + self.assertEqual(req.data.decode('utf-8'), "1\n") + + # get this bill details + req = self.client.get("/api/projects/raclette/bills/1", + headers=self.get_auth("raclette")) + + # compare with the added info + self.assertStatus(200, req) + expected = { + "what": "fromage", + "payer_id": 1, + "owers": [ + {"activated": True, "id": 1, "name": "alexis", "weight": 1}, + {"activated": True, "id": 2, "name": "fred", "weight": 1}], + "amount": 250.25, + "date": "2011-08-10", + "id": 1} + + got = json.loads(req.data.decode('utf-8')) + self.assertEqual( + datetime.date.today(), + datetime.datetime.strptime(got["creation_date"], '%Y-%m-%d').date() + ) + del got["creation_date"] + self.assertDictEqual(expected, got) + + erroneous_amounts = [ + "lambda ", # letters + "(20 + 2", # invalid expression + "20/0", # invalid calc + "9999**99999999999999999", # exponents + "2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2" # greater than 50 chars + ] + + for amount in erroneous_amounts: + req = self.client.post("/api/projects/raclette/bills", data={ + 'date': '2011-08-10', + 'what': 'fromage', + 'payer': "1", + 'payed_for': ["1", "2"], + 'amount': amount, + }, headers=self.get_auth("raclette")) + self.assertStatus(400, req) + def test_statistics(self): # create a project self.api_create("raclette")