replacing eval with ast parse

This commit is contained in:
Ullauri 2018-12-31 00:31:14 -05:00
parent ca0f5a8bf1
commit a5819e9b48
3 changed files with 36 additions and 7 deletions

View file

@ -14,7 +14,7 @@ from jinja2 import Markup
import email_validator import email_validator
from ihatemoney.models import Project, Person from ihatemoney.models import Project, Person
from ihatemoney.utils import slugify from ihatemoney.utils import slugify, eval_arithmetic_expression
def get_billform_for(project, set_default=True, **kwargs): def get_billform_for(project, set_default=True, **kwargs):
@ -54,15 +54,12 @@ class CalculatorStringField(StringField):
def process_formdata(self, valuelist): def process_formdata(self, valuelist):
if valuelist: if valuelist:
error_msg = "Not a valid amount or expression" error_msg = "Not a valid amount or expression"
value = str(valuelist[0]).replace(" ", "").replace(",", ".") value = str(valuelist[0]).replace(",", ".")
if not match(r'^[ 0-9\.\+\-\*/\(\)]{0,50}$', value) or "**" in value: if not match(r'^[ 0-9\.\+\-\*/\(\)]{0,50}$', value) or "**" in value:
raise ValueError(error_msg) raise ValueError(error_msg)
try: valuelist[0] = str(eval_arithmetic_expression(value))
valuelist[0] = str(eval(value, {"__builtins__": None}, {}))
except (SyntaxError, NameError, TypeError, ZeroDivisionError):
raise ValueError(error_msg)
return super(CalculatorStringField, self).process_formdata(valuelist) return super(CalculatorStringField, self).process_formdata(valuelist)

View file

@ -1399,7 +1399,7 @@ class APITestCase(IhatemoneyTestCase):
"(20 + 2", # invalid expression "(20 + 2", # invalid expression
"20/0", # invalid calc "20/0", # invalid calc
"9999**99999999999999999", # exponents "9999**99999999999999999", # exponents
"2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2" # greater than 50 chars "2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2*2", # greater than 50 chars,
] ]
for amount in erroneous_amounts: for amount in erroneous_amounts:

View file

@ -1,5 +1,7 @@
import base64 import base64
import re import re
import ast
import operator
from io import BytesIO, StringIO from io import BytesIO, StringIO
import jinja2 import jinja2
@ -206,3 +208,33 @@ class IhmJSONEncoder(JSONEncoder):
except ImportError: except ImportError:
pass pass
return JSONEncoder.default(self, o) return JSONEncoder.default(self, o)
def eval_arithmetic_expression(expr):
def _eval(node):
# supported operators
operators = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.USub: operator.neg,
}
if isinstance(node, ast.Num): # <number>
return node.n
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
return operators[type(node.op)](_eval(node.left), _eval(node.right))
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
return operators[type(node.op)](_eval(node.operand))
else:
raise TypeError(node)
expr = str(expr)
try:
result = _eval(ast.parse(expr, mode='eval').body)
except (SyntaxError, TypeError, ZeroDivisionError, KeyError):
raise ValueError("Error evaluating expression: {}".format(expr))
return result