tests: migrate to pytest

- replace setUp/tearDown with pytest fixtures
- rename test classes to use the pytest convention
- use pytest assertions

Co-authored-by: Glandos <bugs-github@antipoul.fr>
This commit is contained in:
Éloi Rivard 2023-08-12 13:09:28 +02:00 committed by zorun
parent 2ce76158d2
commit 21408f8bc9
9 changed files with 741 additions and 766 deletions

3
.gitignore vendored
View file

@ -16,4 +16,5 @@ ihatemoney/budget.db
.DS_Store
.idea
.python-version
.coverage*
prof

View file

@ -1,13 +1,12 @@
import base64
import datetime
import json
import unittest
from ihatemoney.tests.common.help_functions import em_surround
from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
class APITestCase(IhatemoneyTestCase):
class TestAPI(IhatemoneyTestCase):
"""Tests the API"""
@ -57,7 +56,7 @@ class APITestCase(IhatemoneyTestCase):
resp = self.client.options(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertEqual(resp.headers["Access-Control-Allow-Origin"], "*")
assert resp.headers["Access-Control-Allow-Origin"] == "*"
def test_basic_auth(self):
# create a project
@ -94,32 +93,32 @@ class APITestCase(IhatemoneyTestCase):
},
)
self.assertEqual(400, resp.status_code)
self.assertEqual(
'{"contact_email": ["Invalid email address."]}\n', resp.data.decode("utf-8")
assert 400 == resp.status_code
assert '{"contact_email": ["Invalid email address."]}\n' == resp.data.decode(
"utf-8"
)
# create it
with self.app.mail.record_messages() as outbox:
resp = self.api_create("raclette")
self.assertEqual(201, resp.status_code)
assert 201 == resp.status_code
# Check that email messages have been sent.
self.assertEqual(len(outbox), 1)
self.assertEqual(outbox[0].recipients, ["raclette@notmyidea.org"])
assert len(outbox) == 1
assert outbox[0].recipients == ["raclette@notmyidea.org"]
# create it twice should return a 400
resp = self.api_create("raclette")
self.assertEqual(400, resp.status_code)
self.assertIn("id", json.loads(resp.data.decode("utf-8")))
assert 400 == resp.status_code
assert "id" in json.loads(resp.data.decode("utf-8"))
# get information about it
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
expected = {
"members": [],
"name": "raclette",
@ -129,7 +128,7 @@ class APITestCase(IhatemoneyTestCase):
"logging_preference": 1,
}
decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected)
assert decoded_resp == expected
# edit should fail if we don't provide the current private code
resp = self.client.put(
@ -143,7 +142,7 @@ class APITestCase(IhatemoneyTestCase):
},
headers=self.get_auth("raclette"),
)
self.assertEqual(400, resp.status_code)
assert 400 == resp.status_code
# edit should fail if we provide the wrong private code
resp = self.client.put(
@ -158,7 +157,7 @@ class APITestCase(IhatemoneyTestCase):
},
headers=self.get_auth("raclette"),
)
self.assertEqual(400, resp.status_code)
assert 400 == resp.status_code
# edit with the correct private code should work
resp = self.client.put(
@ -173,13 +172,13 @@ class APITestCase(IhatemoneyTestCase):
},
headers=self.get_auth("raclette"),
)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
expected = {
"name": "The raclette party",
"contact_email": "yeah@notmyidea.org",
@ -189,7 +188,7 @@ class APITestCase(IhatemoneyTestCase):
"logging_preference": 1,
}
decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected)
assert decoded_resp == expected
# password change is possible via API
resp = self.client.put(
@ -204,12 +203,12 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"),
)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette", "tartiflette")
)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
# delete should work
resp = self.client.delete(
@ -220,21 +219,21 @@ class APITestCase(IhatemoneyTestCase):
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertEqual(401, resp.status_code)
assert 401 == resp.status_code
def test_token_creation(self):
"""Test that token of project is generated"""
# Create project
resp = self.api_create("raclette")
self.assertEqual(201, resp.status_code)
assert 201 == resp.status_code
# Get token
resp = self.client.get(
"/api/projects/raclette/token", headers=self.get_auth("raclette")
)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
decoded_resp = json.loads(resp.data.decode("utf-8"))
@ -243,7 +242,7 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/token",
headers={"Authorization": f"Basic {decoded_resp['token']}"},
)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
# We shouldn't be able to edit project without private code
resp = self.client.put(
@ -256,9 +255,9 @@ class APITestCase(IhatemoneyTestCase):
},
headers={"Authorization": f"Basic {decoded_resp['token']}"},
)
self.assertEqual(400, resp.status_code)
assert 400 == resp.status_code
expected_resp = {"current_password": ["This field is required."]}
self.assertEqual(expected_resp, json.loads(resp.data.decode("utf-8")))
assert expected_resp == json.loads(resp.data.decode("utf-8"))
def test_token_login(self):
resp = self.api_create("raclette")
@ -269,7 +268,7 @@ class APITestCase(IhatemoneyTestCase):
decoded_resp = json.loads(resp.data.decode("utf-8"))
resp = self.client.get(f"/raclette/join/{decoded_resp['token']}")
# Test that we are redirected.
self.assertEqual(302, resp.status_code)
assert 302 == resp.status_code
def test_member(self):
# create a project
@ -281,7 +280,7 @@ class APITestCase(IhatemoneyTestCase):
)
self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8"))
assert "[]\n" == req.data.decode("utf-8")
# add a member
req = self.client.post(
@ -292,7 +291,7 @@ class APITestCase(IhatemoneyTestCase):
# the id of the new member should be returned
self.assertStatus(201, req)
self.assertEqual("1\n", req.data.decode("utf-8"))
assert "1\n" == req.data.decode("utf-8")
# the list of participants should contain one member
req = self.client.get(
@ -300,7 +299,7 @@ class APITestCase(IhatemoneyTestCase):
)
self.assertStatus(200, req)
self.assertEqual(len(json.loads(req.data.decode("utf-8"))), 1)
assert len(json.loads(req.data.decode("utf-8"))) == 1
# Try to add another member with the same name.
req = self.client.post(
@ -325,8 +324,8 @@ class APITestCase(IhatemoneyTestCase):
)
self.assertStatus(200, req)
self.assertEqual("Fred", json.loads(req.data.decode("utf-8"))["name"])
self.assertEqual(2, json.loads(req.data.decode("utf-8"))["weight"])
assert "Fred" == json.loads(req.data.decode("utf-8"))["name"]
assert 2 == json.loads(req.data.decode("utf-8"))["weight"]
# edit this member with same information
# (test PUT idemopotence)
@ -350,7 +349,7 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/members/1", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(False, json.loads(req.data.decode("utf-8"))["activated"])
assert not json.loads(req.data.decode("utf-8"))["activated"]
# re-activate the participant
req = self.client.put(
@ -363,7 +362,7 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/members/1", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(True, json.loads(req.data.decode("utf-8"))["activated"])
assert json.loads(req.data.decode("utf-8"))["activated"]
# delete a member
@ -379,7 +378,7 @@ class APITestCase(IhatemoneyTestCase):
)
self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8"))
assert "[]\n" == req.data.decode("utf-8")
def test_bills(self):
# create a project
@ -396,7 +395,7 @@ class APITestCase(IhatemoneyTestCase):
)
self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8"))
assert "[]\n" == req.data.decode("utf-8")
# add a bill
req = self.client.post(
@ -414,7 +413,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "1\n")
assert req.data.decode("utf-8") == "1\n"
# get this bill details
req = self.client.get(
@ -439,19 +438,19 @@ class APITestCase(IhatemoneyTestCase):
}
got = json.loads(req.data.decode("utf-8"))
self.assertEqual(
datetime.date.today(),
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(),
assert (
datetime.date.today()
== datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
)
del got["creation_date"]
self.assertDictEqual(expected, got)
assert expected == got
# the list of bills should length 1
req = self.client.get(
"/api/projects/raclette/bills", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(1, len(json.loads(req.data.decode("utf-8"))))
assert 1 == len(json.loads(req.data.decode("utf-8")))
# edit with errors should return an error
req = self.client.put(
@ -468,9 +467,7 @@ class APITestCase(IhatemoneyTestCase):
)
self.assertStatus(400, req)
self.assertEqual(
'{"date": ["This field is required."]}\n', req.data.decode("utf-8")
)
assert '{"date": ["This field is required."]}\n' == req.data.decode("utf-8")
# edit a bill
req = self.client.put(
@ -510,12 +507,12 @@ class APITestCase(IhatemoneyTestCase):
}
got = json.loads(req.data.decode("utf-8"))
self.assertEqual(
creation_date,
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(),
assert (
creation_date
== datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
)
del got["creation_date"]
self.assertDictEqual(expected, got)
assert expected == got
# delete a bill
req = self.client.delete(
@ -562,7 +559,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "{}\n".format(id))
assert req.data.decode("utf-8") == "{}\n".format(id)
# get this bill's details
req = self.client.get(
@ -588,12 +585,12 @@ class APITestCase(IhatemoneyTestCase):
}
got = json.loads(req.data.decode("utf-8"))
self.assertEqual(
datetime.date.today(),
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(),
assert (
datetime.date.today()
== datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
)
del got["creation_date"]
self.assertDictEqual(expected, got)
assert expected == got
# should raise errors
erroneous_amounts = [
@ -621,19 +618,19 @@ class APITestCase(IhatemoneyTestCase):
def test_currencies(self):
# check /currencies for list of supported currencies
resp = self.client.get("/api/currencies")
self.assertEqual(200, resp.status_code)
self.assertIn("XXX", json.loads(resp.data.decode("utf-8")))
assert 200 == resp.status_code
assert "XXX" in json.loads(resp.data.decode("utf-8"))
# create project with a default currency
resp = self.api_create("raclette", default_currency="EUR")
self.assertEqual(201, resp.status_code)
assert 201 == resp.status_code
# get information about it
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertEqual(200, resp.status_code)
assert 200 == resp.status_code
expected = {
"members": [],
"name": "raclette",
@ -643,7 +640,7 @@ class APITestCase(IhatemoneyTestCase):
"logging_preference": 1,
}
decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected)
assert decoded_resp == expected
# Add participants
self.api_add_member("raclette", "zorglub")
@ -666,7 +663,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "1\n")
assert req.data.decode("utf-8") == "1\n"
# get this bill details
req = self.client.get(
@ -691,12 +688,12 @@ class APITestCase(IhatemoneyTestCase):
}
got = json.loads(req.data.decode("utf-8"))
self.assertEqual(
datetime.date.today(),
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(),
assert (
datetime.date.today()
== datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
)
del got["creation_date"]
self.assertDictEqual(expected, got)
assert expected == got
# Change bill amount and currency
req = self.client.put(
@ -737,7 +734,7 @@ class APITestCase(IhatemoneyTestCase):
got = json.loads(req.data.decode("utf-8"))
del got["creation_date"]
self.assertDictEqual(expected, got)
assert expected == got
# Add a bill with yet another currency
req = self.client.post(
@ -755,7 +752,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "2\n")
assert req.data.decode("utf-8") == "2\n"
# Try to remove default project currency, it should fail
req = self.client.put(
@ -770,9 +767,9 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"),
)
self.assertStatus(400, req)
self.assertIn("This project cannot be set", req.data.decode("utf-8"))
self.assertIn(
"because it contains bills in multiple currencies", req.data.decode("utf-8")
assert "This project cannot be set" in req.data.decode("utf-8")
assert "because it contains bills in multiple currencies" in req.data.decode(
"utf-8"
)
def test_statistics(self):
@ -801,33 +798,30 @@ class APITestCase(IhatemoneyTestCase):
"/api/projects/raclette/statistics", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(
[
{
"balance": 12.5,
"member": {
"activated": True,
"id": 1,
"name": "zorglub",
"weight": 1.0,
},
"paid": 25.0,
"spent": 12.5,
assert [
{
"balance": 12.5,
"member": {
"activated": True,
"id": 1,
"name": "zorglub",
"weight": 1.0,
},
{
"balance": -12.5,
"member": {
"activated": True,
"id": 2,
"name": "fred",
"weight": 1.0,
},
"paid": 0,
"spent": 12.5,
"paid": 25.0,
"spent": 12.5,
},
{
"balance": -12.5,
"member": {
"activated": True,
"id": 2,
"name": "fred",
"weight": 1.0,
},
],
json.loads(req.data.decode("utf-8")),
)
"paid": 0,
"spent": 12.5,
},
] == json.loads(req.data.decode("utf-8"))
def test_username_xss(self):
# create a project
@ -839,7 +833,7 @@ class APITestCase(IhatemoneyTestCase):
self.api_add_member("raclette", "<script>")
result = self.client.get("/raclette/")
self.assertNotIn("<script>", result.data.decode("utf-8"))
assert "<script>" not in result.data.decode("utf-8")
def test_weighted_bills(self):
# create a project
@ -888,12 +882,12 @@ class APITestCase(IhatemoneyTestCase):
"original_currency": "XXX",
}
got = json.loads(req.data.decode("utf-8"))
self.assertEqual(
creation_date,
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(),
assert (
creation_date
== datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date()
)
del got["creation_date"]
self.assertDictEqual(expected, got)
assert expected == got
# getting it should return a 404
req = self.client.get(
@ -933,7 +927,7 @@ class APITestCase(IhatemoneyTestCase):
self.assertStatus(200, req)
decoded_req = json.loads(req.data.decode("utf-8"))
self.assertDictEqual(decoded_req, expected)
assert decoded_req == expected
def test_log_created_from_api_call(self):
# create a project
@ -944,15 +938,13 @@ class APITestCase(IhatemoneyTestCase):
self.api_add_member("raclette", "zorglub")
resp = self.client.get("/raclette/history", follow_redirects=True)
self.assertEqual(resp.status_code, 200)
self.assertIn(
f"Participant {em_surround('zorglub')} added", resp.data.decode("utf-8")
assert resp.status_code == 200
assert f"Participant {em_surround('zorglub')} added" in resp.data.decode(
"utf-8"
)
self.assertIn(
f"Project {em_surround('raclette')} added", resp.data.decode("utf-8")
)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
assert f"Project {em_surround('raclette')} added" in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
assert "127.0.0.1" not in resp.data.decode("utf-8")
def test_amount_is_null(self):
self.api_create("raclette")
@ -1000,7 +992,3 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"),
)
self.assertStatus(400, req)
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load diff

View file

@ -1,15 +1,13 @@
import os
from unittest.mock import MagicMock
from flask_testing import TestCase
import pytest
from ihatemoney import models
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.run import create_app, db
from ihatemoney.utils import generate_password_hash
class BaseTestCase(TestCase):
@pytest.mark.usefixtures("client", "converter")
class BaseTestCase:
SECRET_KEY = "TEST SESSION"
SQLALCHEMY_DATABASE_URI = os.environ.get(
"TESTING_SQLALCHEMY_DATABASE_URI", "sqlite://"
@ -18,30 +16,6 @@ class BaseTestCase(TestCase):
PASSWORD_HASH_METHOD = "pbkdf2:sha1:1"
PASSWORD_HASH_SALT_LENGTH = 1
def create_app(self):
# Pass the test object as a configuration.
return create_app(self)
def setUp(self):
db.create_all()
# Add dummy data to CurrencyConverter for all tests (since it's a singleton)
mock_data = {
"USD": 1,
"EUR": 0.8,
"CAD": 1.2,
"PLN": 4,
CurrencyConverter.no_currency: 1,
}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# Also add it to an attribute to make tests clearer
self.converter = converter
def tearDown(self):
# clean after testing
db.session.remove()
db.drop_all()
def login(self, project, password=None, test_client=None):
password = password or project
@ -83,7 +57,7 @@ class BaseTestCase(TestCase):
data=data,
# follow_redirects=True,
)
self.assertEqual("/{id}/edit" in str(resp.response), not success)
assert ("/{id}/edit" in str(resp.response)) == (not success)
def create_project(self, id, default_currency="XXX", name=None, password=None):
name = name or str(id)
@ -109,11 +83,9 @@ class IhatemoneyTestCase(BaseTestCase):
def assertStatus(self, expected, resp, url=None):
if url is None:
url = resp.request.path
return self.assertEqual(
expected,
resp.status_code,
f"{url} expected {expected}, got {resp.status_code}",
)
assert (
expected == resp.status_code
), f"{url} expected {expected}, got {resp.status_code}"
def enable_admin(self, password="adminpass"):
self.app.config["ACTIVATE_ADMIN_DASHBOARD"] = True

View file

@ -0,0 +1,49 @@
from unittest.mock import MagicMock
from flask import Flask
import pytest
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.run import create_app, db
@pytest.fixture
def app(request: pytest.FixtureRequest):
"""Create the Flask app with database"""
app = create_app(request.cls)
with app.app_context():
db.create_all()
request.cls.app = app
yield app
# clean after testing
db.session.remove()
db.drop_all()
@pytest.fixture
def client(app: Flask, request: pytest.FixtureRequest):
client = app.test_client()
request.cls.client = client
yield client
@pytest.fixture
def converter(request: pytest.FixtureRequest):
# Add dummy data to CurrencyConverter for all tests (since it's a singleton)
mock_data = {
"USD": 1,
"EUR": 0.8,
"CAD": 1.2,
"PLN": 4,
CurrencyConverter.no_currency: 1,
}
converter = CurrencyConverter()
converter.get_rates = MagicMock(return_value=mock_data)
# Also add it to an attribute to make tests clearer
request.cls.converter = converter
yield converter

View file

@ -1,4 +1,6 @@
import unittest
import re
import pytest
from ihatemoney import history, models
from ihatemoney.tests.common.help_functions import em_surround
@ -6,18 +8,33 @@ from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
from ihatemoney.versioning import LoggingMode
class HistoryTestCase(IhatemoneyTestCase):
def setUp(self):
super().setUp()
self.post_project("demo")
self.login("demo")
@pytest.fixture
def demo(client):
client.post(
"/create",
data={
"name": "demo",
"id": "demo",
"password": "demo",
"contact_email": "demo@notmyidea.org",
"default_currency": "XXX",
"project_history": True,
},
)
client.post(
"/authenticate",
data=dict(id="demo", password="demo"),
)
@pytest.mark.usefixtures("demo")
class TestHistory(IhatemoneyTestCase):
def test_simple_create_logentry_no_ip(self):
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(f"Project {em_surround('demo')} added", resp.data.decode("utf-8"))
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
assert resp.status_code == 200
assert f"Project {em_surround('demo')} added" in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").count("<td> -- </td>") == 1
assert "127.0.0.1" not in resp.data.decode("utf-8")
def change_privacy_to(self, current_password, logging_preference):
# Change only logging_preferences
@ -36,42 +53,38 @@ class HistoryTestCase(IhatemoneyTestCase):
# Disable History
resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True)
self.assertEqual(resp.status_code, 200)
self.assertNotIn("alert-danger", resp.data.decode("utf-8"))
assert resp.status_code == 200
assert "alert-danger" not in resp.data.decode("utf-8")
resp = self.client.get("/demo/edit")
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
if logging_preference == LoggingMode.DISABLED:
self.assertIn('<input id="project_history"', resp.data.decode("utf-8"))
assert '<input id="project_history"' in resp.data.decode("utf-8")
else:
self.assertIn(
'<input checked id="project_history"', resp.data.decode("utf-8")
)
assert '<input checked id="project_history"' in resp.data.decode("utf-8")
if logging_preference == LoggingMode.RECORD_IP:
self.assertIn('<input checked id="ip_recording"', resp.data.decode("utf-8"))
assert '<input checked id="ip_recording"' in resp.data.decode("utf-8")
else:
self.assertIn('<input id="ip_recording"', resp.data.decode("utf-8"))
assert '<input id="ip_recording"' in resp.data.decode("utf-8")
def assert_empty_history_logging_disabled(self):
resp = self.client.get("/demo/history")
self.assertIn(
"This project has history disabled. New actions won't appear below.",
resp.data.decode("utf-8"),
assert (
"This project has history disabled. New actions won't appear below."
in resp.data.decode("utf-8")
)
self.assertIn("Nothing to list", resp.data.decode("utf-8"))
self.assertNotIn(
"The table below reflects actions recorded prior to disabling project history.",
resp.data.decode("utf-8"),
assert "Nothing to list" in resp.data.decode("utf-8")
assert (
"The table below reflects actions recorded prior to disabling project history."
not in resp.data.decode("utf-8")
)
self.assertNotIn(
"Some entries below contain IP addresses,", resp.data.decode("utf-8")
)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
self.assertNotIn("<td> -- </td>", resp.data.decode("utf-8"))
self.assertNotIn(
f"Project {em_surround('demo')} added", resp.data.decode("utf-8")
assert "Some entries below contain IP addresses," not in resp.data.decode(
"utf-8"
)
assert "127.0.0.1" not in resp.data.decode("utf-8")
assert "<td> -- </td>" not in resp.data.decode("utf-8")
assert f"Project {em_surround('demo')} added" not in resp.data.decode("utf-8")
def test_project_edit(self):
new_data = {
@ -84,90 +97,86 @@ class HistoryTestCase(IhatemoneyTestCase):
}
resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(f"Project {em_surround('demo')} added", resp.data.decode("utf-8"))
self.assertIn(
f"Project contact email changed to {em_surround('demo2@notmyidea.org')}",
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert f"Project {em_surround('demo')} added" in resp.data.decode("utf-8")
assert (
f"Project contact email changed to {em_surround('demo2@notmyidea.org')}"
in resp.data.decode("utf-8")
)
self.assertIn("Project private code changed", resp.data.decode("utf-8"))
self.assertIn(
f"Project renamed to {em_surround('demo2')}", resp.data.decode("utf-8")
)
self.assertLess(
resp.data.decode("utf-8").index("Project renamed "),
resp.data.decode("utf-8").index("Project contact email changed to "),
)
self.assertLess(
resp.data.decode("utf-8").index("Project renamed "),
resp.data.decode("utf-8").index("Project private code changed"),
)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 5)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
assert "Project private code changed" in resp.data.decode("utf-8")
assert f"Project renamed to {em_surround('demo2')}" in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").index("Project renamed ") < resp.data.decode(
"utf-8"
).index("Project contact email changed to ")
assert resp.data.decode("utf-8").index("Project renamed ") < resp.data.decode(
"utf-8"
).index("Project private code changed")
assert resp.data.decode("utf-8").count("<td> -- </td>") == 5
assert "127.0.0.1" not in resp.data.decode("utf-8")
def test_project_privacy_edit(self):
resp = self.client.get("/demo/edit")
self.assertEqual(resp.status_code, 200)
self.assertIn(
'<input checked id="project_history" name="project_history" type="checkbox" value="y">',
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert (
'<input checked id="project_history" name="project_history" type="checkbox" value="y">'
in resp.data.decode("utf-8")
)
self.change_privacy_to("demo", LoggingMode.DISABLED)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn("Disabled Project History\n", resp.data.decode("utf-8"))
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
assert resp.status_code == 200
assert "Disabled Project History\n" in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
assert "127.0.0.1" not in resp.data.decode("utf-8")
self.change_privacy_to("demo", LoggingMode.RECORD_IP)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(
"Enabled Project History & IP Address Recording", resp.data.decode("utf-8")
assert resp.status_code == 200
assert "Enabled Project History & IP Address Recording" in resp.data.decode(
"utf-8"
)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 1)
assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
assert resp.data.decode("utf-8").count("127.0.0.1") == 1
self.change_privacy_to("demo", LoggingMode.ENABLED)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn("Disabled IP Address Recording\n", resp.data.decode("utf-8"))
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 2)
assert resp.status_code == 200
assert "Disabled IP Address Recording\n" in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
assert resp.data.decode("utf-8").count("127.0.0.1") == 2
def test_project_privacy_edit2(self):
self.change_privacy_to("demo", LoggingMode.RECORD_IP)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn("Enabled IP Address Recording\n", resp.data.decode("utf-8"))
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 1)
assert resp.status_code == 200
assert "Enabled IP Address Recording\n" in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").count("<td> -- </td>") == 1
assert resp.data.decode("utf-8").count("127.0.0.1") == 1
self.change_privacy_to("demo", LoggingMode.DISABLED)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(
"Disabled Project History & IP Address Recording", resp.data.decode("utf-8")
assert resp.status_code == 200
assert "Disabled Project History & IP Address Recording" in resp.data.decode(
"utf-8"
)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 2)
assert resp.data.decode("utf-8").count("<td> -- </td>") == 1
assert resp.data.decode("utf-8").count("127.0.0.1") == 2
self.change_privacy_to("demo", LoggingMode.ENABLED)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn("Enabled Project History\n", resp.data.decode("utf-8"))
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 2)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 2)
assert resp.status_code == 200
assert "Enabled Project History\n" in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").count("<td> -- </td>") == 2
assert resp.data.decode("utf-8").count("127.0.0.1") == 2
def do_misc_database_operations(self, logging_mode):
new_data = {
@ -185,13 +194,13 @@ class HistoryTestCase(IhatemoneyTestCase):
new_data["ip_recording"] = "y"
resp = self.client.post("/demo/edit", data=new_data, follow_redirects=True)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
# adds a member to this project
resp = self.client.post(
"/demo/members/add", data={"name": "zorglub"}, follow_redirects=True
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
user_id = models.Person.query.one().id
@ -207,7 +216,7 @@ class HistoryTestCase(IhatemoneyTestCase):
},
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
bill_id = models.Bill.query.one().id
@ -223,16 +232,16 @@ class HistoryTestCase(IhatemoneyTestCase):
},
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
# delete the bill
resp = self.client.post(f"/demo/delete/{bill_id}", follow_redirects=True)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
# delete user using POST method
resp = self.client.post(
f"/demo/members/{user_id}/delete", follow_redirects=True
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
def test_disable_clear_no_new_records(self):
# Disable logging
@ -240,27 +249,24 @@ class HistoryTestCase(IhatemoneyTestCase):
# Ensure we can't clear history with a GET or with a password-less POST
resp = self.client.get("/demo/erase_history")
self.assertEqual(resp.status_code, 405)
assert resp.status_code == 405
resp = self.client.post("/demo/erase_history", follow_redirects=True)
self.assertIn(
"Error deleting project history",
resp.data.decode("utf-8"),
)
assert "Error deleting project history" in resp.data.decode("utf-8")
# List history
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(
"This project has history disabled. New actions won't appear below.",
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert (
"This project has history disabled. New actions won't appear below."
in resp.data.decode("utf-8")
)
self.assertIn(
"The table below reflects actions recorded prior to disabling project history.",
resp.data.decode("utf-8"),
assert (
"The table below reflects actions recorded prior to disabling project history."
in resp.data.decode("utf-8")
)
self.assertNotIn("Nothing to list", resp.data.decode("utf-8"))
self.assertNotIn(
"Some entries below contain IP addresses,", resp.data.decode("utf-8")
assert "Nothing to list" not in resp.data.decode("utf-8")
assert "Some entries below contain IP addresses," not in resp.data.decode(
"utf-8"
)
# Clear Existing Entries
@ -269,7 +275,7 @@ class HistoryTestCase(IhatemoneyTestCase):
data={"password": "demo"},
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
self.assert_empty_history_logging_disabled()
# Do lots of database operations & check that there's still no history
@ -288,43 +294,38 @@ class HistoryTestCase(IhatemoneyTestCase):
self.change_privacy_to("123456", LoggingMode.ENABLED)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertNotIn(
"This project has history disabled. New actions won't appear below.",
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert (
"This project has history disabled. New actions won't appear below."
not in resp.data.decode("utf-8")
)
self.assertNotIn(
"The table below reflects actions recorded prior to disabling project history.",
resp.data.decode("utf-8"),
assert (
"The table below reflects actions recorded prior to disabling project history."
not in resp.data.decode("utf-8")
)
self.assertNotIn("Nothing to list", resp.data.decode("utf-8"))
self.assertIn(
"Some entries below contain IP addresses,", resp.data.decode("utf-8")
)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 1)
assert "Nothing to list" not in resp.data.decode("utf-8")
assert "Some entries below contain IP addresses," in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").count("127.0.0.1") == 12
assert resp.data.decode("utf-8").count("<td> -- </td>") == 1
# Generate more operations to confirm additional IP info isn't recorded
self.do_misc_database_operations(LoggingMode.ENABLED)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 7)
assert resp.status_code == 200
assert resp.data.decode("utf-8").count("127.0.0.1") == 12
assert resp.data.decode("utf-8").count("<td> -- </td>") == 7
# Ensure we can't clear IP data with a GET or with a password-less POST
resp = self.client.get("/demo/strip_ip_addresses")
self.assertEqual(resp.status_code, 405)
assert resp.status_code == 405
resp = self.client.post("/demo/strip_ip_addresses", follow_redirects=True)
self.assertIn(
"Error deleting recorded IP addresses",
resp.data.decode("utf-8"),
)
assert "Error deleting recorded IP addresses" in resp.data.decode("utf-8")
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 12)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 7)
assert resp.status_code == 200
assert resp.data.decode("utf-8").count("127.0.0.1") == 12
assert resp.data.decode("utf-8").count("<td> -- </td>") == 7
# Clear IP Data
resp = self.client.post(
@ -333,33 +334,33 @@ class HistoryTestCase(IhatemoneyTestCase):
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
self.assertNotIn(
"This project has history disabled. New actions won't appear below.",
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert (
"This project has history disabled. New actions won't appear below."
not in resp.data.decode("utf-8")
)
self.assertNotIn(
"The table below reflects actions recorded prior to disabling project history.",
resp.data.decode("utf-8"),
assert (
"The table below reflects actions recorded prior to disabling project history."
not in resp.data.decode("utf-8")
)
self.assertNotIn("Nothing to list", resp.data.decode("utf-8"))
self.assertNotIn(
"Some entries below contain IP addresses,", resp.data.decode("utf-8")
assert "Nothing to list" not in resp.data.decode("utf-8")
assert "Some entries below contain IP addresses," not in resp.data.decode(
"utf-8"
)
self.assertEqual(resp.data.decode("utf-8").count("127.0.0.1"), 0)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 19)
assert resp.data.decode("utf-8").count("127.0.0.1") == 0
assert resp.data.decode("utf-8").count("<td> -- </td>") == 19
def test_logs_for_common_actions(self):
# adds a member to this project
resp = self.client.post(
"/demo/members/add", data={"name": "zorglub"}, follow_redirects=True
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(
f"Participant {em_surround('zorglub')} added", resp.data.decode("utf-8")
assert resp.status_code == 200
assert f"Participant {em_surround('zorglub')} added" in resp.data.decode(
"utf-8"
)
# create a bill
@ -374,13 +375,12 @@ class HistoryTestCase(IhatemoneyTestCase):
},
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(
f"Bill {em_surround('fromage à raclette')} added",
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert f"Bill {em_surround('fromage à raclette')} added" in resp.data.decode(
"utf-8"
)
# edit the bill
@ -395,44 +395,37 @@ class HistoryTestCase(IhatemoneyTestCase):
},
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(
f"Bill {em_surround('fromage à raclette')} added",
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert f"Bill {em_surround('fromage à raclette')} added" in resp.data.decode(
"utf-8"
)
self.assertRegex(
resp.data.decode("utf-8"),
assert re.search(
r"Bill %s:\s* Amount changed\s* from %s\s* to %s"
% (
em_surround("fromage à raclette", regex_escape=True),
em_surround("25.0", regex_escape=True),
em_surround("10.0", regex_escape=True),
),
)
self.assertIn(
"Bill %s renamed to %s"
% (em_surround("fromage à raclette"), em_surround("new thing")),
resp.data.decode("utf-8"),
)
self.assertLess(
resp.data.decode("utf-8").index(
f"Bill {em_surround('fromage à raclette')} renamed to"
),
resp.data.decode("utf-8").index("Amount changed"),
)
assert "Bill %s renamed to %s" % (
em_surround("fromage à raclette"),
em_surround("new thing"),
) in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").index(
f"Bill {em_surround('fromage à raclette')} renamed to"
) < resp.data.decode("utf-8").index("Amount changed")
# delete the bill
resp = self.client.post("/demo/delete/1", follow_redirects=True)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(
f"Bill {em_surround('new thing')} removed", resp.data.decode("utf-8")
)
assert resp.status_code == 200
assert f"Bill {em_surround('new thing')} removed" in resp.data.decode("utf-8")
# edit user
resp = self.client.post(
@ -440,39 +433,35 @@ class HistoryTestCase(IhatemoneyTestCase):
data={"weight": 2, "name": "new name"},
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertRegex(
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert re.search(
r"Participant %s:\s* weight changed\s* from %s\s* to %s"
% (
em_surround("zorglub", regex_escape=True),
em_surround("1.0", regex_escape=True),
em_surround("2.0", regex_escape=True),
),
)
self.assertIn(
"Participant %s renamed to %s"
% (em_surround("zorglub"), em_surround("new name")),
resp.data.decode("utf-8"),
)
self.assertLess(
resp.data.decode("utf-8").index(
f"Participant {em_surround('zorglub')} renamed"
),
resp.data.decode("utf-8").index("weight changed"),
)
assert "Participant %s renamed to %s" % (
em_surround("zorglub"),
em_surround("new name"),
) in resp.data.decode("utf-8")
assert resp.data.decode("utf-8").index(
f"Participant {em_surround('zorglub')} renamed"
) < resp.data.decode("utf-8").index("weight changed")
# delete user using POST method
resp = self.client.post("/demo/members/1/delete", follow_redirects=True)
self.assertEqual(resp.status_code, 200)
assert resp.status_code == 200
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertIn(
f"Participant {em_surround('new name')} removed", resp.data.decode("utf-8")
assert resp.status_code == 200
assert f"Participant {em_surround('new name')} removed" in resp.data.decode(
"utf-8"
)
def test_double_bill_double_person_edit_second(self):
@ -504,9 +493,9 @@ class HistoryTestCase(IhatemoneyTestCase):
# Should be 5 history entries at this point
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 5)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
assert resp.status_code == 200
assert resp.data.decode("utf-8").count("<td> -- </td>") == 5
assert "127.0.0.1" not in resp.data.decode("utf-8")
# Edit ONLY the amount on the first bill
self.client.post(
@ -521,28 +510,27 @@ class HistoryTestCase(IhatemoneyTestCase):
)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertRegex(
resp.data.decode("utf-8"),
assert resp.status_code == 200
assert re.search(
r"Bill {}:\s* Amount changed\s* from {}\s* to {}".format(
em_surround("Bill 1", regex_escape=True),
em_surround("25.0", regex_escape=True),
em_surround("88.0", regex_escape=True),
),
resp.data.decode("utf-8"),
)
self.assertNotRegex(
resp.data.decode("utf-8"),
assert not re.search(
r"Removed\s* {}\s* and\s* {}\s* from\s* owers list".format(
em_surround("User 1", regex_escape=True),
em_surround("User 2", regex_escape=True),
),
resp.data.decode("utf-8"),
)
), resp.data.decode("utf-8")
# Should be 6 history entries at this point
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 6)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
assert resp.data.decode("utf-8").count("<td> -- </td>") == 6
assert "127.0.0.1" not in resp.data.decode("utf-8")
def test_bill_add_remove_add(self):
# add two members
@ -565,13 +553,11 @@ class HistoryTestCase(IhatemoneyTestCase):
self.client.post("/demo/delete/1", follow_redirects=True)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 5)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
self.assertIn(f"Bill {em_surround('Bill 1')} added", resp.data.decode("utf-8"))
self.assertIn(
f"Bill {em_surround('Bill 1')} removed", resp.data.decode("utf-8")
)
assert resp.status_code == 200
assert resp.data.decode("utf-8").count("<td> -- </td>") == 5
assert "127.0.0.1" not in resp.data.decode("utf-8")
assert f"Bill {em_surround('Bill 1')} added" in resp.data.decode("utf-8")
assert f"Bill {em_surround('Bill 1')} removed" in resp.data.decode("utf-8")
# Add a new bill
self.client.post(
@ -586,17 +572,15 @@ class HistoryTestCase(IhatemoneyTestCase):
)
resp = self.client.get("/demo/history")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.data.decode("utf-8").count("<td> -- </td>"), 6)
self.assertNotIn("127.0.0.1", resp.data.decode("utf-8"))
self.assertIn(f"Bill {em_surround('Bill 1')} added", resp.data.decode("utf-8"))
self.assertEqual(
resp.data.decode("utf-8").count(f"Bill {em_surround('Bill 1')} added"), 1
)
self.assertIn(f"Bill {em_surround('Bill 2')} added", resp.data.decode("utf-8"))
self.assertIn(
f"Bill {em_surround('Bill 1')} removed", resp.data.decode("utf-8")
assert resp.status_code == 200
assert resp.data.decode("utf-8").count("<td> -- </td>") == 6
assert "127.0.0.1" not in resp.data.decode("utf-8")
assert f"Bill {em_surround('Bill 1')} added" in resp.data.decode("utf-8")
assert (
resp.data.decode("utf-8").count(f"Bill {em_surround('Bill 1')} added") == 1
)
assert f"Bill {em_surround('Bill 2')} added" in resp.data.decode("utf-8")
assert f"Bill {em_surround('Bill 1')} removed" in resp.data.decode("utf-8")
def test_double_bill_double_person_edit_second_no_web(self):
u1 = models.Person(project_id="demo", name="User 1")
@ -617,7 +601,7 @@ class HistoryTestCase(IhatemoneyTestCase):
models.db.session.commit()
history_list = history.get_history(self.get_project("demo"))
self.assertEqual(len(history_list), 5)
assert len(history_list) == 5
# Change just the amount
b1.amount = 5
@ -626,8 +610,8 @@ class HistoryTestCase(IhatemoneyTestCase):
history_list = history.get_history(self.get_project("demo"))
for entry in history_list:
if "prop_changed" in entry:
self.assertNotIn("owers", entry["prop_changed"])
self.assertEqual(len(history_list), 6)
assert "owers" not in entry["prop_changed"]
assert len(history_list) == 6
def test_delete_history_with_project(self):
self.post_project("raclette", password="party")
@ -659,8 +643,4 @@ class HistoryTestCase(IhatemoneyTestCase):
# History should be equal to project creation
history_list = history.get_history(self.get_project("raclette"))
self.assertEqual(len(history_list), 1)
if __name__ == "__main__":
unittest.main()
assert len(history_list) == 1

View file

@ -1,12 +1,46 @@
import copy
import json
import unittest
import pytest
from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
from ihatemoney.utils import list_of_dicts2csv, list_of_dicts2json
@pytest.fixture
def import_data(request: pytest.FixtureRequest):
data = [
{
"date": "2017-01-01",
"what": "refund",
"amount": 13.33,
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
},
{
"date": "2016-12-31",
"what": "red wine",
"amount": 200.0,
"payer_name": "fred",
"payer_weight": 1.0,
"owers": ["zorglub", "tata"],
},
{
"date": "2016-12-31",
"what": "fromage a raclette",
"amount": 10.0,
"payer_name": "zorglub",
"payer_weight": 2.0,
"owers": ["zorglub", "fred", "tata", "pepe"],
},
]
request.cls.data = data
yield data
class CommonTestCase(object):
@pytest.mark.usefixtures("import_data")
class Import(IhatemoneyTestCase):
def setUp(self):
super().setUp()
@ -55,7 +89,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
assert len(bills) == len(self.data)
# Check if name of bills are ok
b = [e["what"] for e in bills]
@ -63,22 +97,22 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
assert b == ref
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
self.assertEqual(b["currency"], d["currency"])
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
assert b["payer_name"] == d["payer_name"]
assert b["amount"] == d["amount"]
assert b["currency"] == d["currency"]
assert b["payer_weight"] == d["payer_weight"]
assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
assert list_project == list_json
def test_import_single_currency_in_empty_project_without_currency(self):
# Import JSON with a single currency in an empty project with no
@ -96,7 +130,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
assert len(bills) == len(self.data)
# Check if name of bills are ok
b = [e["what"] for e in bills]
@ -104,23 +138,23 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
assert b == ref
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
assert b["payer_name"] == d["payer_name"]
assert b["amount"] == d["amount"]
# Currency should have been stripped
self.assertEqual(b["currency"], "XXX")
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
assert b["currency"] == "XXX"
assert b["payer_weight"] == d["payer_weight"]
assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
assert list_project == list_json
def test_import_multiple_currencies_in_empty_project_without_currency(self):
# Import JSON with multiple currencies in an empty project with no
@ -138,7 +172,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills()
# Check that there are no bills
self.assertEqual(len(bills), 0)
assert len(bills) == 0
def test_import_no_currency_in_empty_project_with_currency(self):
# Import JSON without currencies (from ihatemoney < 5) in an empty
@ -154,7 +188,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
assert len(bills) == len(self.data)
# Check if name of bills are ok
b = [e["what"] for e in bills]
@ -162,23 +196,23 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
assert b == ref
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
assert b["payer_name"] == d["payer_name"]
assert b["amount"] == d["amount"]
# All bills are converted to default project currency
self.assertEqual(b["currency"], "EUR")
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
assert b["currency"] == "EUR"
assert b["payer_weight"] == d["payer_weight"]
assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
assert list_project == list_json
def test_import_no_currency_in_empty_project_without_currency(self):
# Import JSON without currencies (from ihatemoney < 5) in an empty
@ -194,7 +228,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
assert len(bills) == len(self.data)
# Check if name of bills are ok
b = [e["what"] for e in bills]
@ -202,22 +236,22 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
assert b == ref
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
self.assertEqual(b["currency"], "XXX")
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
assert b["payer_name"] == d["payer_name"]
assert b["amount"] == d["amount"]
assert b["currency"] == "XXX"
assert b["payer_weight"] == d["payer_weight"]
assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
assert list_project == list_json
def test_import_partial_project(self):
# Import a JSON in a project with already existing data
@ -250,7 +284,7 @@ class CommonTestCase(object):
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
assert len(bills) == len(self.data)
# Check if name of bills are ok
b = [e["what"] for e in bills]
@ -258,22 +292,22 @@ class CommonTestCase(object):
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
assert b == ref
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
self.assertEqual(b["currency"], d["currency"])
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
assert b["payer_name"] == d["payer_name"]
assert b["amount"] == d["amount"]
assert b["currency"] == d["currency"]
assert b["payer_weight"] == d["payer_weight"]
assert b["date"] == d["date"]
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
assert list_project == list_json
def test_import_wrong_data(self):
self.post_project("raclette")
@ -302,7 +336,7 @@ class CommonTestCase(object):
self.import_project("raclette", self.generate_form_data(data), 400)
class ExportTestCase(IhatemoneyTestCase):
class TestExport(IhatemoneyTestCase):
def test_export(self):
# Export a simple project without currencies
@ -379,7 +413,7 @@ class ExportTestCase(IhatemoneyTestCase):
"owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"],
},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of bills
resp = self.client.get("/raclette/export/bills.csv")
@ -392,9 +426,7 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
# generate json export of transactions
resp = self.client.get("/raclette/export/transactions.json")
@ -414,7 +446,7 @@ class ExportTestCase(IhatemoneyTestCase):
},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv")
@ -428,13 +460,11 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
# wrong export_format should return a 404
resp = self.client.get("/raclette/export/transactions.wrong")
self.assertEqual(resp.status_code, 404)
assert resp.status_code == 404
def test_export_with_currencies(self):
self.post_project("raclette", default_currency="EUR")
@ -513,7 +543,7 @@ class ExportTestCase(IhatemoneyTestCase):
"owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"],
},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of bills
resp = self.client.get("/raclette/export/bills.csv")
@ -526,9 +556,7 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
# generate json export of transactions (in EUR!)
resp = self.client.get("/raclette/export/transactions.json")
@ -543,7 +571,7 @@ class ExportTestCase(IhatemoneyTestCase):
{"amount": 38.45, "currency": "EUR", "receiver": "fred", "ower": "zorglub"},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv")
@ -557,9 +585,7 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
# Change project currency to CAD
project = self.get_project("raclette")
@ -578,7 +604,7 @@ class ExportTestCase(IhatemoneyTestCase):
{"amount": 57.67, "currency": "CAD", "receiver": "fred", "ower": "zorglub"},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
assert json.loads(resp.data.decode("utf-8")) == expected
# generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv")
@ -592,9 +618,7 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
def test_export_escape_formulae(self):
self.post_project("raclette", default_currency="EUR")
@ -624,23 +648,17 @@ class ExportTestCase(IhatemoneyTestCase):
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
assert set(line.split(",")) == set(received_lines[i].strip("\r").split(","))
class ImportTestCaseJSON(CommonTestCase.Import):
class TestImportJSON(CommonTestCase.Import):
def generate_form_data(self, data):
return {"file": (list_of_dicts2json(data), "test.json")}
class ImportTestCaseCSV(CommonTestCase.Import):
class TestImportCSV(CommonTestCase.Import):
def generate_form_data(self, data):
formatted_data = copy.deepcopy(data)
for d in formatted_data:
d["owers"] = ", ".join([o for o in d.get("owers", [])])
return {"file": (list_of_dicts2csv(formatted_data), "test.csv")}
if __name__ == "__main__":
unittest.main()

View file

@ -1,9 +1,9 @@
import os
import smtplib
import socket
import unittest
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import orm
from werkzeug.security import check_password_hash
@ -19,19 +19,18 @@ os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
__HERE__ = os.path.dirname(os.path.abspath(__file__))
class ConfigurationTestCase(BaseTestCase):
class TestConfiguration(BaseTestCase):
def test_default_configuration(self):
"""Test that default settings are loaded when no other configuration file is specified"""
self.assertFalse(self.app.config["DEBUG"])
self.assertFalse(self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"])
self.assertEqual(
self.app.config["MAIL_DEFAULT_SENDER"],
("Budget manager <admin@example.com>"),
assert not self.app.config["DEBUG"]
assert not self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]
assert self.app.config["MAIL_DEFAULT_SENDER"] == (
"Budget manager <admin@example.com>"
)
self.assertTrue(self.app.config["ACTIVATE_DEMO_PROJECT"])
self.assertTrue(self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"])
self.assertFalse(self.app.config["ACTIVATE_ADMIN_DASHBOARD"])
self.assertFalse(self.app.config["ENABLE_CAPTCHA"])
assert self.app.config["ACTIVATE_DEMO_PROJECT"]
assert self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"]
assert not self.app.config["ACTIVATE_ADMIN_DASHBOARD"]
assert not self.app.config["ENABLE_CAPTCHA"]
def test_env_var_configuration_file(self):
"""Test that settings are loaded from a configuration file specified
@ -40,7 +39,7 @@ class ConfigurationTestCase(BaseTestCase):
__HERE__, "ihatemoney_envvar.cfg"
)
load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "lalatra")
assert self.app.config["SECRET_KEY"] == "lalatra"
# Test that the specified configuration file is loaded
# even if the default configuration file ihatemoney.cfg exists
@ -50,7 +49,7 @@ class ConfigurationTestCase(BaseTestCase):
)
self.app.config.root_path = __HERE__
load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "lalatra")
assert self.app.config["SECRET_KEY"] == "lalatra"
os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
@ -59,10 +58,10 @@ class ConfigurationTestCase(BaseTestCase):
in the current directory."""
self.app.config.root_path = __HERE__
load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "supersecret")
assert self.app.config["SECRET_KEY"] == "supersecret"
class ServerTestCase(IhatemoneyTestCase):
class TestServer(IhatemoneyTestCase):
def test_homepage(self):
# See https://github.com/spiral-project/ihatemoney/pull/358
self.app.config["APPLICATION_ROOT"] = "/"
@ -80,7 +79,7 @@ class ServerTestCase(IhatemoneyTestCase):
self.assertStatus(200, req)
class CommandTestCase(BaseTestCase):
class TestCommand(BaseTestCase):
def test_generate_config(self):
"""Simply checks that all config file generation
- raise no exception
@ -89,25 +88,25 @@ class CommandTestCase(BaseTestCase):
runner = self.app.test_cli_runner()
for config_file in generate_config.params[0].type.choices:
result = runner.invoke(generate_config, config_file)
self.assertNotEqual(len(result.output.strip()), 0)
assert len(result.output.strip()) != 0
def test_generate_password_hash(self):
runner = self.app.test_cli_runner()
with patch("getpass.getpass", new=lambda prompt: "secret"):
result = runner.invoke(password_hash)
self.assertTrue(check_password_hash(result.output.strip(), "secret"))
assert check_password_hash(result.output.strip(), "secret")
def test_demo_project_deletion(self):
self.create_project("demo")
self.assertEqual(self.get_project("demo").name, "demo")
assert self.get_project("demo").name == "demo"
runner = self.app.test_cli_runner()
runner.invoke(delete_project, "demo")
self.assertEqual(len(models.Project.query.all()), 0)
assert len(models.Project.query.all()) == 0
class ModelsTestCase(IhatemoneyTestCase):
class TestModels(IhatemoneyTestCase):
def test_weighted_bills(self):
"""Test the SQL request that fetch all bills and weights"""
self.post_project("raclette")
@ -156,13 +155,13 @@ class ModelsTestCase(IhatemoneyTestCase):
for weight, bill in project.get_bill_weights().all():
if bill.what == "red wine":
pay_each_expected = 20 / 2
self.assertEqual(bill.amount / weight, pay_each_expected)
assert bill.amount / weight == pay_each_expected
if bill.what == "fromage à raclette":
pay_each_expected = 10 / 4
self.assertEqual(bill.amount / weight, pay_each_expected)
assert bill.amount / weight == pay_each_expected
if bill.what == "delicatessen":
pay_each_expected = 10 / 3
self.assertEqual(bill.amount / weight, pay_each_expected)
assert bill.amount / weight == pay_each_expected
def test_bill_pay_each(self):
self.post_project("raclette")
@ -216,16 +215,16 @@ class ModelsTestCase(IhatemoneyTestCase):
for bill in zorglub_bills.all():
if bill.what == "red wine":
pay_each_expected = 20 / 2
self.assertEqual(bill.pay_each(), pay_each_expected)
assert bill.pay_each() == pay_each_expected
if bill.what == "fromage à raclette":
pay_each_expected = 10 / 4
self.assertEqual(bill.pay_each(), pay_each_expected)
assert bill.pay_each() == pay_each_expected
if bill.what == "delicatessen":
pay_each_expected = 10 / 3
self.assertEqual(bill.pay_each(), pay_each_expected)
assert bill.pay_each() == pay_each_expected
class EmailFailureTestCase(IhatemoneyTestCase):
class TestEmailFailure(IhatemoneyTestCase):
def test_creation_email_failure_smtp(self):
self.login("raclette")
with patch.object(
@ -233,14 +232,14 @@ class EmailFailureTestCase(IhatemoneyTestCase):
):
resp = self.post_project("raclette")
# Check that an error message is displayed
self.assertIn(
"We tried to send you an reminder email, but there was an error",
resp.data.decode("utf-8"),
assert (
"We tried to send you an reminder email, but there was an error"
in resp.data.decode("utf-8")
)
# Check that we were redirected to the home page anyway
self.assertIn(
'<a href="/raclette/members/add">Add the first participant',
resp.data.decode("utf-8"),
assert (
'<a href="/raclette/members/add">Add the first participant'
in resp.data.decode("utf-8")
)
def test_creation_email_failure_socket(self):
@ -248,14 +247,14 @@ class EmailFailureTestCase(IhatemoneyTestCase):
with patch.object(self.app.mail, "send", MagicMock(side_effect=socket.error)):
resp = self.post_project("raclette")
# Check that an error message is displayed
self.assertIn(
"We tried to send you an reminder email, but there was an error",
resp.data.decode("utf-8"),
assert (
"We tried to send you an reminder email, but there was an error"
in resp.data.decode("utf-8")
)
# Check that we were redirected to the home page anyway
self.assertIn(
'<a href="/raclette/members/add">Add the first participant',
resp.data.decode("utf-8"),
assert (
'<a href="/raclette/members/add">Add the first participant'
in resp.data.decode("utf-8")
)
def test_password_reset_email_failure(self):
@ -266,14 +265,13 @@ class EmailFailureTestCase(IhatemoneyTestCase):
"/password-reminder", data={"id": "raclette"}, follow_redirects=True
)
# Check that an error message is displayed
self.assertIn(
"there was an error while sending you an email",
resp.data.decode("utf-8"),
assert "there was an error while sending you an email" in resp.data.decode(
"utf-8"
)
# Check that we were not redirected to the success page
self.assertNotIn(
"A link to reset your password has been sent to you",
resp.data.decode("utf-8"),
assert (
"A link to reset your password has been sent to you"
not in resp.data.decode("utf-8")
)
def test_invitation_email_failure(self):
@ -287,17 +285,15 @@ class EmailFailureTestCase(IhatemoneyTestCase):
follow_redirects=True,
)
# Check that an error message is displayed
self.assertIn(
"there was an error while trying to send the invitation emails",
resp.data.decode("utf-8"),
assert (
"there was an error while trying to send the invitation emails"
in resp.data.decode("utf-8")
)
# Check that we are still on the same page (no redirection)
self.assertIn(
"Invite people to join this project", resp.data.decode("utf-8")
)
assert "Invite people to join this project" in resp.data.decode("utf-8")
class CaptchaTestCase(IhatemoneyTestCase):
class TestCaptcha(IhatemoneyTestCase):
ENABLE_CAPTCHA = True
def test_project_creation_with_captcha_case_insensitive(self):
@ -315,7 +311,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
"captcha": "éùüß",
},
)
self.assertEqual(len(models.Project.query.all()), 1)
assert len(models.Project.query.all()) == 1
def test_project_creation_with_captcha(self):
with self.client as c:
@ -329,7 +325,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
"default_currency": "USD",
},
)
self.assertEqual(len(models.Project.query.all()), 0)
assert len(models.Project.query.all()) == 0
c.post(
"/create",
@ -342,7 +338,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
"captcha": "nope",
},
)
self.assertEqual(len(models.Project.query.all()), 0)
assert len(models.Project.query.all()) == 0
c.post(
"/create",
@ -355,7 +351,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
"captcha": "euro",
},
)
self.assertEqual(len(models.Project.query.all()), 1)
assert len(models.Project.query.all()) == 1
def test_api_project_creation_does_not_need_captcha(self):
self.client.get("/")
@ -368,11 +364,11 @@ class CaptchaTestCase(IhatemoneyTestCase):
"contact_email": "raclette@notmyidea.org",
},
)
self.assertTrue(resp.status, 201)
self.assertEqual(len(models.Project.query.all()), 1)
assert resp.status_code == 201
assert len(models.Project.query.all()) == 1
class TestCurrencyConverter(unittest.TestCase):
class TestCurrencyConverter:
converter = CurrencyConverter()
mock_data = {
"USD": 1,
@ -386,28 +382,23 @@ class TestCurrencyConverter(unittest.TestCase):
def test_only_one_instance(self):
one = id(CurrencyConverter())
two = id(CurrencyConverter())
self.assertEqual(one, two)
assert one == two
def test_get_currencies(self):
self.assertCountEqual(
self.converter.get_currencies(),
["USD", "EUR", "CAD", "PLN", CurrencyConverter.no_currency],
assert set(self.converter.get_currencies()) == set(
["USD", "EUR", "CAD", "PLN", CurrencyConverter.no_currency]
)
def test_exchange_currency(self):
result = self.converter.exchange_currency(100, "USD", "EUR")
self.assertEqual(result, 80.0)
assert result == 80.0
def test_failing_remote(self):
rates = {}
with patch("requests.Response.json", new=lambda _: {}), self.assertWarns(
with patch("requests.Response.json", new=lambda _: {}), pytest.warns(
UserWarning
):
# we need a non-patched converter, but it seems that MagickMock
# is mocking EVERY instance of the class method. Too bad.
rates = CurrencyConverter.get_rates(self.converter)
self.assertDictEqual(rates, {CurrencyConverter.no_currency: 1})
if __name__ == "__main__":
unittest.main()
assert rates == {CurrencyConverter.no_currency: 1}

View file

@ -59,8 +59,8 @@ dev =
flake8==5.0.4
isort==5.11.5
vermin==1.5.2
Flask-Testing>=0.8.1
pytest>=6.2.5
pytest-flask>=1.2.0
pytest-libfaketime>=0.1.2
tox>=3.14.6
zest.releaser>=6.20.1