diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py index 1bfb0fe2..17711989 100644 --- a/ihatemoney/forms.py +++ b/ihatemoney/forms.py @@ -121,6 +121,11 @@ class CalculatorStringField(StringField): class EditProjectForm(FlaskForm): name = StringField(_("Project name"), validators=[DataRequired()]) + current_password = PasswordField( + _("Current private code"), + description=_("Enter existing private code to edit project"), + validators=[DataRequired()], + ) # If empty -> don't change the password password = PasswordField( _("New private code"), @@ -154,6 +159,13 @@ class EditProjectForm(FlaskForm): for currency_name in self.currency_helper.get_currencies() ] + def validate_current_password(self, field): + project = Project.query.get(self.id.data) + if project is None: + raise ValidationError(_("Unknown error")) + if not check_password_hash(project.password, self.current_password.data): + raise ValidationError(_("Invalid private code.")) + @property def logging_preference(self): """Get the LoggingMode object corresponding to current form data.""" @@ -212,7 +224,9 @@ class ImportProjectForm(FlaskForm): class ProjectForm(EditProjectForm): id = StringField(_("Project identifier"), validators=[DataRequired()]) - # This field overrides the one from EditProjectForm + # Remove this field that is inherited from EditProjectForm + current_password = None + # This field overrides the one from EditProjectForm (to make it mandatory) password = PasswordField(_("Private code"), validators=[DataRequired()]) submit = SubmitField(_("Create the project")) diff --git a/ihatemoney/templates/forms.html b/ihatemoney/templates/forms.html index 26eb376a..e339268e 100644 --- a/ihatemoney/templates/forms.html +++ b/ihatemoney/templates/forms.html @@ -100,6 +100,7 @@ {{ input(form.default_currency) }} + {{ input(form.current_password) }}
diff --git a/ihatemoney/tests/api_test.py b/ihatemoney/tests/api_test.py index b5f11e3e..73d917da 100644 --- a/ihatemoney/tests/api_test.py +++ b/ihatemoney/tests/api_test.py @@ -131,7 +131,7 @@ class APITestCase(IhatemoneyTestCase): decoded_resp = json.loads(resp.data.decode("utf-8")) self.assertDictEqual(decoded_resp, expected) - # edit should work + # edit should fail if we don't provide the current private code resp = self.client.put( "/api/projects/raclette", data={ @@ -143,7 +143,36 @@ class APITestCase(IhatemoneyTestCase): }, headers=self.get_auth("raclette"), ) + self.assertEqual(400, resp.status_code) + # edit should fail if we provide the wrong private code + resp = self.client.put( + "/api/projects/raclette", + data={ + "contact_email": "yeah@notmyidea.org", + "default_currency": "XXX", + "current_password": "fromage aux patates", + "password": "raclette", + "name": "The raclette party", + "project_history": "y", + }, + headers=self.get_auth("raclette"), + ) + self.assertEqual(400, resp.status_code) + + # edit with the correct private code should work + resp = self.client.put( + "/api/projects/raclette", + data={ + "contact_email": "yeah@notmyidea.org", + "default_currency": "XXX", + "current_password": "raclette", + "password": "raclette", + "name": "The raclette party", + "project_history": "y", + }, + headers=self.get_auth("raclette"), + ) self.assertEqual(200, resp.status_code) resp = self.client.get( @@ -168,6 +197,7 @@ class APITestCase(IhatemoneyTestCase): data={ "contact_email": "yeah@notmyidea.org", "default_currency": "XXX", + "current_password": "raclette", "password": "tartiflette", "name": "The raclette party", }, @@ -213,9 +243,23 @@ class APITestCase(IhatemoneyTestCase): "/api/projects/raclette/token", headers={"Authorization": f"Basic {decoded_resp['token']}"}, ) - self.assertEqual(200, resp.status_code) + # We shouldn't be able to edit project without private code + resp = self.client.put( + "/api/projects/raclette", + data={ + "contact_email": "yeah@notmyidea.org", + "default_currency": "XXX", + "password": "tartiflette", + "name": "The raclette party", + }, + headers={"Authorization": f"Basic {decoded_resp['token']}"}, + ) + self.assertEqual(400, resp.status_code) + expected_resp = {"current_password": ["This field is required."]} + self.assertEqual(expected_resp, json.loads(resp.data.decode("utf-8"))) + def test_token_login(self): resp = self.api_create("raclette") # Get token @@ -719,6 +763,7 @@ class APITestCase(IhatemoneyTestCase): data={ "contact_email": "yeah@notmyidea.org", "default_currency": "XXX", + "current_password": "raclette", "password": "raclette", "name": "The raclette party", }, diff --git a/ihatemoney/tests/budget_test.py b/ihatemoney/tests/budget_test.py index bac56507..1b979223 100644 --- a/ihatemoney/tests/budget_test.py +++ b/ihatemoney/tests/budget_test.py @@ -181,6 +181,7 @@ class BudgetTestCase(IhatemoneyTestCase): data={ "name": "raclette", "contact_email": "zorglub@notmyidea.org", + "current_password": "raclette", "password": "didoudida", "default_currency": "XXX", }, @@ -922,10 +923,30 @@ class BudgetTestCase(IhatemoneyTestCase): "default_currency": "USD", } - resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True) - self.assertEqual(resp.status_code, 200) + # It should fail if we don't provide the current password + resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=False) + self.assertIn("This field is required", resp.data.decode("utf-8")) project = self.get_project("raclette") + self.assertNotEqual(project.name, new_data["name"]) + self.assertNotEqual(project.contact_email, new_data["contact_email"]) + self.assertNotEqual(project.default_currency, new_data["default_currency"]) + self.assertFalse(check_password_hash(project.password, new_data["password"])) + # It should fail if we provide the wrong current password + new_data["current_password"] = "patates au fromage" + resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=False) + self.assertIn("Invalid private code", resp.data.decode("utf-8")) + project = self.get_project("raclette") + self.assertNotEqual(project.name, new_data["name"]) + self.assertNotEqual(project.contact_email, new_data["contact_email"]) + self.assertNotEqual(project.default_currency, new_data["default_currency"]) + self.assertFalse(check_password_hash(project.password, new_data["password"])) + + # It should work if we give the current private code + new_data["current_password"] = "raclette" + resp = self.client.post("/raclette/edit", data=new_data) + self.assertEqual(resp.status_code, 302) + project = self.get_project("raclette") self.assertEqual(project.name, new_data["name"]) self.assertEqual(project.contact_email, new_data["contact_email"]) self.assertEqual(project.default_currency, new_data["default_currency"]) @@ -934,7 +955,7 @@ class BudgetTestCase(IhatemoneyTestCase): # Editing a project with a wrong email address should fail new_data["contact_email"] = "wrong_email" - resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True) + resp = self.client.post("/raclette/edit", data=new_data) self.assertIn("Invalid email address", resp.data.decode("utf-8")) def test_dashboard(self): @@ -2039,6 +2060,7 @@ class BudgetTestCase(IhatemoneyTestCase): data={ "name": "raclette", "contact_email": "zorglub@notmyidea.org", + "current_password": "raclette", "password": "didoudida", "default_currency": "XXX", }, diff --git a/ihatemoney/tests/history_test.py b/ihatemoney/tests/history_test.py index 1cc15ced..9b4ec33f 100644 --- a/ihatemoney/tests/history_test.py +++ b/ihatemoney/tests/history_test.py @@ -19,11 +19,12 @@ class HistoryTestCase(IhatemoneyTestCase): self.assertEqual(resp.data.decode("utf-8").count(" -- "), 1) self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) - def change_privacy_to(self, logging_preference): + def change_privacy_to(self, current_password, logging_preference): # Change only logging_preferences new_data = { "name": "demo", "contact_email": "demo@notmyidea.org", + "current_password": current_password, "password": "demo", "default_currency": "XXX", } @@ -76,6 +77,7 @@ class HistoryTestCase(IhatemoneyTestCase): new_data = { "name": "demo2", "contact_email": "demo2@notmyidea.org", + "current_password": "demo", "password": "123456", "project_history": "y", "default_currency": "USD", # Currency changed from default @@ -114,7 +116,7 @@ class HistoryTestCase(IhatemoneyTestCase): resp.data.decode("utf-8"), ) - self.change_privacy_to(LoggingMode.DISABLED) + self.change_privacy_to("demo", LoggingMode.DISABLED) resp = self.client.get("/demo/history") self.assertEqual(resp.status_code, 200) @@ -122,7 +124,7 @@ class HistoryTestCase(IhatemoneyTestCase): self.assertEqual(resp.data.decode("utf-8").count(" -- "), 2) self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) - self.change_privacy_to(LoggingMode.RECORD_IP) + self.change_privacy_to("demo", LoggingMode.RECORD_IP) resp = self.client.get("/demo/history") self.assertEqual(resp.status_code, 200) @@ -132,7 +134,7 @@ class HistoryTestCase(IhatemoneyTestCase): self.assertEqual(resp.data.decode("utf-8").count(" -- "), 2) self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 1) - self.change_privacy_to(LoggingMode.ENABLED) + self.change_privacy_to("demo", LoggingMode.ENABLED) resp = self.client.get("/demo/history") self.assertEqual(resp.status_code, 200) @@ -141,7 +143,7 @@ class HistoryTestCase(IhatemoneyTestCase): self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 2) def test_project_privacy_edit2(self): - self.change_privacy_to(LoggingMode.RECORD_IP) + self.change_privacy_to("demo", LoggingMode.RECORD_IP) resp = self.client.get("/demo/history") self.assertEqual(resp.status_code, 200) @@ -149,7 +151,7 @@ class HistoryTestCase(IhatemoneyTestCase): self.assertEqual(resp.data.decode("utf-8").count(" -- "), 1) self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 1) - self.change_privacy_to(LoggingMode.DISABLED) + self.change_privacy_to("demo", LoggingMode.DISABLED) resp = self.client.get("/demo/history") self.assertEqual(resp.status_code, 200) @@ -159,7 +161,7 @@ class HistoryTestCase(IhatemoneyTestCase): self.assertEqual(resp.data.decode("utf-8").count(" -- "), 1) self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 2) - self.change_privacy_to(LoggingMode.ENABLED) + self.change_privacy_to("demo", LoggingMode.ENABLED) resp = self.client.get("/demo/history") self.assertEqual(resp.status_code, 200) @@ -171,6 +173,7 @@ class HistoryTestCase(IhatemoneyTestCase): new_data = { "name": "demo2", "contact_email": "demo2@notmyidea.org", + "current_password": "demo", "password": "123456", "default_currency": "USD", } @@ -233,7 +236,7 @@ class HistoryTestCase(IhatemoneyTestCase): def test_disable_clear_no_new_records(self): # Disable logging - self.change_privacy_to(LoggingMode.DISABLED) + self.change_privacy_to("demo", LoggingMode.DISABLED) # Ensure we can't clear history with a GET or with a password-less POST resp = self.client.get("/demo/erase_history") @@ -276,13 +279,13 @@ class HistoryTestCase(IhatemoneyTestCase): def test_clear_ip_records(self): # Enable IP Recording - self.change_privacy_to(LoggingMode.RECORD_IP) + self.change_privacy_to("demo", LoggingMode.RECORD_IP) # Do lots of database operations to generate IP address entries self.do_misc_database_operations(LoggingMode.RECORD_IP) # Disable IP Recording - self.change_privacy_to(LoggingMode.ENABLED) + self.change_privacy_to("123456", LoggingMode.ENABLED) resp = self.client.get("/demo/history") self.assertEqual(resp.status_code, 200)