allow basic math ops in amount field for bills form

This commit is contained in:
Ullauri 2018-12-29 21:39:15 -05:00
parent 7cb339c0bb
commit 62fcc5a25b
2 changed files with 87 additions and 1 deletions

View file

@ -8,6 +8,7 @@ from flask import request
from werkzeug.security import generate_password_hash from werkzeug.security import generate_password_hash
from datetime import datetime from datetime import datetime
from re import match
from jinja2 import Markup from jinja2 import Markup
import email_validator import email_validator
@ -44,6 +45,28 @@ class CommaDecimalField(DecimalField):
return super(CommaDecimalField, self).process_formdata(value) return super(CommaDecimalField, self).process_formdata(value)
class CalculatorStringField(StringField):
"""
A class to deal with math ops (+, -, *, /)
in StringField
"""
def process_formdata(self, valuelist):
if valuelist:
error_msg = "Not a valid amount or expression"
value = str(valuelist[0]).replace(" ", "").replace(",", ".")
if not match(r'^[ 0-9\.\+\-\*/\(\)]{0,50}$', value) or "**" in value:
raise ValueError(error_msg)
try:
valuelist[0] = str(eval(value, {"__builtins__": None}, {}))
except (SyntaxError, NameError, TypeError, ZeroDivisionError):
raise ValueError(error_msg)
return super(CalculatorStringField, self).process_formdata(valuelist)
class EditProjectForm(FlaskForm): class EditProjectForm(FlaskForm):
name = StringField(_("Project name"), validators=[Required()]) name = StringField(_("Project name"), validators=[Required()])
password = StringField(_("Private code"), validators=[Required()]) password = StringField(_("Private code"), validators=[Required()])
@ -117,7 +140,7 @@ class BillForm(FlaskForm):
date = DateField(_("Date"), validators=[Required()], default=datetime.now) date = DateField(_("Date"), validators=[Required()], default=datetime.now)
what = StringField(_("What?"), validators=[Required()]) what = StringField(_("What?"), validators=[Required()])
payer = SelectField(_("Payer"), validators=[Required()], coerce=int) payer = SelectField(_("Payer"), validators=[Required()], coerce=int)
amount = CommaDecimalField(_("Amount paid"), validators=[Required()]) amount = CalculatorStringField(_("Amount paid"), validators=[Required()])
payed_for = SelectMultipleField(_("For whom?"), payed_for = SelectMultipleField(_("For whom?"),
validators=[Required()], coerce=int) validators=[Required()], coerce=int)
submit = SubmitField(_("Submit")) submit = SubmitField(_("Submit"))

View file

@ -1349,6 +1349,69 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette")) headers=self.get_auth("raclette"))
self.assertStatus(404, req) self.assertStatus(404, req)
def test_bills_with_calculation(self):
# create a project
self.api_create("raclette")
# add members
self.api_add_member("raclette", "alexis")
self.api_add_member("raclette", "fred")
# add a bill
req = self.client.post("/api/projects/raclette/bills", data={
'date': '2011-08-10',
'what': 'fromage',
'payer': "1",
'payed_for': ["1", "2"],
'amount': '((100 + 200.25) * 2 - 100) / 2',
}, headers=self.get_auth("raclette"))
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data.decode('utf-8'), "1\n")
# get this bill details
req = self.client.get("/api/projects/raclette/bills/1",
headers=self.get_auth("raclette"))
# compare with the added info
self.assertStatus(200, req)
expected = {
"what": "fromage",
"payer_id": 1,
"owers": [
{"activated": True, "id": 1, "name": "alexis", "weight": 1},
{"activated": True, "id": 2, "name": "fred", "weight": 1}],
"amount": 250.25,
"date": "2011-08-10",
"id": 1}
got = json.loads(req.data.decode('utf-8'))
self.assertEqual(
datetime.date.today(),
datetime.datetime.strptime(got["creation_date"], '%Y-%m-%d').date()
)
del got["creation_date"]
self.assertDictEqual(expected, got)
erroneous_amounts = [
"lambda ", # letters
"(20 + 2", # invalid expression
"20/0", # invalid calc
"9999**99999999999999999", # exponents
"2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2" # greater than 50 chars
]
for amount in erroneous_amounts:
req = self.client.post("/api/projects/raclette/bills", data={
'date': '2011-08-10',
'what': 'fromage',
'payer': "1",
'payed_for': ["1", "2"],
'amount': amount,
}, headers=self.get_auth("raclette"))
self.assertStatus(400, req)
def test_statistics(self): def test_statistics(self):
# create a project # create a project
self.api_create("raclette") self.api_create("raclette")