diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py
index 1bfb0fe2..17711989 100644
--- a/ihatemoney/forms.py
+++ b/ihatemoney/forms.py
@@ -121,6 +121,11 @@ class CalculatorStringField(StringField):
class EditProjectForm(FlaskForm):
name = StringField(_("Project name"), validators=[DataRequired()])
+ current_password = PasswordField(
+ _("Current private code"),
+ description=_("Enter existing private code to edit project"),
+ validators=[DataRequired()],
+ )
# If empty -> don't change the password
password = PasswordField(
_("New private code"),
@@ -154,6 +159,13 @@ class EditProjectForm(FlaskForm):
for currency_name in self.currency_helper.get_currencies()
]
+ def validate_current_password(self, field):
+ project = Project.query.get(self.id.data)
+ if project is None:
+ raise ValidationError(_("Unknown error"))
+ if not check_password_hash(project.password, self.current_password.data):
+ raise ValidationError(_("Invalid private code."))
+
@property
def logging_preference(self):
"""Get the LoggingMode object corresponding to current form data."""
@@ -212,7 +224,9 @@ class ImportProjectForm(FlaskForm):
class ProjectForm(EditProjectForm):
id = StringField(_("Project identifier"), validators=[DataRequired()])
- # This field overrides the one from EditProjectForm
+ # Remove this field that is inherited from EditProjectForm
+ current_password = None
+ # This field overrides the one from EditProjectForm (to make it mandatory)
password = PasswordField(_("Private code"), validators=[DataRequired()])
submit = SubmitField(_("Create the project"))
diff --git a/ihatemoney/templates/forms.html b/ihatemoney/templates/forms.html
index 26eb376a..e339268e 100644
--- a/ihatemoney/templates/forms.html
+++ b/ihatemoney/templates/forms.html
@@ -100,6 +100,7 @@
{{ input(form.default_currency) }}
+ {{ input(form.current_password) }}
diff --git a/ihatemoney/tests/api_test.py b/ihatemoney/tests/api_test.py
index b5f11e3e..73d917da 100644
--- a/ihatemoney/tests/api_test.py
+++ b/ihatemoney/tests/api_test.py
@@ -131,7 +131,7 @@ class APITestCase(IhatemoneyTestCase):
decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected)
- # edit should work
+ # edit should fail if we don't provide the current private code
resp = self.client.put(
"/api/projects/raclette",
data={
@@ -143,7 +143,36 @@ class APITestCase(IhatemoneyTestCase):
},
headers=self.get_auth("raclette"),
)
+ self.assertEqual(400, resp.status_code)
+ # edit should fail if we provide the wrong private code
+ resp = self.client.put(
+ "/api/projects/raclette",
+ data={
+ "contact_email": "yeah@notmyidea.org",
+ "default_currency": "XXX",
+ "current_password": "fromage aux patates",
+ "password": "raclette",
+ "name": "The raclette party",
+ "project_history": "y",
+ },
+ headers=self.get_auth("raclette"),
+ )
+ self.assertEqual(400, resp.status_code)
+
+ # edit with the correct private code should work
+ resp = self.client.put(
+ "/api/projects/raclette",
+ data={
+ "contact_email": "yeah@notmyidea.org",
+ "default_currency": "XXX",
+ "current_password": "raclette",
+ "password": "raclette",
+ "name": "The raclette party",
+ "project_history": "y",
+ },
+ headers=self.get_auth("raclette"),
+ )
self.assertEqual(200, resp.status_code)
resp = self.client.get(
@@ -168,6 +197,7 @@ class APITestCase(IhatemoneyTestCase):
data={
"contact_email": "yeah@notmyidea.org",
"default_currency": "XXX",
+ "current_password": "raclette",
"password": "tartiflette",
"name": "The raclette party",
},
@@ -213,9 +243,23 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/token",
headers={"Authorization": f"Basic {decoded_resp['token']}"},
)
-
self.assertEqual(200, resp.status_code)
+ # We shouldn't be able to edit project without private code
+ resp = self.client.put(
+ "/api/projects/raclette",
+ data={
+ "contact_email": "yeah@notmyidea.org",
+ "default_currency": "XXX",
+ "password": "tartiflette",
+ "name": "The raclette party",
+ },
+ headers={"Authorization": f"Basic {decoded_resp['token']}"},
+ )
+ self.assertEqual(400, resp.status_code)
+ expected_resp = {"current_password": ["This field is required."]}
+ self.assertEqual(expected_resp, json.loads(resp.data.decode("utf-8")))
+
def test_token_login(self):
resp = self.api_create("raclette")
# Get token
@@ -719,6 +763,7 @@ class APITestCase(IhatemoneyTestCase):
data={
"contact_email": "yeah@notmyidea.org",
"default_currency": "XXX",
+ "current_password": "raclette",
"password": "raclette",
"name": "The raclette party",
},
diff --git a/ihatemoney/tests/budget_test.py b/ihatemoney/tests/budget_test.py
index bac56507..1b979223 100644
--- a/ihatemoney/tests/budget_test.py
+++ b/ihatemoney/tests/budget_test.py
@@ -181,6 +181,7 @@ class BudgetTestCase(IhatemoneyTestCase):
data={
"name": "raclette",
"contact_email": "zorglub@notmyidea.org",
+ "current_password": "raclette",
"password": "didoudida",
"default_currency": "XXX",
},
@@ -922,10 +923,30 @@ class BudgetTestCase(IhatemoneyTestCase):
"default_currency": "USD",
}
- resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True)
- self.assertEqual(resp.status_code, 200)
+ # It should fail if we don't provide the current password
+ resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=False)
+ self.assertIn("This field is required", resp.data.decode("utf-8"))
project = self.get_project("raclette")
+ self.assertNotEqual(project.name, new_data["name"])
+ self.assertNotEqual(project.contact_email, new_data["contact_email"])
+ self.assertNotEqual(project.default_currency, new_data["default_currency"])
+ self.assertFalse(check_password_hash(project.password, new_data["password"]))
+ # It should fail if we provide the wrong current password
+ new_data["current_password"] = "patates au fromage"
+ resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=False)
+ self.assertIn("Invalid private code", resp.data.decode("utf-8"))
+ project = self.get_project("raclette")
+ self.assertNotEqual(project.name, new_data["name"])
+ self.assertNotEqual(project.contact_email, new_data["contact_email"])
+ self.assertNotEqual(project.default_currency, new_data["default_currency"])
+ self.assertFalse(check_password_hash(project.password, new_data["password"]))
+
+ # It should work if we give the current private code
+ new_data["current_password"] = "raclette"
+ resp = self.client.post("/raclette/edit", data=new_data)
+ self.assertEqual(resp.status_code, 302)
+ project = self.get_project("raclette")
self.assertEqual(project.name, new_data["name"])
self.assertEqual(project.contact_email, new_data["contact_email"])
self.assertEqual(project.default_currency, new_data["default_currency"])
@@ -934,7 +955,7 @@ class BudgetTestCase(IhatemoneyTestCase):
# Editing a project with a wrong email address should fail
new_data["contact_email"] = "wrong_email"
- resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True)
+ resp = self.client.post("/raclette/edit", data=new_data)
self.assertIn("Invalid email address", resp.data.decode("utf-8"))
def test_dashboard(self):
@@ -2039,6 +2060,7 @@ class BudgetTestCase(IhatemoneyTestCase):
data={
"name": "raclette",
"contact_email": "zorglub@notmyidea.org",
+ "current_password": "raclette",
"password": "didoudida",
"default_currency": "XXX",
},
diff --git a/ihatemoney/tests/history_test.py b/ihatemoney/tests/history_test.py
index 1cc15ced..9b4ec33f 100644
--- a/ihatemoney/tests/history_test.py
+++ b/ihatemoney/tests/history_test.py
@@ -19,11 +19,12 @@ class HistoryTestCase(IhatemoneyTestCase):
self.assertEqual(resp.data.decode("utf-8").count("