tests: migrate to pytest

- replace setUp/tearDown with pytest fixtures
- rename test classes to use the pytest convention
- use pytest assertions

Co-authored-by: Glandos <bugs-github@antipoul.fr>
This commit is contained in:
Éloi Rivard 2023-08-12 13:09:28 +02:00 committed by zorun
parent 2ce76158d2
commit 21408f8bc9
9 changed files with 741 additions and 766 deletions

3
.gitignore vendored
View file

@ -16,4 +16,5 @@ ihatemoney/budget.db
.DS_Store .DS_Store
.idea .idea
.python-version .python-version
.coverage*
prof

View file

@ -1,13 +1,12 @@
import base64 import base64
import datetime import datetime
import json import json
import unittest
from ihatemoney.tests.common.help_functions import em_surround from ihatemoney.tests.common.help_functions import em_surround
from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
class APITestCase(IhatemoneyTestCase): class TestAPI(IhatemoneyTestCase):
"""Tests the API""" """Tests the API"""
@ -57,7 +56,7 @@ class APITestCase(IhatemoneyTestCase):
resp = self.client.options( resp = self.client.options(
"/api/projects/raclette", headers=self.get_auth("raclette") "/api/projects/raclette", headers=self.get_auth("raclette")
) )
self.assertEqual(resp.headers["Access-Control-Allow-Origin"], "*") assert resp.headers["Access-Control-Allow-Origin"] == "*"
def test_basic_auth(self): def test_basic_auth(self):
# create a project # create a project
@ -94,32 +93,32 @@ class APITestCase(IhatemoneyTestCase):
}, },
) )
self.assertEqual(400, resp.status_code) assert 400 == resp.status_code
self.assertEqual( assert '{"contact_email": ["Invalid email address."]}\n' == resp.data.decode(
'{"contact_email": ["Invalid email address."]}\n', resp.data.decode("utf-8") "utf-8"
) )
# create it # create it
with self.app.mail.record_messages() as outbox: with self.app.mail.record_messages() as outbox:
resp = self.api_create("raclette") resp = self.api_create("raclette")
self.assertEqual(201, resp.status_code) assert 201 == resp.status_code
# Check that email messages have been sent. # Check that email messages have been sent.
self.assertEqual(len(outbox), 1) assert len(outbox) == 1
self.assertEqual(outbox[0].recipients, ["raclette@notmyidea.org"]) assert outbox[0].recipients == ["raclette@notmyidea.org"]
# create it twice should return a 400 # create it twice should return a 400
resp = self.api_create("raclette") resp = self.api_create("raclette")
self.assertEqual(400, resp.status_code) assert 400 == resp.status_code
self.assertIn("id", json.loads(resp.data.decode("utf-8"))) assert "id" in json.loads(resp.data.decode("utf-8"))
# get information about it # get information about it
resp = self.client.get( resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette") "/api/projects/raclette", headers=self.get_auth("raclette")
) )
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
expected = { expected = {
"members": [], "members": [],
"name": "raclette", "name": "raclette",
@ -129,7 +128,7 @@ class APITestCase(IhatemoneyTestCase):
"logging_preference": 1, "logging_preference": 1,
} }
decoded_resp = json.loads(resp.data.decode("utf-8")) decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected) assert decoded_resp == expected
# edit should fail if we don't provide the current private code # edit should fail if we don't provide the current private code
resp = self.client.put( resp = self.client.put(
@ -143,7 +142,7 @@ class APITestCase(IhatemoneyTestCase):
}, },
headers=self.get_auth("raclette"), headers=self.get_auth("raclette"),
) )
self.assertEqual(400, resp.status_code) assert 400 == resp.status_code
# edit should fail if we provide the wrong private code # edit should fail if we provide the wrong private code
resp = self.client.put( resp = self.client.put(
@ -158,7 +157,7 @@ class APITestCase(IhatemoneyTestCase):
}, },
headers=self.get_auth("raclette"), headers=self.get_auth("raclette"),
) )
self.assertEqual(400, resp.status_code) assert 400 == resp.status_code
# edit with the correct private code should work # edit with the correct private code should work
resp = self.client.put( resp = self.client.put(
@ -173,13 +172,13 @@ class APITestCase(IhatemoneyTestCase):
}, },
headers=self.get_auth("raclette"), headers=self.get_auth("raclette"),
) )
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
resp = self.client.get( resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette") "/api/projects/raclette", headers=self.get_auth("raclette")
) )
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
expected = { expected = {
"name": "The raclette party", "name": "The raclette party",
"contact_email": "yeah@notmyidea.org", "contact_email": "yeah@notmyidea.org",
@ -189,7 +188,7 @@ class APITestCase(IhatemoneyTestCase):
"logging_preference": 1, "logging_preference": 1,
} }
decoded_resp = json.loads(resp.data.decode("utf-8")) decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected) assert decoded_resp == expected
# password change is possible via API # password change is possible via API
resp = self.client.put( resp = self.client.put(
@ -204,12 +203,12 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"), headers=self.get_auth("raclette"),
) )
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
resp = self.client.get( resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette", "tartiflette") "/api/projects/raclette", headers=self.get_auth("raclette", "tartiflette")
) )
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
# delete should work # delete should work
resp = self.client.delete( resp = self.client.delete(
@ -220,21 +219,21 @@ class APITestCase(IhatemoneyTestCase):
resp = self.client.get( resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette") "/api/projects/raclette", headers=self.get_auth("raclette")
) )
self.assertEqual(401, resp.status_code) assert 401 == resp.status_code
def test_token_creation(self): def test_token_creation(self):
"""Test that token of project is generated""" """Test that token of project is generated"""
# Create project # Create project
resp = self.api_create("raclette") resp = self.api_create("raclette")
self.assertEqual(201, resp.status_code) assert 201 == resp.status_code
# Get token # Get token
resp = self.client.get( resp = self.client.get(
"/api/projects/raclette/token", headers=self.get_auth("raclette") "/api/projects/raclette/token", headers=self.get_auth("raclette")
) )
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
decoded_resp = json.loads(resp.data.decode("utf-8")) decoded_resp = json.loads(resp.data.decode("utf-8"))
@ -243,7 +242,7 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/token", "/api/projects/raclette/token",
headers={"Authorization": f"Basic {decoded_resp['token']}"}, headers={"Authorization": f"Basic {decoded_resp['token']}"},
) )
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
# We shouldn't be able to edit project without private code # We shouldn't be able to edit project without private code
resp = self.client.put( resp = self.client.put(
@ -256,9 +255,9 @@ class APITestCase(IhatemoneyTestCase):
}, },
headers={"Authorization": f"Basic {decoded_resp['token']}"}, headers={"Authorization": f"Basic {decoded_resp['token']}"},
) )
self.assertEqual(400, resp.status_code) assert 400 == resp.status_code
expected_resp = {"current_password": ["This field is required."]} expected_resp = {"current_password": ["This field is required."]}
self.assertEqual(expected_resp, json.loads(resp.data.decode("utf-8"))) assert expected_resp == json.loads(resp.data.decode("utf-8"))
def test_token_login(self): def test_token_login(self):
resp = self.api_create("raclette") resp = self.api_create("raclette")
@ -269,7 +268,7 @@ class APITestCase(IhatemoneyTestCase):
decoded_resp = json.loads(resp.data.decode("utf-8")) decoded_resp = json.loads(resp.data.decode("utf-8"))
resp = self.client.get(f"/raclette/join/{decoded_resp['token']}") resp = self.client.get(f"/raclette/join/{decoded_resp['token']}")
# Test that we are redirected. # Test that we are redirected.
self.assertEqual(302, resp.status_code) assert 302 == resp.status_code
def test_member(self): def test_member(self):
# create a project # create a project
@ -281,7 +280,7 @@ class APITestCase(IhatemoneyTestCase):
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8")) assert "[]\n" == req.data.decode("utf-8")
# add a member # add a member
req = self.client.post( req = self.client.post(
@ -292,7 +291,7 @@ class APITestCase(IhatemoneyTestCase):
# the id of the new member should be returned # the id of the new member should be returned
self.assertStatus(201, req) self.assertStatus(201, req)
self.assertEqual("1\n", req.data.decode("utf-8")) assert "1\n" == req.data.decode("utf-8")
# the list of participants should contain one member # the list of participants should contain one member
req = self.client.get( req = self.client.get(
@ -300,7 +299,7 @@ class APITestCase(IhatemoneyTestCase):
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual(len(json.loads(req.data.decode("utf-8"))), 1) assert len(json.loads(req.data.decode("utf-8"))) == 1
# Try to add another member with the same name. # Try to add another member with the same name.
req = self.client.post( req = self.client.post(
@ -325,8 +324,8 @@ class APITestCase(IhatemoneyTestCase):
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual("Fred", json.loads(req.data.decode("utf-8"))["name"]) assert "Fred" == json.loads(req.data.decode("utf-8"))["name"]
self.assertEqual(2, json.loads(req.data.decode("utf-8"))["weight"]) assert 2 == json.loads(req.data.decode("utf-8"))["weight"]
# edit this member with same information # edit this member with same information
# (test PUT idemopotence) # (test PUT idemopotence)
@ -350,7 +349,7 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/members/1", headers=self.get_auth("raclette") "/api/projects/raclette/members/1", headers=self.get_auth("raclette")
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual(False, json.loads(req.data.decode("utf-8"))["activated"]) assert not json.loads(req.data.decode("utf-8"))["activated"]
# re-activate the participant # re-activate the participant
req = self.client.put( req = self.client.put(
@ -363,7 +362,7 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/members/1", headers=self.get_auth("raclette") "/api/projects/raclette/members/1", headers=self.get_auth("raclette")
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual(True, json.loads(req.data.decode("utf-8"))["activated"]) assert json.loads(req.data.decode("utf-8"))["activated"]
# delete a member # delete a member
@ -379,7 +378,7 @@ class APITestCase(IhatemoneyTestCase):
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8")) assert "[]\n" == req.data.decode("utf-8")
def test_bills(self): def test_bills(self):
# create a project # create a project
@ -396,7 +395,7 @@ class APITestCase(IhatemoneyTestCase):
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8")) assert "[]\n" == req.data.decode("utf-8")
# add a bill # add a bill
req = self.client.post( req = self.client.post(
@ -414,7 +413,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id # should return the id
self.assertStatus(201, req) self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "1\n") assert req.data.decode("utf-8") == "1\n"
# get this bill details # get this bill details
req = self.client.get( req = self.client.get(
@ -439,19 +438,19 @@ class APITestCase(IhatemoneyTestCase):
} }
got = json.loads(req.data.decode("utf-8")) got = json.loads(req.data.decode("utf-8"))
self.assertEqual( assert (
datetime.date.today(), datetime.date.today()
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(), == datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
) )
del got["creation_date"] del got["creation_date"]
self.assertDictEqual(expected, got) assert expected == got
# the list of bills should length 1 # the list of bills should length 1
req = self.client.get( req = self.client.get(
"/api/projects/raclette/bills", headers=self.get_auth("raclette") "/api/projects/raclette/bills", headers=self.get_auth("raclette")
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual(1, len(json.loads(req.data.decode("utf-8")))) assert 1 == len(json.loads(req.data.decode("utf-8")))
# edit with errors should return an error # edit with errors should return an error
req = self.client.put( req = self.client.put(
@ -468,9 +467,7 @@ class APITestCase(IhatemoneyTestCase):
) )
self.assertStatus(400, req) self.assertStatus(400, req)
self.assertEqual( assert '{"date": ["This field is required."]}\n' == req.data.decode("utf-8")
'{"date": ["This field is required."]}\n', req.data.decode("utf-8")
)
# edit a bill # edit a bill
req = self.client.put( req = self.client.put(
@ -510,12 +507,12 @@ class APITestCase(IhatemoneyTestCase):
} }
got = json.loads(req.data.decode("utf-8")) got = json.loads(req.data.decode("utf-8"))
self.assertEqual( assert (
creation_date, creation_date
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(), == datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
) )
del got["creation_date"] del got["creation_date"]
self.assertDictEqual(expected, got) assert expected == got
# delete a bill # delete a bill
req = self.client.delete( req = self.client.delete(
@ -562,7 +559,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id # should return the id
self.assertStatus(201, req) self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "{}\n".format(id)) assert req.data.decode("utf-8") == "{}\n".format(id)
# get this bill's details # get this bill's details
req = self.client.get( req = self.client.get(
@ -588,12 +585,12 @@ class APITestCase(IhatemoneyTestCase):
} }
got = json.loads(req.data.decode("utf-8")) got = json.loads(req.data.decode("utf-8"))
self.assertEqual( assert (
datetime.date.today(), datetime.date.today()
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(), == datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
) )
del got["creation_date"] del got["creation_date"]
self.assertDictEqual(expected, got) assert expected == got
# should raise errors # should raise errors
erroneous_amounts = [ erroneous_amounts = [
@ -621,19 +618,19 @@ class APITestCase(IhatemoneyTestCase):
def test_currencies(self): def test_currencies(self):
# check /currencies for list of supported currencies # check /currencies for list of supported currencies
resp = self.client.get("/api/currencies") resp = self.client.get("/api/currencies")
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
self.assertIn("XXX", json.loads(resp.data.decode("utf-8"))) assert "XXX" in json.loads(resp.data.decode("utf-8"))
# create project with a default currency # create project with a default currency
resp = self.api_create("raclette", default_currency="EUR") resp = self.api_create("raclette", default_currency="EUR")
self.assertEqual(201, resp.status_code) assert 201 == resp.status_code
# get information about it # get information about it
resp = self.client.get( resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette") "/api/projects/raclette", headers=self.get_auth("raclette")
) )
self.assertEqual(200, resp.status_code) assert 200 == resp.status_code
expected = { expected = {
"members": [], "members": [],
"name": "raclette", "name": "raclette",
@ -643,7 +640,7 @@ class APITestCase(IhatemoneyTestCase):
"logging_preference": 1, "logging_preference": 1,
} }
decoded_resp = json.loads(resp.data.decode("utf-8")) decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected) assert decoded_resp == expected
# Add participants # Add participants
self.api_add_member("raclette", "zorglub") self.api_add_member("raclette", "zorglub")
@ -666,7 +663,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id # should return the id
self.assertStatus(201, req) self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "1\n") assert req.data.decode("utf-8") == "1\n"
# get this bill details # get this bill details
req = self.client.get( req = self.client.get(
@ -691,12 +688,12 @@ class APITestCase(IhatemoneyTestCase):
} }
got = json.loads(req.data.decode("utf-8")) got = json.loads(req.data.decode("utf-8"))
self.assertEqual( assert (
datetime.date.today(), datetime.date.today()
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(), == datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
) )
del got["creation_date"] del got["creation_date"]
self.assertDictEqual(expected, got) assert expected == got
# Change bill amount and currency # Change bill amount and currency
req = self.client.put( req = self.client.put(
@ -737,7 +734,7 @@ class APITestCase(IhatemoneyTestCase):
got = json.loads(req.data.decode("utf-8")) got = json.loads(req.data.decode("utf-8"))
del got["creation_date"] del got["creation_date"]
self.assertDictEqual(expected, got) assert expected == got
# Add a bill with yet another currency # Add a bill with yet another currency
req = self.client.post( req = self.client.post(
@ -755,7 +752,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id # should return the id
self.assertStatus(201, req) self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "2\n") assert req.data.decode("utf-8") == "2\n"
# Try to remove default project currency, it should fail # Try to remove default project currency, it should fail
req = self.client.put( req = self.client.put(
@ -770,9 +767,9 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"), headers=self.get_auth("raclette"),
) )
self.assertStatus(400, req) self.assertStatus(400, req)
self.assertIn("This project cannot be set", req.data.decode("utf-8")) assert "This project cannot be set" in req.data.decode("utf-8")
self.assertIn( assert "because it contains bills in multiple currencies" in req.data.decode(
"because it contains bills in multiple currencies", req.data.decode("utf-8") "utf-8"
) )
def test_statistics(self): def test_statistics(self):
@ -801,33 +798,30 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/statistics", headers=self.get_auth("raclette") "/api/projects/raclette/statistics", headers=self.get_auth("raclette")
) )
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual( assert [
[ {
{ "balance": 12.5,
"balance": 12.5, "member": {
"member": { "activated": True,
"activated": True, "id": 1,
"id": 1, "name": "zorglub",
"name": "zorglub", "weight": 1.0,
"weight": 1.0,
},
"paid": 25.0,
"spent": 12.5,
}, },
{ "paid": 25.0,
"balance": -12.5, "spent": 12.5,
"member": { },
"activated": True, {
"id": 2, "balance": -12.5,
"name": "fred", "member": {
"weight": 1.0, "activated": True,
}, "id": 2,
"paid": 0, "name": "fred",
"spent": 12.5, "weight": 1.0,
}, },
], "paid": 0,
json.loads(req.data.decode("utf-8")), "spent": 12.5,
) },
] == json.loads(req.data.decode("utf-8"))
def test_username_xss(self): def test_username_xss(self):
# create a project # create a project
@ -839,7 +833,7 @@ class APITestCase(IhatemoneyTestCase):
self.api_add_member("raclette", "<script>") self.api_add_member("raclette", "<script>")
result = self.client.get("/raclette/") result = self.client.get("/raclette/")
self.assertNotIn("<script>", result.data.decode("utf-8")) assert "<script>" not in result.data.decode("utf-8")
def test_weighted_bills(self): def test_weighted_bills(self):
# create a project # create a project
@ -888,12 +882,12 @@ class APITestCase(IhatemoneyTestCase):
"original_currency": "XXX", "original_currency": "XXX",
} }
got = json.loads(req.data.decode("utf-8")) got = json.loads(req.data.decode("utf-8"))
self.assertEqual( assert (
creation_date, creation_date
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(), == datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
) )
del got["creation_date"] del got["creation_date"]
self.assertDictEqual(expected, got) assert expected == got
# getting it should return a 404 # getting it should return a 404
req = self.client.get( req = self.client.get(
@ -933,7 +927,7 @@ class APITestCase(IhatemoneyTestCase):
self.assertStatus(200, req) self.assertStatus(200, req)
decoded_req = json.loads(req.data.decode("utf-8")) decoded_req = json.loads(req.data.decode("utf-8"))
self.assertDictEqual(decoded_req, expected) assert decoded_req == expected
def test_log_created_from_api_call(self): def test_log_created_from_api_call(self):
# create a project # create a project
@ -944,15 +938,13 @@ class APITestCase(IhatemoneyTestCase):
self.api_add_member("raclette", "zorglub") self.api_add_member("raclette", "zorglub")
resp = self.client.get("/raclette/history", follow_redirects=True) resp = self.client.get("/raclette/history", follow_redirects=True)
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert f"Participant {em_surround('zorglub')} added" in resp.data.decode(
f"Participant {em_surround('zorglub')} added", resp.data.decode("utf-8") "utf-8"
) )
self.assertIn( assert f"Project {em_surround('raclette')} added" in resp.data.decode("utf-8")
f"Project {em_surround('raclette')} added", resp.data.decode("utf-8") assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
) assert "127.0.0.1" not in resp.data.decode("utf-8")
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
def test_amount_is_null(self): def test_amount_is_null(self):
self.api_create("raclette") self.api_create("raclette")
@ -1000,7 +992,3 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"), headers=self.get_auth("raclette"),
) )
self.assertStatus(400, req) self.assertStatus(400, req)
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load diff

View file

@ -1,15 +1,13 @@
import os import os
from unittest.mock import MagicMock
from flask_testing import TestCase import pytest
from ihatemoney import models from ihatemoney import models
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.run import create_app, db
from ihatemoney.utils import generate_password_hash from ihatemoney.utils import generate_password_hash
class BaseTestCase(TestCase): @pytest.mark.usefixtures("client", "converter")
class BaseTestCase:
SECRET_KEY = "TEST SESSION" SECRET_KEY = "TEST SESSION"
SQLALCHEMY_DATABASE_URI = os.environ.get( SQLALCHEMY_DATABASE_URI = os.environ.get(
"TESTING_SQLALCHEMY_DATABASE_URI", "sqlite://" "TESTING_SQLALCHEMY_DATABASE_URI", "sqlite://"
@ -18,30 +16,6 @@ class BaseTestCase(TestCase):
PASSWORD_HASH_METHOD = "pbkdf2:sha1:1" PASSWORD_HASH_METHOD = "pbkdf2:sha1:1"
PASSWORD_HASH_SALT_LENGTH = 1 PASSWORD_HASH_SALT_LENGTH = 1
def create_app(self):
# Pass the test object as a configuration.
return create_app(self)
def setUp(self):
db.create_all()
# Add dummy data to CurrencyConverter for all tests (since it's a singleton)
mock_data = {
"USD": 1,
"EUR": 0.8,
"CAD": 1.2,
"PLN": 4,
CurrencyConverter.no_currency: 1,
}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# Also add it to an attribute to make tests clearer
self.converter = converter
def tearDown(self):
# clean after testing
db.session.remove()
db.drop_all()
def login(self, project, password=None, test_client=None): def login(self, project, password=None, test_client=None):
password = password or project password = password or project
@ -83,7 +57,7 @@ class BaseTestCase(TestCase):
data=data, data=data,
# follow_redirects=True, # follow_redirects=True,
) )
self.assertEqual("/{id}/edit" in str(resp.response), not success) assert ("/{id}/edit" in str(resp.response)) == (not success)
def create_project(self, id, default_currency="XXX", name=None, password=None): def create_project(self, id, default_currency="XXX", name=None, password=None):
name = name or str(id) name = name or str(id)
@ -109,11 +83,9 @@ class IhatemoneyTestCase(BaseTestCase):
def assertStatus(self, expected, resp, url=None): def assertStatus(self, expected, resp, url=None):
if url is None: if url is None:
url = resp.request.path url = resp.request.path
return self.assertEqual( assert (
expected, expected == resp.status_code
resp.status_code, ), f"{url} expected {expected}, got {resp.status_code}"
f"{url} expected {expected}, got {resp.status_code}",
)
def enable_admin(self, password="adminpass"): def enable_admin(self, password="adminpass"):
self.app.config["ACTIVATE_ADMIN_DASHBOARD"] = True self.app.config["ACTIVATE_ADMIN_DASHBOARD"] = True

View file

@ -0,0 +1,49 @@
from unittest.mock import MagicMock
from flask import Flask
import pytest
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.run import create_app, db
@pytest.fixture
def app(request: pytest.FixtureRequest):
"""Create the Flask app with database"""
app = create_app(request.cls)
with app.app_context():
db.create_all()
request.cls.app = app
yield app
# clean after testing
db.session.remove()
db.drop_all()
@pytest.fixture
def client(app: Flask, request: pytest.FixtureRequest):
client = app.test_client()
request.cls.client = client
yield client
@pytest.fixture
def converter(request: pytest.FixtureRequest):
# Add dummy data to CurrencyConverter for all tests (since it's a singleton)
mock_data = {
"USD": 1,
"EUR": 0.8,
"CAD": 1.2,
"PLN": 4,
CurrencyConverter.no_currency: 1,
}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# Also add it to an attribute to make tests clearer
request.cls.converter = converter
yield converter

View file

@ -1,4 +1,6 @@
import unittest import re
import pytest
from ihatemoney import history, models from ihatemoney import history, models
from ihatemoney.tests.common.help_functions import em_surround from ihatemoney.tests.common.help_functions import em_surround
@ -6,18 +8,33 @@ from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
from ihatemoney.versioning import LoggingMode from ihatemoney.versioning import LoggingMode
class HistoryTestCase(IhatemoneyTestCase): @pytest.fixture
def setUp(self): def demo(client):
super().setUp() client.post(
self.post_project("demo") "/create",
self.login("demo") data={
"name": "demo",
"id": "demo",
"password": "demo",
"contact_email": "demo@notmyidea.org",
"default_currency": "XXX",
"project_history": True,
},
)
client.post(
"/authenticate",
data=dict(id="demo", password="demo"),
)
@pytest.mark.usefixtures("demo")
class TestHistory(IhatemoneyTestCase):
def test_simple_create_logentry_no_ip(self): def test_simple_create_logentry_no_ip(self):
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn(f"Project {em_surround('demo')} added", resp.data.decode("utf-8")) assert f"Project {em_surround('demo')} added" in resp.data.decode("utf-8")
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1) assert resp.data.decode("utf-8").count("<td> -- </td>") == 1
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) assert "127.0.0.1" not in resp.data.decode("utf-8")
def change_privacy_to(self, current_password, logging_preference): def change_privacy_to(self, current_password, logging_preference):
# Change only logging_preferences # Change only logging_preferences
@ -36,42 +53,38 @@ class HistoryTestCase(IhatemoneyTestCase):
# Disable History # Disable History
resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True) resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True)
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertNotIn("alert-danger", resp.data.decode("utf-8")) assert "alert-danger" not in resp.data.decode("utf-8")
resp = self.client.get("/demo/edit") resp = self.client.get("/demo/edit")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
if logging_preference == LoggingMode.DISABLED: if logging_preference == LoggingMode.DISABLED:
self.assertIn('<input id="project_history"', resp.data.decode("utf-8")) assert '<input id="project_history"' in resp.data.decode("utf-8")
else: else:
self.assertIn( assert '<input checked id="project_history"' in resp.data.decode("utf-8")
'<input checked id="project_history"', resp.data.decode("utf-8")
)
if logging_preference == LoggingMode.RECORD_IP: if logging_preference == LoggingMode.RECORD_IP:
self.assertIn('<input checked id="ip_recording"', resp.data.decode("utf-8")) assert '<input checked id="ip_recording"' in resp.data.decode("utf-8")
else: else:
self.assertIn('<input id="ip_recording"', resp.data.decode("utf-8")) assert '<input id="ip_recording"' in resp.data.decode("utf-8")
def assert_empty_history_logging_disabled(self): def assert_empty_history_logging_disabled(self):
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertIn( assert (
"This project has history disabled. New actions won't appear below.", "This project has history disabled. New actions won't appear below."
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
self.assertIn("Nothing to list", resp.data.decode("utf-8")) assert "Nothing to list" in resp.data.decode("utf-8")
self.assertNotIn( assert (
"The table below reflects actions recorded prior to disabling project history.", "The table below reflects actions recorded prior to disabling project history."
resp.data.decode("utf-8"), not in resp.data.decode("utf-8")
) )
self.assertNotIn( assert "Some entries below contain IP addresses," not in resp.data.decode(
"Some entries below contain IP addresses,", resp.data.decode("utf-8") "utf-8"
)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
self.assertNotIn("<td> -- </td>", resp.data.decode("utf-8"))
self.assertNotIn(
f"Project {em_surround('demo')} added", resp.data.decode("utf-8")
) )
assert "127.0.0.1" not in resp.data.decode("utf-8")
assert "<td> -- </td>" not in resp.data.decode("utf-8")
assert f"Project {em_surround('demo')} added" not in resp.data.decode("utf-8")
def test_project_edit(self): def test_project_edit(self):
new_data = { new_data = {
@ -84,90 +97,86 @@ class HistoryTestCase(IhatemoneyTestCase):
} }
resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True) resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True)
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn(f"Project {em_surround('demo')} added", resp.data.decode("utf-8")) assert f"Project {em_surround('demo')} added" in resp.data.decode("utf-8")
self.assertIn( assert (
f"Project contact email changed to {em_surround('demo2@notmyidea.org')}", f"Project contact email changed to {em_surround('demo2@notmyidea.org')}"
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
self.assertIn("Project private code changed", resp.data.decode("utf-8")) assert "Project private code changed" in resp.data.decode("utf-8")
self.assertIn( assert f"Project renamed to {em_surround('demo2')}" in resp.data.decode("utf-8")
f"Project renamed to {em_surround('demo2')}", resp.data.decode("utf-8") assert resp.data.decode("utf-8").index("Project renamed ") < resp.data.decode(
) "utf-8"
self.assertLess( ).index("Project contact email changed to ")
resp.data.decode("utf-8").index("Project renamed "), assert resp.data.decode("utf-8").index("Project renamed ") < resp.data.decode(
resp.data.decode("utf-8").index("Project contact email changed to "), "utf-8"
) ).index("Project private code changed")
self.assertLess( assert resp.data.decode("utf-8").count("<td> -- </td>") == 5
resp.data.decode("utf-8").index("Project renamed "), assert "127.0.0.1" not in resp.data.decode("utf-8")
resp.data.decode("utf-8").index("Project private code changed"),
)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 5)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
def test_project_privacy_edit(self): def test_project_privacy_edit(self):
resp = self.client.get("/demo/edit") resp = self.client.get("/demo/edit")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert (
'<input checked id="project_history" name="project_history" type="checkbox" value="y">', '<input checked id="project_history" name="project_history" type="checkbox" value="y">'
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
self.change_privacy_to("demo", LoggingMode.DISABLED) self.change_privacy_to("demo", LoggingMode.DISABLED)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn("Disabled Project History\n", resp.data.decode("utf-8")) assert "Disabled Project History\n" in resp.data.decode("utf-8")
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2) assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) assert "127.0.0.1" not in resp.data.decode("utf-8")
self.change_privacy_to("demo", LoggingMode.RECORD_IP) self.change_privacy_to("demo", LoggingMode.RECORD_IP)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert "Enabled Project History & IP Address Recording" in resp.data.decode(
"Enabled Project History & IP Address Recording", resp.data.decode("utf-8") "utf-8"
) )
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2) assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 1) assert resp.data.decode("utf-8").count("127.0.0.1") == 1
self.change_privacy_to("demo", LoggingMode.ENABLED) self.change_privacy_to("demo", LoggingMode.ENABLED)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn("Disabled IP Address Recording\n", resp.data.decode("utf-8")) assert "Disabled IP Address Recording\n" in resp.data.decode("utf-8")
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2) assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 2) assert resp.data.decode("utf-8").count("127.0.0.1") == 2
def test_project_privacy_edit2(self): def test_project_privacy_edit2(self):
self.change_privacy_to("demo", LoggingMode.RECORD_IP) self.change_privacy_to("demo", LoggingMode.RECORD_IP)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn("Enabled IP Address Recording\n", resp.data.decode("utf-8")) assert "Enabled IP Address Recording\n" in resp.data.decode("utf-8")
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1) assert resp.data.decode("utf-8").count("<td> -- </td>") == 1
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 1) assert resp.data.decode("utf-8").count("127.0.0.1") == 1
self.change_privacy_to("demo", LoggingMode.DISABLED) self.change_privacy_to("demo", LoggingMode.DISABLED)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert "Disabled Project History & IP Address Recording" in resp.data.decode(
"Disabled Project History & IP Address Recording", resp.data.decode("utf-8") "utf-8"
) )
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1) assert resp.data.decode("utf-8").count("<td> -- </td>") == 1
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 2) assert resp.data.decode("utf-8").count("127.0.0.1") == 2
self.change_privacy_to("demo", LoggingMode.ENABLED) self.change_privacy_to("demo", LoggingMode.ENABLED)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn("Enabled Project History\n", resp.data.decode("utf-8")) assert "Enabled Project History\n" in resp.data.decode("utf-8")
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2) assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 2) assert resp.data.decode("utf-8").count("127.0.0.1") == 2
def do_misc_database_operations(self, logging_mode): def do_misc_database_operations(self, logging_mode):
new_data = { new_data = {
@ -185,13 +194,13 @@ class HistoryTestCase(IhatemoneyTestCase):
new_data["ip_recording"] = "y" new_data["ip_recording"] = "y"
resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True) resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True)
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
# adds a member to this project # adds a member to this project
resp = self.client.post( resp = self.client.post(
"/demo/members/add", data={"name": "zorglub"}, follow_redirects=True "/demo/members/add", data={"name": "zorglub"}, follow_redirects=True
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
user_id = models.Person.query.one().id user_id = models.Person.query.one().id
@ -207,7 +216,7 @@ class HistoryTestCase(IhatemoneyTestCase):
}, },
follow_redirects=True, follow_redirects=True,
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
bill_id = models.Bill.query.one().id bill_id = models.Bill.query.one().id
@ -223,16 +232,16 @@ class HistoryTestCase(IhatemoneyTestCase):
}, },
follow_redirects=True, follow_redirects=True,
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
# delete the bill # delete the bill
resp = self.client.post(f"/demo/delete/{bill_id}", follow_redirects=True) resp = self.client.post(f"/demo/delete/{bill_id}", follow_redirects=True)
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
# delete user using POST method # delete user using POST method
resp = self.client.post( resp = self.client.post(
f"/demo/members/{user_id}/delete", follow_redirects=True f"/demo/members/{user_id}/delete", follow_redirects=True
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
def test_disable_clear_no_new_records(self): def test_disable_clear_no_new_records(self):
# Disable logging # Disable logging
@ -240,27 +249,24 @@ class HistoryTestCase(IhatemoneyTestCase):
# Ensure we can't clear history with a GET or with a password-less POST # Ensure we can't clear history with a GET or with a password-less POST
resp = self.client.get("/demo/erase_history") resp = self.client.get("/demo/erase_history")
self.assertEqual(resp.status_code, 405) assert resp.status_code == 405
resp = self.client.post("/demo/erase_history", follow_redirects=True) resp = self.client.post("/demo/erase_history", follow_redirects=True)
self.assertIn( assert "Error deleting project history" in resp.data.decode("utf-8")
"Error deleting project history",
resp.data.decode("utf-8"),
)
# List history # List history
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert (
"This project has history disabled. New actions won't appear below.", "This project has history disabled. New actions won't appear below."
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
self.assertIn( assert (
"The table below reflects actions recorded prior to disabling project history.", "The table below reflects actions recorded prior to disabling project history."
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
self.assertNotIn("Nothing to list", resp.data.decode("utf-8")) assert "Nothing to list" not in resp.data.decode("utf-8")
self.assertNotIn( assert "Some entries below contain IP addresses," not in resp.data.decode(
"Some entries below contain IP addresses,", resp.data.decode("utf-8") "utf-8"
) )
# Clear Existing Entries # Clear Existing Entries
@ -269,7 +275,7 @@ class HistoryTestCase(IhatemoneyTestCase):
data={"password": "demo"}, data={"password": "demo"},
follow_redirects=True, follow_redirects=True,
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assert_empty_history_logging_disabled() self.assert_empty_history_logging_disabled()
# Do lots of database operations & check that there's still no history # Do lots of database operations & check that there's still no history
@ -288,43 +294,38 @@ class HistoryTestCase(IhatemoneyTestCase):
self.change_privacy_to("123456", LoggingMode.ENABLED) self.change_privacy_to("123456", LoggingMode.ENABLED)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertNotIn( assert (
"This project has history disabled. New actions won't appear below.", "This project has history disabled. New actions won't appear below."
resp.data.decode("utf-8"), not in resp.data.decode("utf-8")
) )
self.assertNotIn( assert (
"The table below reflects actions recorded prior to disabling project history.", "The table below reflects actions recorded prior to disabling project history."
resp.data.decode("utf-8"), not in resp.data.decode("utf-8")
) )
self.assertNotIn("Nothing to list", resp.data.decode("utf-8")) assert "Nothing to list" not in resp.data.decode("utf-8")
self.assertIn( assert "Some entries below contain IP addresses," in resp.data.decode("utf-8")
"Some entries below contain IP addresses,", resp.data.decode("utf-8") assert resp.data.decode("utf-8").count("127.0.0.1") == 12
) assert resp.data.decode("utf-8").count("<td> -- </td>") == 1
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1)
# Generate more operations to confirm additional IP info isn't recorded # Generate more operations to confirm additional IP info isn't recorded
self.do_misc_database_operations(LoggingMode.ENABLED) self.do_misc_database_operations(LoggingMode.ENABLED)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12) assert resp.data.decode("utf-8").count("127.0.0.1") == 12
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 7) assert resp.data.decode("utf-8").count("<td> -- </td>") == 7
# Ensure we can't clear IP data with a GET or with a password-less POST # Ensure we can't clear IP data with a GET or with a password-less POST
resp = self.client.get("/demo/strip_ip_addresses") resp = self.client.get("/demo/strip_ip_addresses")
self.assertEqual(resp.status_code, 405) assert resp.status_code == 405
resp = self.client.post("/demo/strip_ip_addresses", follow_redirects=True) resp = self.client.post("/demo/strip_ip_addresses", follow_redirects=True)
self.assertIn( assert "Error deleting recorded IP addresses" in resp.data.decode("utf-8")
"Error deleting recorded IP addresses",
resp.data.decode("utf-8"),
)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12) assert resp.data.decode("utf-8").count("127.0.0.1") == 12
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 7) assert resp.data.decode("utf-8").count("<td> -- </td>") == 7
# Clear IP Data # Clear IP Data
resp = self.client.post( resp = self.client.post(
@ -333,33 +334,33 @@ class HistoryTestCase(IhatemoneyTestCase):
follow_redirects=True, follow_redirects=True,
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertNotIn( assert (
"This project has history disabled. New actions won't appear below.", "This project has history disabled. New actions won't appear below."
resp.data.decode("utf-8"), not in resp.data.decode("utf-8")
) )
self.assertNotIn( assert (
"The table below reflects actions recorded prior to disabling project history.", "The table below reflects actions recorded prior to disabling project history."
resp.data.decode("utf-8"), not in resp.data.decode("utf-8")
) )
self.assertNotIn("Nothing to list", resp.data.decode("utf-8")) assert "Nothing to list" not in resp.data.decode("utf-8")
self.assertNotIn( assert "Some entries below contain IP addresses," not in resp.data.decode(
"Some entries below contain IP addresses,", resp.data.decode("utf-8") "utf-8"
) )
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 0) assert resp.data.decode("utf-8").count("127.0.0.1") == 0
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 19) assert resp.data.decode("utf-8").count("<td> -- </td>") == 19
def test_logs_for_common_actions(self): def test_logs_for_common_actions(self):
# adds a member to this project # adds a member to this project
resp = self.client.post( resp = self.client.post(
"/demo/members/add", data={"name": "zorglub"}, follow_redirects=True "/demo/members/add", data={"name": "zorglub"}, follow_redirects=True
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert f"Participant {em_surround('zorglub')} added" in resp.data.decode(
f"Participant {em_surround('zorglub')} added", resp.data.decode("utf-8") "utf-8"
) )
# create a bill # create a bill
@ -374,13 +375,12 @@ class HistoryTestCase(IhatemoneyTestCase):
}, },
follow_redirects=True, follow_redirects=True,
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert f"Bill {em_surround('fromage à raclette')} added" in resp.data.decode(
f"Bill {em_surround('fromage à raclette')} added", "utf-8"
resp.data.decode("utf-8"),
) )
# edit the bill # edit the bill
@ -395,44 +395,37 @@ class HistoryTestCase(IhatemoneyTestCase):
}, },
follow_redirects=True, follow_redirects=True,
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert f"Bill {em_surround('fromage à raclette')} added" in resp.data.decode(
f"Bill {em_surround('fromage à raclette')} added", "utf-8"
resp.data.decode("utf-8"),
) )
self.assertRegex( assert re.search(
resp.data.decode("utf-8"),
r"Bill %s:\s* Amount changed\s* from %s\s* to %s" r"Bill %s:\s* Amount changed\s* from %s\s* to %s"
% ( % (
em_surround("fromage à raclette", regex_escape=True), em_surround("fromage à raclette", regex_escape=True),
em_surround("25.0", regex_escape=True), em_surround("25.0", regex_escape=True),
em_surround("10.0", regex_escape=True), em_surround("10.0", regex_escape=True),
), ),
)
self.assertIn(
"Bill %s renamed to %s"
% (em_surround("fromage à raclette"), em_surround("new thing")),
resp.data.decode("utf-8"), resp.data.decode("utf-8"),
) )
self.assertLess( assert "Bill %s renamed to %s" % (
resp.data.decode("utf-8").index( em_surround("fromage à raclette"),
f"Bill {em_surround('fromage à raclette')} renamed to" em_surround("new thing"),
), ) in resp.data.decode("utf-8")
resp.data.decode("utf-8").index("Amount changed"), assert resp.data.decode("utf-8").index(
) f"Bill {em_surround('fromage à raclette')} renamed to"
) < resp.data.decode("utf-8").index("Amount changed")
# delete the bill # delete the bill
resp = self.client.post("/demo/delete/1", follow_redirects=True) resp = self.client.post("/demo/delete/1", follow_redirects=True)
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert f"Bill {em_surround('new thing')} removed" in resp.data.decode("utf-8")
f"Bill {em_surround('new thing')} removed", resp.data.decode("utf-8")
)
# edit user # edit user
resp = self.client.post( resp = self.client.post(
@ -440,39 +433,35 @@ class HistoryTestCase(IhatemoneyTestCase):
data={"weight": 2, "name": "new name"}, data={"weight": 2, "name": "new name"},
follow_redirects=True, follow_redirects=True,
) )
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertRegex( assert re.search(
resp.data.decode("utf-8"),
r"Participant %s:\s* weight changed\s* from %s\s* to %s" r"Participant %s:\s* weight changed\s* from %s\s* to %s"
% ( % (
em_surround("zorglub", regex_escape=True), em_surround("zorglub", regex_escape=True),
em_surround("1.0", regex_escape=True), em_surround("1.0", regex_escape=True),
em_surround("2.0", regex_escape=True), em_surround("2.0", regex_escape=True),
), ),
)
self.assertIn(
"Participant %s renamed to %s"
% (em_surround("zorglub"), em_surround("new name")),
resp.data.decode("utf-8"), resp.data.decode("utf-8"),
) )
self.assertLess( assert "Participant %s renamed to %s" % (
resp.data.decode("utf-8").index( em_surround("zorglub"),
f"Participant {em_surround('zorglub')} renamed" em_surround("new name"),
), ) in resp.data.decode("utf-8")
resp.data.decode("utf-8").index("weight changed"), assert resp.data.decode("utf-8").index(
) f"Participant {em_surround('zorglub')} renamed"
) < resp.data.decode("utf-8").index("weight changed")
# delete user using POST method # delete user using POST method
resp = self.client.post("/demo/members/1/delete", follow_redirects=True) resp = self.client.post("/demo/members/1/delete", follow_redirects=True)
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertIn( assert f"Participant {em_surround('new name')} removed" in resp.data.decode(
f"Participant {em_surround('new name')} removed", resp.data.decode("utf-8") "utf-8"
) )
def test_double_bill_double_person_edit_second(self): def test_double_bill_double_person_edit_second(self):
@ -504,9 +493,9 @@ class HistoryTestCase(IhatemoneyTestCase):
# Should be 5 history entries at this point # Should be 5 history entries at this point
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 5) assert resp.data.decode("utf-8").count("<td> -- </td>") == 5
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) assert "127.0.0.1" not in resp.data.decode("utf-8")
# Edit ONLY the amount on the first bill # Edit ONLY the amount on the first bill
self.client.post( self.client.post(
@ -521,28 +510,27 @@ class HistoryTestCase(IhatemoneyTestCase):
) )
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertRegex( assert re.search(
resp.data.decode("utf-8"),
r"Bill {}:\s* Amount changed\s* from {}\s* to {}".format( r"Bill {}:\s* Amount changed\s* from {}\s* to {}".format(
em_surround("Bill 1", regex_escape=True), em_surround("Bill 1", regex_escape=True),
em_surround("25.0", regex_escape=True), em_surround("25.0", regex_escape=True),
em_surround("88.0", regex_escape=True), em_surround("88.0", regex_escape=True),
), ),
resp.data.decode("utf-8"),
) )
self.assertNotRegex( assert not re.search(
resp.data.decode("utf-8"),
r"Removed\s* {}\s* and\s* {}\s* from\s* owers list".format( r"Removed\s* {}\s* and\s* {}\s* from\s* owers list".format(
em_surround("User 1", regex_escape=True), em_surround("User 1", regex_escape=True),
em_surround("User 2", regex_escape=True), em_surround("User 2", regex_escape=True),
), ),
resp.data.decode("utf-8"), resp.data.decode("utf-8"),
) ), resp.data.decode("utf-8")
# Should be 6 history entries at this point # Should be 6 history entries at this point
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 6) assert resp.data.decode("utf-8").count("<td> -- </td>") == 6
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) assert "127.0.0.1" not in resp.data.decode("utf-8")
def test_bill_add_remove_add(self): def test_bill_add_remove_add(self):
# add two members # add two members
@ -565,13 +553,11 @@ class HistoryTestCase(IhatemoneyTestCase):
self.client.post("/demo/delete/1", follow_redirects=True) self.client.post("/demo/delete/1", follow_redirects=True)
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 5) assert resp.data.decode("utf-8").count("<td> -- </td>") == 5
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) assert "127.0.0.1" not in resp.data.decode("utf-8")
self.assertIn(f"Bill {em_surround('Bill 1')} added", resp.data.decode("utf-8")) assert f"Bill {em_surround('Bill 1')} added" in resp.data.decode("utf-8")
self.assertIn( assert f"Bill {em_surround('Bill 1')} removed" in resp.data.decode("utf-8")
f"Bill {em_surround('Bill 1')} removed", resp.data.decode("utf-8")
)
# Add a new bill # Add a new bill
self.client.post( self.client.post(
@ -586,17 +572,15 @@ class HistoryTestCase(IhatemoneyTestCase):
) )
resp = self.client.get("/demo/history") resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200) assert resp.status_code == 200
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 6) assert resp.data.decode("utf-8").count("<td> -- </td>") == 6
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8")) assert "127.0.0.1" not in resp.data.decode("utf-8")
self.assertIn(f"Bill {em_surround('Bill 1')} added", resp.data.decode("utf-8")) assert f"Bill {em_surround('Bill 1')} added" in resp.data.decode("utf-8")
self.assertEqual( assert (
resp.data.decode("utf-8").count(f"Bill {em_surround('Bill 1')} added"), 1 resp.data.decode("utf-8").count(f"Bill {em_surround('Bill 1')} added") == 1
)
self.assertIn(f"Bill {em_surround('Bill 2')} added", resp.data.decode("utf-8"))
self.assertIn(
f"Bill {em_surround('Bill 1')} removed", resp.data.decode("utf-8")
) )
assert f"Bill {em_surround('Bill 2')} added" in resp.data.decode("utf-8")
assert f"Bill {em_surround('Bill 1')} removed" in resp.data.decode("utf-8")
def test_double_bill_double_person_edit_second_no_web(self): def test_double_bill_double_person_edit_second_no_web(self):
u1 = models.Person(project_id="demo", name="User 1") u1 = models.Person(project_id="demo", name="User 1")
@ -617,7 +601,7 @@ class HistoryTestCase(IhatemoneyTestCase):
models.db.session.commit() models.db.session.commit()
history_list = history.get_history(self.get_project("demo")) history_list = history.get_history(self.get_project("demo"))
self.assertEqual(len(history_list), 5) assert len(history_list) == 5
# Change just the amount # Change just the amount
b1.amount = 5 b1.amount = 5
@ -626,8 +610,8 @@ class HistoryTestCase(IhatemoneyTestCase):
history_list = history.get_history(self.get_project("demo")) history_list = history.get_history(self.get_project("demo"))
for entry in history_list: for entry in history_list:
if "prop_changed" in entry: if "prop_changed" in entry:
self.assertNotIn("owers", entry["prop_changed"]) assert "owers" not in entry["prop_changed"]
self.assertEqual(len(history_list), 6) assert len(history_list) == 6
def test_delete_history_with_project(self): def test_delete_history_with_project(self):
self.post_project("raclette", password="party") self.post_project("raclette", password="party")
@ -659,8 +643,4 @@ class HistoryTestCase(IhatemoneyTestCase):
# History should be equal to project creation # History should be equal to project creation
history_list = history.get_history(self.get_project("raclette")) history_list = history.get_history(self.get_project("raclette"))
self.assertEqual(len(history_list), 1) assert len(history_list) == 1
if __name__ == "__main__":
unittest.main()

View file

@ -1,12 +1,46 @@
import copy import copy
import json import json
import unittest
import pytest
from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
from ihatemoney.utils import list_of_dicts2csv, list_of_dicts2json from ihatemoney.utils import list_of_dicts2csv, list_of_dicts2json
@pytest.fixture
def import_data(request: pytest.FixtureRequest):
data = [
{
"date": "2017-01-01",
"what": "refund",
"amount": 13.33,
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
},
{
"date": "2016-12-31",
"what": "red wine",
"amount": 200.0,
"payer_name": "fred",
"payer_weight": 1.0,
"owers": ["zorglub", "tata"],
},
{
"date": "2016-12-31",
"what": "fromage a raclette",
"amount": 10.0,
"payer_name": "zorglub",
"payer_weight": 2.0,
"owers": ["zorglub", "fred", "tata", "pepe"],
},
]
request.cls.data = data
yield data
class CommonTestCase(object): class CommonTestCase(object):
@pytest.mark.usefixtures("import_data")
class Import(IhatemoneyTestCase): class Import(IhatemoneyTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -55,7 +89,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills() bills = project.get_pretty_bills()
# Check if all bills have been added # Check if all bills have been added
self.assertEqual(len(bills), len(self.data)) assert len(bills) == len(self.data)
# Check if name of bills are ok # Check if name of bills are ok
b = [e["what"] for e in bills] b = [e["what"] for e in bills]
@ -63,22 +97,22 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data] ref = [e["what"] for e in self.data]
ref.sort() ref.sort()
self.assertEqual(b, ref) assert b == ref
# Check if other informations in bill are ok # Check if other informations in bill are ok
for d in self.data: for d in self.data:
for b in bills: for b in bills:
if b["what"] == d["what"]: if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"]) assert b["payer_name"] == d["payer_name"]
self.assertEqual(b["amount"], d["amount"]) assert b["amount"] == d["amount"]
self.assertEqual(b["currency"], d["currency"]) assert b["currency"] == d["currency"]
self.assertEqual(b["payer_weight"], d["payer_weight"]) assert b["payer_weight"] == d["payer_weight"]
self.assertEqual(b["date"], d["date"]) assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]] list_project = [ower for ower in b["owers"]]
list_project.sort() list_project.sort()
list_json = [ower for ower in d["owers"]] list_json = [ower for ower in d["owers"]]
list_json.sort() list_json.sort()
self.assertEqual(list_project, list_json) assert list_project == list_json
def test_import_single_currency_in_empty_project_without_currency(self): def test_import_single_currency_in_empty_project_without_currency(self):
# Import JSON with a single currency in an empty project with no # Import JSON with a single currency in an empty project with no
@ -96,7 +130,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills() bills = project.get_pretty_bills()
# Check if all bills have been added # Check if all bills have been added
self.assertEqual(len(bills), len(self.data)) assert len(bills) == len(self.data)
# Check if name of bills are ok # Check if name of bills are ok
b = [e["what"] for e in bills] b = [e["what"] for e in bills]
@ -104,23 +138,23 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data] ref = [e["what"] for e in self.data]
ref.sort() ref.sort()
self.assertEqual(b, ref) assert b == ref
# Check if other informations in bill are ok # Check if other informations in bill are ok
for d in self.data: for d in self.data:
for b in bills: for b in bills:
if b["what"] == d["what"]: if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"]) assert b["payer_name"] == d["payer_name"]
self.assertEqual(b["amount"], d["amount"]) assert b["amount"] == d["amount"]
# Currency should have been stripped # Currency should have been stripped
self.assertEqual(b["currency"], "XXX") assert b["currency"] == "XXX"
self.assertEqual(b["payer_weight"], d["payer_weight"]) assert b["payer_weight"] == d["payer_weight"]
self.assertEqual(b["date"], d["date"]) assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]] list_project = [ower for ower in b["owers"]]
list_project.sort() list_project.sort()
list_json = [ower for ower in d["owers"]] list_json = [ower for ower in d["owers"]]
list_json.sort() list_json.sort()
self.assertEqual(list_project, list_json) assert list_project == list_json
def test_import_multiple_currencies_in_empty_project_without_currency(self): def test_import_multiple_currencies_in_empty_project_without_currency(self):
# Import JSON with multiple currencies in an empty project with no # Import JSON with multiple currencies in an empty project with no
@ -138,7 +172,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills() bills = project.get_pretty_bills()
# Check that there are no bills # Check that there are no bills
self.assertEqual(len(bills), 0) assert len(bills) == 0
def test_import_no_currency_in_empty_project_with_currency(self): def test_import_no_currency_in_empty_project_with_currency(self):
# Import JSON without currencies (from ihatemoney < 5) in an empty # Import JSON without currencies (from ihatemoney < 5) in an empty
@ -154,7 +188,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills() bills = project.get_pretty_bills()
# Check if all bills have been added # Check if all bills have been added
self.assertEqual(len(bills), len(self.data)) assert len(bills) == len(self.data)
# Check if name of bills are ok # Check if name of bills are ok
b = [e["what"] for e in bills] b = [e["what"] for e in bills]
@ -162,23 +196,23 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data] ref = [e["what"] for e in self.data]
ref.sort() ref.sort()
self.assertEqual(b, ref) assert b == ref
# Check if other informations in bill are ok # Check if other informations in bill are ok
for d in self.data: for d in self.data:
for b in bills: for b in bills:
if b["what"] == d["what"]: if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"]) assert b["payer_name"] == d["payer_name"]
self.assertEqual(b["amount"], d["amount"]) assert b["amount"] == d["amount"]
# All bills are converted to default project currency # All bills are converted to default project currency
self.assertEqual(b["currency"], "EUR") assert b["currency"] == "EUR"
self.assertEqual(b["payer_weight"], d["payer_weight"]) assert b["payer_weight"] == d["payer_weight"]
self.assertEqual(b["date"], d["date"]) assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]] list_project = [ower for ower in b["owers"]]
list_project.sort() list_project.sort()
list_json = [ower for ower in d["owers"]] list_json = [ower for ower in d["owers"]]
list_json.sort() list_json.sort()
self.assertEqual(list_project, list_json) assert list_project == list_json
def test_import_no_currency_in_empty_project_without_currency(self): def test_import_no_currency_in_empty_project_without_currency(self):
# Import JSON without currencies (from ihatemoney < 5) in an empty # Import JSON without currencies (from ihatemoney < 5) in an empty
@ -194,7 +228,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills() bills = project.get_pretty_bills()
# Check if all bills have been added # Check if all bills have been added
self.assertEqual(len(bills), len(self.data)) assert len(bills) == len(self.data)
# Check if name of bills are ok # Check if name of bills are ok
b = [e["what"] for e in bills] b = [e["what"] for e in bills]
@ -202,22 +236,22 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data] ref = [e["what"] for e in self.data]
ref.sort() ref.sort()
self.assertEqual(b, ref) assert b == ref
# Check if other informations in bill are ok # Check if other informations in bill are ok
for d in self.data: for d in self.data:
for b in bills: for b in bills:
if b["what"] == d["what"]: if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"]) assert b["payer_name"] == d["payer_name"]
self.assertEqual(b["amount"], d["amount"]) assert b["amount"] == d["amount"]
self.assertEqual(b["currency"], "XXX") assert b["currency"] == "XXX"
self.assertEqual(b["payer_weight"], d["payer_weight"]) assert b["payer_weight"] == d["payer_weight"]
self.assertEqual(b["date"], d["date"]) assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]] list_project = [ower for ower in b["owers"]]
list_project.sort() list_project.sort()
list_json = [ower for ower in d["owers"]] list_json = [ower for ower in d["owers"]]
list_json.sort() list_json.sort()
self.assertEqual(list_project, list_json) assert list_project == list_json
def test_import_partial_project(self): def test_import_partial_project(self):
# Import a JSON in a project with already existing data # Import a JSON in a project with already existing data
@ -250,7 +284,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills() bills = project.get_pretty_bills()
# Check if all bills have been added # Check if all bills have been added
self.assertEqual(len(bills), len(self.data)) assert len(bills) == len(self.data)
# Check if name of bills are ok # Check if name of bills are ok
b = [e["what"] for e in bills] b = [e["what"] for e in bills]
@ -258,22 +292,22 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data] ref = [e["what"] for e in self.data]
ref.sort() ref.sort()
self.assertEqual(b, ref) assert b == ref
# Check if other informations in bill are ok # Check if other informations in bill are ok
for d in self.data: for d in self.data:
for b in bills: for b in bills:
if b["what"] == d["what"]: if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"]) assert b["payer_name"] == d["payer_name"]
self.assertEqual(b["amount"], d["amount"]) assert b["amount"] == d["amount"]
self.assertEqual(b["currency"], d["currency"]) assert b["currency"] == d["currency"]
self.assertEqual(b["payer_weight"], d["payer_weight"]) assert b["payer_weight"] == d["payer_weight"]
self.assertEqual(b["date"], d["date"]) assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]] list_project = [ower for ower in b["owers"]]
list_project.sort() list_project.sort()
list_json = [ower for ower in d["owers"]] list_json = [ower for ower in d["owers"]]
list_json.sort() list_json.sort()
self.assertEqual(list_project, list_json) assert list_project == list_json
def test_import_wrong_data(self): def test_import_wrong_data(self):
self.post_project("raclette") self.post_project("raclette")
@ -302,7 +336,7 @@ class CommonTestCase(object):
self.import_project("raclette", self.generate_form_data(data), 400) self.import_project("raclette", self.generate_form_data(data), 400)
class ExportTestCase(IhatemoneyTestCase): class TestExport(IhatemoneyTestCase):
def test_export(self): def test_export(self):
# Export a simple project without currencies # Export a simple project without currencies
@ -379,7 +413,7 @@ class ExportTestCase(IhatemoneyTestCase):
"owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"], "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"],
}, },
] ]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of bills # generate csv export of bills
resp = self.client.get("/raclette/export/bills.csv") resp = self.client.get("/raclette/export/bills.csv")
@ -392,9 +426,7 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n") received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected): for i, line in enumerate(expected):
self.assertEqual( assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# generate json export of transactions # generate json export of transactions
resp = self.client.get("/raclette/export/transactions.json") resp = self.client.get("/raclette/export/transactions.json")
@ -414,7 +446,7 @@ class ExportTestCase(IhatemoneyTestCase):
}, },
] ]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of transactions # generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv") resp = self.client.get("/raclette/export/transactions.csv")
@ -428,13 +460,11 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n") received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected): for i, line in enumerate(expected):
self.assertEqual( assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# wrong export_format should return a 404 # wrong export_format should return a 404
resp = self.client.get("/raclette/export/transactions.wrong") resp = self.client.get("/raclette/export/transactions.wrong")
self.assertEqual(resp.status_code, 404) assert resp.status_code == 404
def test_export_with_currencies(self): def test_export_with_currencies(self):
self.post_project("raclette", default_currency="EUR") self.post_project("raclette", default_currency="EUR")
@ -513,7 +543,7 @@ class ExportTestCase(IhatemoneyTestCase):
"owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"], "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"],
}, },
] ]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of bills # generate csv export of bills
resp = self.client.get("/raclette/export/bills.csv") resp = self.client.get("/raclette/export/bills.csv")
@ -526,9 +556,7 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n") received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected): for i, line in enumerate(expected):
self.assertEqual( assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# generate json export of transactions (in EUR!) # generate json export of transactions (in EUR!)
resp = self.client.get("/raclette/export/transactions.json") resp = self.client.get("/raclette/export/transactions.json")
@ -543,7 +571,7 @@ class ExportTestCase(IhatemoneyTestCase):
{"amount": 38.45, "currency": "EUR", "receiver": "fred", "ower": "zorglub"}, {"amount": 38.45, "currency": "EUR", "receiver": "fred", "ower": "zorglub"},
] ]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of transactions # generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv") resp = self.client.get("/raclette/export/transactions.csv")
@ -557,9 +585,7 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n") received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected): for i, line in enumerate(expected):
self.assertEqual( assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# Change project currency to CAD # Change project currency to CAD
project = self.get_project("raclette") project = self.get_project("raclette")
@ -578,7 +604,7 @@ class ExportTestCase(IhatemoneyTestCase):
{"amount": 57.67, "currency": "CAD", "receiver": "fred", "ower": "zorglub"}, {"amount": 57.67, "currency": "CAD", "receiver": "fred", "ower": "zorglub"},
] ]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of transactions # generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv") resp = self.client.get("/raclette/export/transactions.csv")
@ -592,9 +618,7 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n") received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected): for i, line in enumerate(expected):
self.assertEqual( assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
def test_export_escape_formulae(self): def test_export_escape_formulae(self):
self.post_project("raclette", default_currency="EUR") self.post_project("raclette", default_currency="EUR")
@ -624,23 +648,17 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n") received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected): for i, line in enumerate(expected):
self.assertEqual( assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
class ImportTestCaseJSON(CommonTestCase.Import): class TestImportJSON(CommonTestCase.Import):
def generate_form_data(self, data): def generate_form_data(self, data):
return {"file": (list_of_dicts2json(data), "test.json")} return {"file": (list_of_dicts2json(data), "test.json")}
class ImportTestCaseCSV(CommonTestCase.Import): class TestImportCSV(CommonTestCase.Import):
def generate_form_data(self, data): def generate_form_data(self, data):
formatted_data = copy.deepcopy(data) formatted_data = copy.deepcopy(data)
for d in formatted_data: for d in formatted_data:
d["owers"] = ", ".join([o for o in d.get("owers", [])]) d["owers"] = ", ".join([o for o in d.get("owers", [])])
return {"file": (list_of_dicts2csv(formatted_data), "test.csv")} return {"file": (list_of_dicts2csv(formatted_data), "test.csv")}
if __name__ == "__main__":
unittest.main()

View file

@ -1,9 +1,9 @@
import os import os
import smtplib import smtplib
import socket import socket
import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import orm from sqlalchemy import orm
from werkzeug.security import check_password_hash from werkzeug.security import check_password_hash
@ -19,19 +19,18 @@ os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
__HERE__ = os.path.dirname(os.path.abspath(__file__)) __HERE__ = os.path.dirname(os.path.abspath(__file__))
class ConfigurationTestCase(BaseTestCase): class TestConfiguration(BaseTestCase):
def test_default_configuration(self): def test_default_configuration(self):
"""Test that default settings are loaded when no other configuration file is specified""" """Test that default settings are loaded when no other configuration file is specified"""
self.assertFalse(self.app.config["DEBUG"]) assert not self.app.config["DEBUG"]
self.assertFalse(self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]) assert not self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]
self.assertEqual( assert self.app.config["MAIL_DEFAULT_SENDER"] == (
self.app.config["MAIL_DEFAULT_SENDER"], "Budget manager <admin@example.com>"
("Budget manager <admin@example.com>"),
) )
self.assertTrue(self.app.config["ACTIVATE_DEMO_PROJECT"]) assert self.app.config["ACTIVATE_DEMO_PROJECT"]
self.assertTrue(self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"]) assert self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"]
self.assertFalse(self.app.config["ACTIVATE_ADMIN_DASHBOARD"]) assert not self.app.config["ACTIVATE_ADMIN_DASHBOARD"]
self.assertFalse(self.app.config["ENABLE_CAPTCHA"]) assert not self.app.config["ENABLE_CAPTCHA"]
def test_env_var_configuration_file(self): def test_env_var_configuration_file(self):
"""Test that settings are loaded from a configuration file specified """Test that settings are loaded from a configuration file specified
@ -40,7 +39,7 @@ class ConfigurationTestCase(BaseTestCase):
__HERE__, "ihatemoney_envvar.cfg" __HERE__, "ihatemoney_envvar.cfg"
) )
load_configuration(self.app) load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "lalatra") assert self.app.config["SECRET_KEY"] == "lalatra"
# Test that the specified configuration file is loaded # Test that the specified configuration file is loaded
# even if the default configuration file ihatemoney.cfg exists # even if the default configuration file ihatemoney.cfg exists
@ -50,7 +49,7 @@ class ConfigurationTestCase(BaseTestCase):
) )
self.app.config.root_path = __HERE__ self.app.config.root_path = __HERE__
load_configuration(self.app) load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "lalatra") assert self.app.config["SECRET_KEY"] == "lalatra"
os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None) os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
@ -59,10 +58,10 @@ class ConfigurationTestCase(BaseTestCase):
in the current directory.""" in the current directory."""
self.app.config.root_path = __HERE__ self.app.config.root_path = __HERE__
load_configuration(self.app) load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "supersecret") assert self.app.config["SECRET_KEY"] == "supersecret"
class ServerTestCase(IhatemoneyTestCase): class TestServer(IhatemoneyTestCase):
def test_homepage(self): def test_homepage(self):
# See https://github.com/spiral-project/ihatemoney/pull/358 # See https://github.com/spiral-project/ihatemoney/pull/358
self.app.config["APPLICATION_ROOT"] = "/" self.app.config["APPLICATION_ROOT"] = "/"
@ -80,7 +79,7 @@ class ServerTestCase(IhatemoneyTestCase):
self.assertStatus(200, req) self.assertStatus(200, req)
class CommandTestCase(BaseTestCase): class TestCommand(BaseTestCase):
def test_generate_config(self): def test_generate_config(self):
"""Simply checks that all config file generation """Simply checks that all config file generation
- raise no exception - raise no exception
@ -89,25 +88,25 @@ class CommandTestCase(BaseTestCase):
runner = self.app.test_cli_runner() runner = self.app.test_cli_runner()
for config_file in generate_config.params[0].type.choices: for config_file in generate_config.params[0].type.choices:
result = runner.invoke(generate_config, config_file) result = runner.invoke(generate_config, config_file)
self.assertNotEqual(len(result.output.strip()), 0) assert len(result.output.strip()) != 0
def test_generate_password_hash(self): def test_generate_password_hash(self):
runner = self.app.test_cli_runner() runner = self.app.test_cli_runner()
with patch("getpass.getpass", new=lambda prompt: "secret"): with patch("getpass.getpass", new=lambda prompt: "secret"):
result = runner.invoke(password_hash) result = runner.invoke(password_hash)
self.assertTrue(check_password_hash(result.output.strip(), "secret")) assert check_password_hash(result.output.strip(), "secret")
def test_demo_project_deletion(self): def test_demo_project_deletion(self):
self.create_project("demo") self.create_project("demo")
self.assertEqual(self.get_project("demo").name, "demo") assert self.get_project("demo").name == "demo"
runner = self.app.test_cli_runner() runner = self.app.test_cli_runner()
runner.invoke(delete_project, "demo") runner.invoke(delete_project, "demo")
self.assertEqual(len(models.Project.query.all()), 0) assert len(models.Project.query.all()) == 0
class ModelsTestCase(IhatemoneyTestCase): class TestModels(IhatemoneyTestCase):
def test_weighted_bills(self): def test_weighted_bills(self):
"""Test the SQL request that fetch all bills and weights""" """Test the SQL request that fetch all bills and weights"""
self.post_project("raclette") self.post_project("raclette")
@ -156,13 +155,13 @@ class ModelsTestCase(IhatemoneyTestCase):
for weight, bill in project.get_bill_weights().all(): for weight, bill in project.get_bill_weights().all():
if bill.what == "red wine": if bill.what == "red wine":
pay_each_expected = 20 / 2 pay_each_expected = 20 / 2
self.assertEqual(bill.amount / weight, pay_each_expected) assert bill.amount / weight == pay_each_expected
if bill.what == "fromage à raclette": if bill.what == "fromage à raclette":
pay_each_expected = 10 / 4 pay_each_expected = 10 / 4
self.assertEqual(bill.amount / weight, pay_each_expected) assert bill.amount / weight == pay_each_expected
if bill.what == "delicatessen": if bill.what == "delicatessen":
pay_each_expected = 10 / 3 pay_each_expected = 10 / 3
self.assertEqual(bill.amount / weight, pay_each_expected) assert bill.amount / weight == pay_each_expected
def test_bill_pay_each(self): def test_bill_pay_each(self):
self.post_project("raclette") self.post_project("raclette")
@ -216,16 +215,16 @@ class ModelsTestCase(IhatemoneyTestCase):
for bill in zorglub_bills.all(): for bill in zorglub_bills.all():
if bill.what == "red wine": if bill.what == "red wine":
pay_each_expected = 20 / 2 pay_each_expected = 20 / 2
self.assertEqual(bill.pay_each(), pay_each_expected) assert bill.pay_each() == pay_each_expected
if bill.what == "fromage à raclette": if bill.what == "fromage à raclette":
pay_each_expected = 10 / 4 pay_each_expected = 10 / 4
self.assertEqual(bill.pay_each(), pay_each_expected) assert bill.pay_each() == pay_each_expected
if bill.what == "delicatessen": if bill.what == "delicatessen":
pay_each_expected = 10 / 3 pay_each_expected = 10 / 3
self.assertEqual(bill.pay_each(), pay_each_expected) assert bill.pay_each() == pay_each_expected
class EmailFailureTestCase(IhatemoneyTestCase): class TestEmailFailure(IhatemoneyTestCase):
def test_creation_email_failure_smtp(self): def test_creation_email_failure_smtp(self):
self.login("raclette") self.login("raclette")
with patch.object( with patch.object(
@ -233,14 +232,14 @@ class EmailFailureTestCase(IhatemoneyTestCase):
): ):
resp = self.post_project("raclette") resp = self.post_project("raclette")
# Check that an error message is displayed # Check that an error message is displayed
self.assertIn( assert (
"We tried to send you an reminder email, but there was an error", "We tried to send you an reminder email, but there was an error"
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
# Check that we were redirected to the home page anyway # Check that we were redirected to the home page anyway
self.assertIn( assert (
'<a href="/raclette/members/add">Add the first participant', '<a href="/raclette/members/add">Add the first participant'
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
def test_creation_email_failure_socket(self): def test_creation_email_failure_socket(self):
@ -248,14 +247,14 @@ class EmailFailureTestCase(IhatemoneyTestCase):
with patch.object(self.app.mail, "send", MagicMock(side_effect=socket.error)): with patch.object(self.app.mail, "send", MagicMock(side_effect=socket.error)):
resp = self.post_project("raclette") resp = self.post_project("raclette")
# Check that an error message is displayed # Check that an error message is displayed
self.assertIn( assert (
"We tried to send you an reminder email, but there was an error", "We tried to send you an reminder email, but there was an error"
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
# Check that we were redirected to the home page anyway # Check that we were redirected to the home page anyway
self.assertIn( assert (
'<a href="/raclette/members/add">Add the first participant', '<a href="/raclette/members/add">Add the first participant'
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
def test_password_reset_email_failure(self): def test_password_reset_email_failure(self):
@ -266,14 +265,13 @@ class EmailFailureTestCase(IhatemoneyTestCase):
"/password-reminder", data={"id": "raclette"}, follow_redirects=True "/password-reminder", data={"id": "raclette"}, follow_redirects=True
) )
# Check that an error message is displayed # Check that an error message is displayed
self.assertIn( assert "there was an error while sending you an email" in resp.data.decode(
"there was an error while sending you an email", "utf-8"
resp.data.decode("utf-8"),
) )
# Check that we were not redirected to the success page # Check that we were not redirected to the success page
self.assertNotIn( assert (
"A link to reset your password has been sent to you", "A link to reset your password has been sent to you"
resp.data.decode("utf-8"), not in resp.data.decode("utf-8")
) )
def test_invitation_email_failure(self): def test_invitation_email_failure(self):
@ -287,17 +285,15 @@ class EmailFailureTestCase(IhatemoneyTestCase):
follow_redirects=True, follow_redirects=True,
) )
# Check that an error message is displayed # Check that an error message is displayed
self.assertIn( assert (
"there was an error while trying to send the invitation emails", "there was an error while trying to send the invitation emails"
resp.data.decode("utf-8"), in resp.data.decode("utf-8")
) )
# Check that we are still on the same page (no redirection) # Check that we are still on the same page (no redirection)
self.assertIn( assert "Invite people to join this project" in resp.data.decode("utf-8")
"Invite people to join this project", resp.data.decode("utf-8")
)
class CaptchaTestCase(IhatemoneyTestCase): class TestCaptcha(IhatemoneyTestCase):
ENABLE_CAPTCHA = True ENABLE_CAPTCHA = True
def test_project_creation_with_captcha_case_insensitive(self): def test_project_creation_with_captcha_case_insensitive(self):
@ -315,7 +311,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
"captcha": "éùüß", "captcha": "éùüß",
}, },
) )
self.assertEqual(len(models.Project.query.all()), 1) assert len(models.Project.query.all()) == 1
def test_project_creation_with_captcha(self): def test_project_creation_with_captcha(self):
with self.client as c: with self.client as c:
@ -329,7 +325,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
"default_currency": "USD", "default_currency": "USD",
}, },
) )
self.assertEqual(len(models.Project.query.all()), 0) assert len(models.Project.query.all()) == 0
c.post( c.post(
"/create", "/create",
@ -342,7 +338,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
"captcha": "nope", "captcha": "nope",
}, },
) )
self.assertEqual(len(models.Project.query.all()), 0) assert len(models.Project.query.all()) == 0
c.post( c.post(
"/create", "/create",
@ -355,7 +351,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
"captcha": "euro", "captcha": "euro",
}, },
) )
self.assertEqual(len(models.Project.query.all()), 1) assert len(models.Project.query.all()) == 1
def test_api_project_creation_does_not_need_captcha(self): def test_api_project_creation_does_not_need_captcha(self):
self.client.get("/") self.client.get("/")
@ -368,11 +364,11 @@ class CaptchaTestCase(IhatemoneyTestCase):
"contact_email": "raclette@notmyidea.org", "contact_email": "raclette@notmyidea.org",
}, },
) )
self.assertTrue(resp.status, 201) assert resp.status_code == 201
self.assertEqual(len(models.Project.query.all()), 1) assert len(models.Project.query.all()) == 1
class TestCurrencyConverter(unittest.TestCase): class TestCurrencyConverter:
converter = CurrencyConverter() converter = CurrencyConverter()
mock_data = { mock_data = {
"USD": 1, "USD": 1,
@ -386,28 +382,23 @@ class TestCurrencyConverter(unittest.TestCase):
def test_only_one_instance(self): def test_only_one_instance(self):
one = id(CurrencyConverter()) one = id(CurrencyConverter())
two = id(CurrencyConverter()) two = id(CurrencyConverter())
self.assertEqual(one, two) assert one == two
def test_get_currencies(self): def test_get_currencies(self):
self.assertCountEqual( assert set(self.converter.get_currencies()) == set(
self.converter.get_currencies(), ["USD", "EUR", "CAD", "PLN", CurrencyConverter.no_currency]
["USD", "EUR", "CAD", "PLN", CurrencyConverter.no_currency],
) )
def test_exchange_currency(self): def test_exchange_currency(self):
result = self.converter.exchange_currency(100, "USD", "EUR") result = self.converter.exchange_currency(100, "USD", "EUR")
self.assertEqual(result, 80.0) assert result == 80.0
def test_failing_remote(self): def test_failing_remote(self):
rates = {} rates = {}
with patch("requests.Response.json", new=lambda _: {}), self.assertWarns( with patch("requests.Response.json", new=lambda _: {}), pytest.warns(
UserWarning UserWarning
): ):
# we need a non-patched converter, but it seems that MagickMock # we need a non-patched converter, but it seems that MagickMock
# is mocking EVERY instance of the class method. Too bad. # is mocking EVERY instance of the class method. Too bad.
rates = CurrencyConverter.get_rates(self.converter) rates = CurrencyConverter.get_rates(self.converter)
self.assertDictEqual(rates, {CurrencyConverter.no_currency: 1}) assert rates == {CurrencyConverter.no_currency: 1}
if __name__ == "__main__":
unittest.main()

View file

@ -59,8 +59,8 @@ dev =
flake8==5.0.4 flake8==5.0.4
isort==5.11.5 isort==5.11.5
vermin==1.5.2 vermin==1.5.2
Flask-Testing>=0.8.1
pytest>=6.2.5 pytest>=6.2.5
pytest-flask>=1.2.0
pytest-libfaketime>=0.1.2 pytest-libfaketime>=0.1.2
tox>=3.14.6 tox>=3.14.6
zest.releaser>=6.20.1 zest.releaser>=6.20.1