diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 318de686..a1d8b751 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,7 @@ Added ===== - Add CORS headers in the API (#407) - Document database migrations (#390) +- Allow basic math operations in amount field (#413) 3.0 (2018-11-25) ---------------- diff --git a/docs/contributing.rst b/docs/contributing.rst index bcb3f162..0350f01c 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -116,7 +116,7 @@ Collect all new strings to translate:: Compile them into *.mo* files:: - $ make compile-translations + $ make build-translations Commit both *.mo* and *.po*. diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py index e8d437bf..5374fd9f 100644 --- a/ihatemoney/forms.py +++ b/ihatemoney/forms.py @@ -8,12 +8,13 @@ from flask import request from werkzeug.security import generate_password_hash from datetime import datetime +from re import match from jinja2 import Markup import email_validator from ihatemoney.models import Project, Person -from ihatemoney.utils import slugify +from ihatemoney.utils import slugify, eval_arithmetic_expression def get_billform_for(project, set_default=True, **kwargs): @@ -44,6 +45,30 @@ class CommaDecimalField(DecimalField): return super(CommaDecimalField, self).process_formdata(value) +class CalculatorStringField(StringField): + """ + A class to deal with math ops (+, -, *, /) + in StringField + """ + + def process_formdata(self, valuelist): + if valuelist: + message = _( + "Not a valid amount or expression." + "Only numbers and + - * / operators" + "are accepted." + ) + value = str(valuelist[0]).replace(",", ".") + + # avoid exponents to prevent expensive calculations i.e 2**9999999999**9999999 + if not match(r'^[ 0-9\.\+\-\*/\(\)]{0,200}$', value) or "**" in value: + raise ValueError(Markup(message)) + + valuelist[0] = str(eval_arithmetic_expression(value)) + + return super(CalculatorStringField, self).process_formdata(valuelist) + + class EditProjectForm(FlaskForm): name = StringField(_("Project name"), validators=[Required()]) password = StringField(_("Private code"), validators=[Required()]) @@ -117,7 +142,7 @@ class BillForm(FlaskForm): date = DateField(_("Date"), validators=[Required()], default=datetime.now) what = StringField(_("What?"), validators=[Required()]) payer = SelectField(_("Payer"), validators=[Required()], coerce=int) - amount = CommaDecimalField(_("Amount paid"), validators=[Required()]) + amount = CalculatorStringField(_("Amount paid"), validators=[Required()]) payed_for = SelectMultipleField(_("For whom?"), validators=[Required()], coerce=int) submit = SubmitField(_("Submit")) diff --git a/ihatemoney/messages.pot b/ihatemoney/messages.pot index bd4dcc49..ba81d225 100644 --- a/ihatemoney/messages.pot +++ b/ihatemoney/messages.pot @@ -1,3 +1,8 @@ +msgid "" +"Not a valid amount or expression.Only numbers and + - * / operatorsare " +"accepted." +msgstr "" + msgid "Project name" msgstr "" @@ -364,6 +369,9 @@ msgstr "" msgid "Add a new bill" msgstr "" +msgid "Added on" +msgstr "" + msgid "When?" msgstr "" diff --git a/ihatemoney/tests/tests.py b/ihatemoney/tests/tests.py index 2f3d4ac2..9f9d8fab 100644 --- a/ihatemoney/tests/tests.py +++ b/ihatemoney/tests/tests.py @@ -1349,6 +1349,87 @@ class APITestCase(IhatemoneyTestCase): headers=self.get_auth("raclette")) self.assertStatus(404, req) + def test_bills_with_calculation(self): + # create a project + self.api_create("raclette") + + # add members + self.api_add_member("raclette", "alexis") + self.api_add_member("raclette", "fred") + + # valid amounts + input_expected = [ + ("((100 + 200.25) * 2 - 100) / 2", 250.25), + ("3/2", 1.5), + ("2 + 1 * 5 - 2 / 1", 5), + ] + + for i, pair in enumerate(input_expected): + input_amount, expected_amount = pair + id = i + 1 + + req = self.client.post( + "/api/projects/raclette/bills", + data={ + 'date': '2011-08-10', + 'what': 'fromage', + 'payer': "1", + 'payed_for': ["1", "2"], + 'amount': input_amount, + }, + headers=self.get_auth("raclette") + ) + + # should return the id + self.assertStatus(201, req) + self.assertEqual(req.data.decode('utf-8'), "{}\n".format(id)) + + # get this bill's details + req = self.client.get( + "/api/projects/raclette/bills/{}".format(id), + headers=self.get_auth("raclette") + ) + + # compare with the added info + self.assertStatus(200, req) + expected = { + "what": "fromage", + "payer_id": 1, + "owers": [ + {"activated": True, "id": 1, "name": "alexis", "weight": 1}, + {"activated": True, "id": 2, "name": "fred", "weight": 1}], + "amount": expected_amount, + "date": "2011-08-10", + "id": id, + } + + got = json.loads(req.data.decode('utf-8')) + self.assertEqual( + datetime.date.today(), + datetime.datetime.strptime(got["creation_date"], '%Y-%m-%d').date() + ) + del got["creation_date"] + self.assertDictEqual(expected, got) + + # should raise errors + erroneous_amounts = [ + "lambda ", # letters + "(20 + 2", # invalid expression + "20/0", # invalid calc + "9999**99999999999999999", # exponents + "2" * 201, # greater than 200 chars, + ] + + for amount in erroneous_amounts: + req = self.client.post("/api/projects/raclette/bills", data={ + 'date': '2011-08-10', + 'what': 'fromage', + 'payer': "1", + 'payed_for': ["1", "2"], + 'amount': amount, + }, headers=self.get_auth("raclette")) + self.assertStatus(400, req) + def test_statistics(self): # create a project self.api_create("raclette") diff --git a/ihatemoney/translations/fr/LC_MESSAGES/messages.mo b/ihatemoney/translations/fr/LC_MESSAGES/messages.mo index ab8a8316..44b0c6ae 100644 Binary files a/ihatemoney/translations/fr/LC_MESSAGES/messages.mo and b/ihatemoney/translations/fr/LC_MESSAGES/messages.mo differ diff --git a/ihatemoney/translations/fr/LC_MESSAGES/messages.po b/ihatemoney/translations/fr/LC_MESSAGES/messages.po index b3e2fdbe..2516d568 100644 --- a/ihatemoney/translations/fr/LC_MESSAGES/messages.po +++ b/ihatemoney/translations/fr/LC_MESSAGES/messages.po @@ -7,7 +7,7 @@ msgid "" msgstr "" "Project-Id-Version: PROJECT VERSION\n" "Report-Msgid-Bugs-To: EMAIL@ADDRESS\n" -"POT-Creation-Date: 2018-08-05 23:41+0200\n" +"POT-Creation-Date: 2019-01-02 23:26-0500\n" "PO-Revision-Date: 2018-05-15 22:00+0200\n" "Last-Translator: Adrien CLERC <>\n" "Language: fr\n" @@ -18,6 +18,13 @@ msgstr "" "Content-Transfer-Encoding: 8bit\n" "Generated-By: Babel 2.6.0\n" +msgid "" +"Not a valid amount or expression.Only numbers and + - * / operators are " +"accepted." +msgstr "" +"Pas un montant ou une expression valide. Seuls les nombres et les opérateurs" +"+ - * / sont acceptés" + msgid "Project name" msgstr "Nom de projet" @@ -394,6 +401,9 @@ msgstr "Invitez d’autres personnes à rejoindre ce projet !" msgid "Add a new bill" msgstr "Nouvelle facture" +msgid "Added on" +msgstr "Ajouté Sur" + msgid "When?" msgstr "Quand ?" @@ -494,3 +504,4 @@ msgstr "Solde" #~ msgid "Invite" #~ msgstr "Invitez" + diff --git a/ihatemoney/utils.py b/ihatemoney/utils.py index ec228343..2fac4efb 100644 --- a/ihatemoney/utils.py +++ b/ihatemoney/utils.py @@ -1,5 +1,8 @@ +from __future__ import division import base64 import re +import ast +import operator from io import BytesIO, StringIO import jinja2 @@ -206,3 +209,33 @@ class IhmJSONEncoder(JSONEncoder): except ImportError: pass return JSONEncoder.default(self, o) + + +def eval_arithmetic_expression(expr): + def _eval(node): + # supported operators + operators = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.USub: operator.neg, + } + + if isinstance(node, ast.Num): # + return node.n + elif isinstance(node, ast.BinOp): # + return operators[type(node.op)](_eval(node.left), _eval(node.right)) + elif isinstance(node, ast.UnaryOp): # e.g., -1 + return operators[type(node.op)](_eval(node.operand)) + else: + raise TypeError(node) + + expr = str(expr) + + try: + result = _eval(ast.parse(expr, mode='eval').body) + except (SyntaxError, TypeError, ZeroDivisionError, KeyError): + raise ValueError("Error evaluating expression: {}".format(expr)) + + return result