diff --git a/ihatemoney/api/common.py b/ihatemoney/api/common.py index 1cfc34c5..44727f5a 100644 --- a/ihatemoney/api/common.py +++ b/ihatemoney/api/common.py @@ -159,8 +159,7 @@ class BillsHandler(Resource): def post(self, project): form = get_billform_for(project, True, meta={"csrf": False}) if form.validate(): - bill = Bill() - form.save(bill, project) + bill = form.export(project) db.session.add(bill) db.session.commit() return bill.id, 201 diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py index 4e241c86..fe966778 100644 --- a/ihatemoney/forms.py +++ b/ihatemoney/forms.py @@ -23,7 +23,7 @@ from wtforms.validators import ( ) from ihatemoney.currency_convertor import CurrencyConverter -from ihatemoney.models import LoggingMode, Person, Project +from ihatemoney.models import Bill, LoggingMode, Person, Project from ihatemoney.utils import ( eval_arithmetic_expression, render_localized_currency, @@ -182,13 +182,15 @@ class EditProjectForm(FlaskForm): return project -class UploadForm(FlaskForm): +class ImportProjectForm(FlaskForm): file = FileField( - "JSON", - validators=[FileRequired(), FileAllowed(["json", "JSON"], "JSON only!")], - description=_("Import previously exported JSON file"), + "File", + validators=[ + FileRequired(), + FileAllowed(["json", "JSON", "csv", "CSV"], "Incorrect file format"), + ], + description=_("Compatible with Cospend"), ) - submit = SubmitField(_("Import")) class ProjectForm(EditProjectForm): @@ -319,33 +321,31 @@ class BillForm(FlaskForm): submit = SubmitField(_("Submit")) submit2 = SubmitField(_("Submit and add a new one")) + def export(self, project): + return Bill( + amount=float(self.amount.data), + date=self.date.data, + external_link=self.external_link.data, + original_currency=str(self.original_currency.data), + owers=Person.query.get_by_ids(self.payed_for.data, project), + payer_id=self.payer.data, + project_default_currency=project.default_currency, + what=self.what.data, + ) + def save(self, bill, project): bill.payer_id = self.payer.data bill.amount = self.amount.data bill.what = self.what.data bill.external_link = self.external_link.data bill.date = self.date.data - bill.owers = [Person.query.get(ower, project) for ower in self.payed_for.data] + bill.owers = Person.query.get_by_ids(self.payed_for.data, project) bill.original_currency = self.original_currency.data bill.converted_amount = self.currency_helper.exchange_currency( bill.amount, bill.original_currency, project.default_currency ) return bill - def fake_form(self, bill, project): - bill.payer_id = self.payer - bill.amount = self.amount - bill.what = self.what - bill.external_link = "" - bill.date = self.date - bill.owers = [Person.query.get(ower, project) for ower in self.payed_for] - bill.original_currency = self.original_currency - bill.converted_amount = self.currency_helper.exchange_currency( - bill.amount, bill.original_currency, project.default_currency - ) - - return bill - def fill(self, bill, project): self.payer.data = bill.payer_id self.amount.data = bill.amount diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 473e7c0b..7877b410 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -1,6 +1,7 @@ from collections import defaultdict from datetime import datetime +from dateutil.parser import parse from debts import settle from flask import current_app, g from flask_sqlalchemy import BaseQuery, SQLAlchemy @@ -19,6 +20,7 @@ from werkzeug.security import generate_password_hash from ihatemoney.currency_convertor import CurrencyConverter from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder +from ihatemoney.utils import get_members, same_bill from ihatemoney.versioning import ( ConditionalVersioningManager, LoggingMode, @@ -320,6 +322,44 @@ class Project(db.Model): db.session.add(self) db.session.commit() + def import_bills(self, bills: list): + """Import bills from a list of dictionaries""" + # Add members not already in the project + project_members = [str(m) for m in self.members] + new_members = [ + m for m in get_members(bills) if str(m[0]) not in project_members + ] + for m in new_members: + Person(name=m[0], project=self, weight=m[1]) + db.session.commit() + + # Import bills not already in the project + project_bills = self.get_pretty_bills() + id_dict = {m.name: m.id for m in self.members} + for b in bills: + same = False + for p_b in project_bills: + if same_bill(p_b, b): + same = True + break + if not same: + # Create bills + try: + new_bill = Bill( + amount=b["amount"], + date=parse(b["date"]), + external_link="", + original_currency=b["currency"], + owers=Person.query.get_by_names(b["owers"], self), + payer_id=id_dict[b["payer_name"]], + project_default_currency=self.default_currency, + what=b["what"], + ) + except Exception as e: + raise ValueError(f"Unable to import csv data: {repr(e)}") + db.session.add(new_bill) + db.session.commit() + def remove_member(self, member_id): """Remove a member from the project. @@ -435,16 +475,17 @@ class Project(db.Model): ("Alice", 20, ("Amina", "Alice"), "Beer !"), ("Amina", 50, ("Amina", "Alice", "Georg"), "AMAP"), ) - for (payer, amount, owers, subject) in operations: - bill = Bill() - bill.payer_id = members[payer].id - bill.what = subject - bill.owers = [members[name] for name in owers] - bill.amount = amount - bill.original_currency = "XXX" - bill.converted_amount = amount - - db.session.add(bill) + for (payer, amount, owers, what) in operations: + db.session.add( + Bill( + amount=amount, + original_currency=project.default_currency, + owers=[members[name] for name in owers], + payer_id=members[payer].id, + project_default_currency=project.default_currency, + what=what, + ) + ) db.session.commit() return project @@ -459,6 +500,13 @@ class Person(db.Model): .one_or_none() ) + def get_by_names(self, names, project): + return ( + Person.query.filter(Person.name.in_(names)) + .filter(Person.project_id == project.id) + .all() + ) + def get(self, id, project=None): if not project: project = g.project @@ -468,6 +516,15 @@ class Person(db.Model): .one_or_none() ) + def get_by_ids(self, ids, project=None): + if not project: + project = g.project + return ( + Person.query.filter(Person.id.in_(ids)) + .filter(Person.project_id == project.id) + .all() + ) + query_class = PersonQuery # Direct SQLAlchemy-Continuum to track changes to this model @@ -561,6 +618,31 @@ class Bill(db.Model): archive = db.Column(db.Integer, db.ForeignKey("archive.id")) + currency_helper = CurrencyConverter() + + def __init__( + self, + amount: float, + date: datetime = None, + external_link: str = "", + original_currency: str = "", + owers: list = [], + payer_id: int = None, + project_default_currency: str = "", + what: str = "", + ): + super().__init__() + self.amount = amount + self.date = date + self.external_link = external_link + self.original_currency = original_currency + self.owers = owers + self.payer_id = payer_id + self.what = what + self.converted_amount = self.currency_helper.exchange_currency( + self.amount, self.original_currency, project_default_currency + ) + @property def _to_serialize(self): return { diff --git a/ihatemoney/templates/edit_project.html b/ihatemoney/templates/edit_project.html index 7fcc725e..7ea47d9f 100644 --- a/ihatemoney/templates/edit_project.html +++ b/ihatemoney/templates/edit_project.html @@ -29,7 +29,7 @@

{{ _("Edit project") }}

-
+ {{ forms.edit_project(edit_form) }}
@@ -38,23 +38,10 @@ {{ forms.delete_project(delete_form) }} -

{{ _("Import JSON") }}

-
- {{ import_form.hidden_tag() }} -
-
- {{ import_form.file(class="custom-file-input") }} - - {{ import_form.file.description }} - -
- -
- -
- {{ import_form.submit(class="btn btn-primary") }} -
+

{{ _("Import project") }}

+ + {{ forms.import_project(import_form) }}

{{ _("Download project's data") }}

diff --git a/ihatemoney/templates/forms.html b/ihatemoney/templates/forms.html index e6662b76..f93cfc6d 100644 --- a/ihatemoney/templates/forms.html +++ b/ihatemoney/templates/forms.html @@ -118,13 +118,27 @@ {% endmacro %} -{% macro upload_json(form) %} +{% macro import_project(form) %} + {% include "display_errors.html" %} {{ form.hidden_tag() }} - {{ form.file }} -
- + +

{{ _("Import previously exported project") }}

+ +
+
+ {{ form.file(class="custom-file-input", accept=".json,.csv") }} + + {{ form.file.description }} + +
+
+ +
+ +
+ {% endmacro %} {% macro delete_project_history(form) %} diff --git a/ihatemoney/tests/budget_test.py b/ihatemoney/tests/budget_test.py index 92618594..b9586bfa 100644 --- a/ihatemoney/tests/budget_test.py +++ b/ihatemoney/tests/budget_test.py @@ -1,6 +1,4 @@ from collections import defaultdict -import io -import json import re from time import sleep import unittest @@ -191,8 +189,7 @@ class BudgetTestCase(IhatemoneyTestCase): self.assertIn("Invalid token", resp.data.decode("utf-8")) def test_project_creation(self): - with self.app.test_client() as c: - + with self.client as c: with self.app.mail.record_messages() as outbox: # add a valid project resp = c.post( @@ -222,7 +219,7 @@ class BudgetTestCase(IhatemoneyTestCase): self.assertEqual(len(models.Project.query.all()), 1) # Add a second project with the same id - models.Project.query.get("raclette") + self.get_project("raclette") c.post( "/create", @@ -240,7 +237,7 @@ class BudgetTestCase(IhatemoneyTestCase): def test_project_creation_without_public_permissions(self): self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"] = False - with self.app.test_client() as c: + with self.client as c: # add a valid project c.post( "/create", @@ -261,7 +258,7 @@ class BudgetTestCase(IhatemoneyTestCase): def test_project_creation_with_public_permissions(self): self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"] = True - with self.app.test_client() as c: + with self.client as c: # add a valid project c.post( "/create", @@ -282,7 +279,7 @@ class BudgetTestCase(IhatemoneyTestCase): def test_project_deletion(self): - with self.app.test_client() as c: + with self.client as c: c.post( "/create", data={ @@ -340,17 +337,17 @@ class BudgetTestCase(IhatemoneyTestCase): # adds a member to this project self.client.post("/raclette/members/add", data={"name": "zorglub"}) - self.assertEqual(len(models.Project.query.get("raclette").members), 1) + self.assertEqual(len(self.get_project("raclette").members), 1) # adds him twice result = self.client.post("/raclette/members/add", data={"name": "zorglub"}) # should not accept him - self.assertEqual(len(models.Project.query.get("raclette").members), 1) + self.assertEqual(len(self.get_project("raclette").members), 1) # add fred self.client.post("/raclette/members/add", data={"name": "fred"}) - self.assertEqual(len(models.Project.query.get("raclette").members), 2) + self.assertEqual(len(self.get_project("raclette").members), 2) # check fred is present in the bills page result = self.client.get("/raclette/") @@ -358,16 +355,15 @@ class BudgetTestCase(IhatemoneyTestCase): # remove fred self.client.post( - "/raclette/members/%s/delete" - % models.Project.query.get("raclette").members[-1].id + "/raclette/members/%s/delete" % self.get_project("raclette").members[-1].id ) # as fred is not bound to any bill, he is removed - self.assertEqual(len(models.Project.query.get("raclette").members), 1) + self.assertEqual(len(self.get_project("raclette").members), 1) # add fred again self.client.post("/raclette/members/add", data={"name": "fred"}) - fred_id = models.Project.query.get("raclette").members[-1].id + fred_id = self.get_project("raclette").members[-1].id # bound him to a bill result = self.client.post( @@ -385,8 +381,8 @@ class BudgetTestCase(IhatemoneyTestCase): self.client.post(f"/raclette/members/{fred_id}/delete") # he is still in the database, but is deactivated - self.assertEqual(len(models.Project.query.get("raclette").members), 2) - self.assertEqual(len(models.Project.query.get("raclette").active_members), 1) + self.assertEqual(len(self.get_project("raclette").members), 2) + self.assertEqual(len(self.get_project("raclette").active_members), 1) # as fred is now deactivated, check that he is not listed when adding # a bill or displaying the balance @@ -400,14 +396,14 @@ class BudgetTestCase(IhatemoneyTestCase): # adding him again should reactivate him self.client.post("/raclette/members/add", data={"name": "fred"}) - self.assertEqual(len(models.Project.query.get("raclette").active_members), 2) + self.assertEqual(len(self.get_project("raclette").active_members), 2) # adding an user with the same name as another user from a different # project should not cause any troubles self.post_project("randomid") self.login("randomid") self.client.post("/randomid/members/add", data={"name": "fred"}) - self.assertEqual(len(models.Project.query.get("randomid").active_members), 1) + self.assertEqual(len(self.get_project("randomid").active_members), 1) def test_person_model(self): self.post_project("raclette") @@ -415,7 +411,7 @@ class BudgetTestCase(IhatemoneyTestCase): # adds a member to this project self.client.post("/raclette/members/add", data={"name": "zorglub"}) - zorglub = models.Project.query.get("raclette").members[-1] + zorglub = self.get_project("raclette").members[-1] # should not have any bills self.assertFalse(zorglub.has_bills()) @@ -433,7 +429,7 @@ class BudgetTestCase(IhatemoneyTestCase): ) # should have a bill now - zorglub = models.Project.query.get("raclette").members[-1] + zorglub = self.get_project("raclette").members[-1] self.assertTrue(zorglub.has_bills()) def test_member_delete_method(self): @@ -449,7 +445,7 @@ class BudgetTestCase(IhatemoneyTestCase): # delete user using POST method self.client.post("/raclette/members/1/delete") - self.assertEqual(len(models.Project.query.get("raclette").active_members), 0) + self.assertEqual(len(self.get_project("raclette").active_members), 0) # try to delete an user already deleted self.client.post("/raclette/members/1/delete") @@ -457,7 +453,7 @@ class BudgetTestCase(IhatemoneyTestCase): # test that a demo project is created if none is defined self.assertEqual([], models.Project.query.all()) self.client.get("/demo") - demo = models.Project.query.get("demo") + demo = self.get_project("demo") self.assertTrue(demo is not None) self.assertEqual(["Amina", "Georg", "Alice"], [m.name for m in demo.members]) @@ -485,14 +481,14 @@ class BudgetTestCase(IhatemoneyTestCase): self.assertIn("Authentication", resp.data.decode("utf-8")) # try to connect with wrong credentials should not work - with self.app.test_client() as c: + with self.client as c: resp = c.post("/authenticate", data={"id": "raclette", "password": "nope"}) self.assertIn("Authentication", resp.data.decode("utf-8")) self.assertNotIn("raclette", session) # try to connect with the right credentials should work - with self.app.test_client() as c: + with self.client as c: resp = c.post( "/authenticate", data={"id": "raclette", "password": "raclette"} ) @@ -507,7 +503,7 @@ class BudgetTestCase(IhatemoneyTestCase): # test that with admin credentials, one can access every project self.app.config["ADMIN_PASSWORD"] = generate_password_hash("pass") - with self.app.test_client() as c: + with self.client as c: resp = c.post("/admin?goto=%2Fraclette", data={"admin_password": "pass"}) self.assertNotIn("Authentication", resp.data.decode("utf-8")) self.assertTrue(session["is_admin"]) @@ -516,7 +512,7 @@ class BudgetTestCase(IhatemoneyTestCase): self.post_project("Raclette") # try to connect with the right credentials should work - with self.app.test_client() as c: + with self.client as c: resp = c.post( "/authenticate", data={"id": "Raclette", "password": "Raclette"} ) @@ -586,7 +582,7 @@ class BudgetTestCase(IhatemoneyTestCase): self.client.post("/raclette/members/add", data={"name": "zorglub"}) self.client.post("/raclette/members/add", data={"name": "fred"}) - members_ids = [m.id for m in models.Project.query.get("raclette").members] + members_ids = [m.id for m in self.get_project("raclette").members] # create a bill self.client.post( @@ -599,7 +595,7 @@ class BudgetTestCase(IhatemoneyTestCase): "amount": "25", }, ) - models.Project.query.get("raclette") + self.get_project("raclette") bill = models.Bill.query.one() self.assertEqual(bill.amount, 25) @@ -660,7 +656,7 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) - balance = models.Project.query.get("raclette").balance + balance = self.get_project("raclette").balance self.assertEqual(set(balance.values()), set([19.0, -19.0])) # Bill with negative amount @@ -729,7 +725,7 @@ class BudgetTestCase(IhatemoneyTestCase): "/raclette/members/add", data={"name": "freddy familly", "weight": 4} ) - members_ids = [m.id for m in models.Project.query.get("raclette").members] + members_ids = [m.id for m in self.get_project("raclette").members] # test balance self.client.post( @@ -754,7 +750,7 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) - balance = models.Project.query.get("raclette").balance + balance = self.get_project("raclette").balance self.assertEqual(set(balance.values()), set([6, -6])) def test_trimmed_members(self): @@ -763,7 +759,7 @@ class BudgetTestCase(IhatemoneyTestCase): # Add two times the same person (with a space at the end). self.client.post("/raclette/members/add", data={"name": "zorglub"}) self.client.post("/raclette/members/add", data={"name": "zorglub "}) - members = models.Project.query.get("raclette").members + members = self.get_project("raclette").members self.assertEqual(len(members), 1) @@ -795,8 +791,8 @@ class BudgetTestCase(IhatemoneyTestCase): # An error should be generated, and its weight should still be 1. self.assertIn('

', resp.data.decode("utf-8")) - self.assertEqual(len(models.Project.query.get("raclette").members), 1) - self.assertEqual(models.Project.query.get("raclette").members[0].weight, 1) + self.assertEqual(len(self.get_project("raclette").members), 1) + self.assertEqual(self.get_project("raclette").members[0].weight, 1) def test_rounding(self): self.post_project("raclette") @@ -840,11 +836,11 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) - balance = models.Project.query.get("raclette").balance + balance = self.get_project("raclette").balance result = {} - result[models.Project.query.get("raclette").members[0].id] = 8.12 - result[models.Project.query.get("raclette").members[1].id] = 0.0 - result[models.Project.query.get("raclette").members[2].id] = -8.12 + result[self.get_project("raclette").members[0].id] = 8.12 + result[self.get_project("raclette").members[1].id] = 0.0 + result[self.get_project("raclette").members[2].id] = -8.12 # Since we're using floating point to store currency, we can have some # rounding issues that prevent test from working. # However, we should obtain the same values as the theoretical ones if we @@ -866,7 +862,7 @@ class BudgetTestCase(IhatemoneyTestCase): resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True) self.assertEqual(resp.status_code, 200) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") self.assertEqual(project.name, new_data["name"]) self.assertEqual(project.contact_email, new_data["contact_email"]) @@ -1030,14 +1026,14 @@ class BudgetTestCase(IhatemoneyTestCase): "amount": "10", }, ) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") transactions = project.get_transactions_to_settle_bill() members = defaultdict(int) # We should have the same values between transactions and project balances for t in transactions: members[t["ower"]] -= t["amount"] members[t["receiver"]] += t["amount"] - balance = models.Project.query.get("raclette").balance + balance = self.get_project("raclette").balance for m, a in members.items(): assert abs(a - balance[m.id]) < 0.01 return @@ -1083,7 +1079,7 @@ class BudgetTestCase(IhatemoneyTestCase): "amount": "13.33", }, ) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") transactions = project.get_transactions_to_settle_bill() # There should not be any zero-amount transfer after rounding @@ -1095,768 +1091,6 @@ class BudgetTestCase(IhatemoneyTestCase): msg=f"{t['amount']} is equal to zero after rounding", ) - def test_export(self): - # Export a simple project without currencies - - self.post_project("raclette") - - # add participants - self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) - self.client.post("/raclette/members/add", data={"name": "fred"}) - self.client.post("/raclette/members/add", data={"name": "tata"}) - self.client.post("/raclette/members/add", data={"name": "pépé"}) - - # create bills - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "fromage à raclette", - "payer": 1, - "payed_for": [1, 2, 3, 4], - "amount": "10.0", - }, - ) - - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "red wine", - "payer": 2, - "payed_for": [1, 3], - "amount": "200", - }, - ) - - self.client.post( - "/raclette/add", - data={ - "date": "2017-01-01", - "what": "refund", - "payer": 3, - "payed_for": [2], - "amount": "13.33", - }, - ) - - # generate json export of bills - resp = self.client.get("/raclette/export/bills.json") - expected = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "XXX", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "red wine", - "amount": 200.0, - "currency": "XXX", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage \xe0 raclette", - "amount": 10.0, - "currency": "XXX", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"], - }, - ] - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of bills - resp = self.client.get("/raclette/export/bills.csv") - expected = [ - "date,what,amount,currency,payer_name,payer_weight,owers", - "2017-01-01,refund,XXX,13.33,tata,1.0,fred", - '2016-12-31,red wine,XXX,200.0,fred,1.0,"zorglub, tata"', - '2016-12-31,fromage à raclette,10.0,XXX,zorglub,2.0,"zorglub, fred, tata, pépé"', - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - # generate json export of transactions - resp = self.client.get("/raclette/export/transactions.json") - expected = [ - { - "amount": 2.00, - "currency": "XXX", - "receiver": "fred", - "ower": "p\xe9p\xe9", - }, - {"amount": 55.34, "currency": "XXX", "receiver": "fred", "ower": "tata"}, - { - "amount": 127.33, - "currency": "XXX", - "receiver": "fred", - "ower": "zorglub", - }, - ] - - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of transactions - resp = self.client.get("/raclette/export/transactions.csv") - - expected = [ - "amount,currency,receiver,ower", - "2.0,XXX,fred,pépé", - "55.34,XXX,fred,tata", - "127.33,XXX,fred,zorglub", - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - # wrong export_format should return a 404 - resp = self.client.get("/raclette/export/transactions.wrong") - self.assertEqual(resp.status_code, 404) - - def test_export_with_currencies(self): - self.post_project("raclette", default_currency="EUR") - - # add participants - self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) - self.client.post("/raclette/members/add", data={"name": "fred"}) - self.client.post("/raclette/members/add", data={"name": "tata"}) - self.client.post("/raclette/members/add", data={"name": "pépé"}) - - # create bills - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "fromage à raclette", - "payer": 1, - "payed_for": [1, 2, 3, 4], - "amount": "10.0", - "original_currency": "EUR", - }, - ) - - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "poutine from Québec", - "payer": 2, - "payed_for": [1, 3], - "amount": "100", - "original_currency": "CAD", - }, - ) - - self.client.post( - "/raclette/add", - data={ - "date": "2017-01-01", - "what": "refund", - "payer": 3, - "payed_for": [2], - "amount": "13.33", - "original_currency": "EUR", - }, - ) - - # generate json export of bills - resp = self.client.get("/raclette/export/bills.json") - expected = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "EUR", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "poutine from Qu\xe9bec", - "amount": 100.0, - "currency": "CAD", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage \xe0 raclette", - "amount": 10.0, - "currency": "EUR", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"], - }, - ] - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of bills - resp = self.client.get("/raclette/export/bills.csv") - expected = [ - "date,what,amount,currency,payer_name,payer_weight,owers", - "2017-01-01,refund,13.33,EUR,tata,1.0,fred", - '2016-12-31,poutine from Québec,100.0,CAD,fred,1.0,"zorglub, tata"', - '2016-12-31,fromage à raclette,10.0,EUR,zorglub,2.0,"zorglub, fred, tata, pépé"', - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - # generate json export of transactions (in EUR!) - resp = self.client.get("/raclette/export/transactions.json") - expected = [ - { - "amount": 2.00, - "currency": "EUR", - "receiver": "fred", - "ower": "p\xe9p\xe9", - }, - {"amount": 10.89, "currency": "EUR", "receiver": "fred", "ower": "tata"}, - {"amount": 38.45, "currency": "EUR", "receiver": "fred", "ower": "zorglub"}, - ] - - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of transactions - resp = self.client.get("/raclette/export/transactions.csv") - - expected = [ - "amount,currency,receiver,ower", - "2.0,EUR,fred,pépé", - "10.89,EUR,fred,tata", - "38.45,EUR,fred,zorglub", - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - # Change project currency to CAD - project = models.Project.query.get("raclette") - project.switch_currency("CAD") - - # generate json export of transactions (now in CAD!) - resp = self.client.get("/raclette/export/transactions.json") - expected = [ - { - "amount": 3.00, - "currency": "CAD", - "receiver": "fred", - "ower": "p\xe9p\xe9", - }, - {"amount": 16.34, "currency": "CAD", "receiver": "fred", "ower": "tata"}, - {"amount": 57.67, "currency": "CAD", "receiver": "fred", "ower": "zorglub"}, - ] - - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of transactions - resp = self.client.get("/raclette/export/transactions.csv") - - expected = [ - "amount,currency,receiver,ower", - "3.0,CAD,fred,pépé", - "16.34,CAD,fred,tata", - "57.67,CAD,fred,zorglub", - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - def test_import_currencies_in_empty_project_with_currency(self): - # Import JSON with currencies in an empty project with a default currency - - self.post_project("raclette", default_currency="EUR") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "EUR", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "poutine from québec", - "amount": 50.0, - "currency": "CAD", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "currency": "EUR", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - self.assertEqual(j["currency"], i["currency"]) - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_single_currency_in_empty_project_without_currency(self): - # Import JSON with a single currency in an empty project with no - # default currency. It should work by stripping the currency from - # bills. - - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "EUR", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "currency": "EUR", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - # Currency should have been stripped - self.assertEqual(j["currency"], "XXX") - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_multiple_currencies_in_empty_project_without_currency(self): - # Import JSON with multiple currencies in an empty project with no - # default currency. It should fail. - - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "EUR", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "poutine from québec", - "amount": 50.0, - "currency": "CAD", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "currency": "EUR", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - # Import should fail - with pytest.raises(ValueError): - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check that there are no bills - self.assertEqual(len(bills), 0) - - def test_import_no_currency_in_empty_project_with_currency(self): - # Import JSON without currencies (from ihatemoney < 5) in an empty - # project with a default currency. - - self.post_project("raclette", default_currency="EUR") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "red wine", - "amount": 200.0, - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - # All bills are converted to default project currency - self.assertEqual(j["currency"], "EUR") - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_no_currency_in_empty_project_without_currency(self): - # Import JSON without currencies (from ihatemoney < 5) in an empty - # project with no default currency. - - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "red wine", - "amount": 200.0, - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - self.assertEqual(j["currency"], "XXX") - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_partial_project(self): - # Import a JSON in a project with already existing data - - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) - self.client.post("/raclette/members/add", data={"name": "fred"}) - self.client.post("/raclette/members/add", data={"name": "tata"}) - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "red wine", - "payer": 2, - "payed_for": [1, 3], - "amount": "200", - }, - ) - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "XXX", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { # This expense does not have to be present twice. - "date": "2016-12-31", - "what": "red wine", - "amount": 200.0, - "currency": "XXX", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "currency": "XXX", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - self.assertEqual(j["currency"], i["currency"]) - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_wrong_json(self): - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_1 = [ - { # wrong keys - "checked": False, - "dimensions": {"width": 5, "height": 10}, - "id": 1, - "name": "A green door", - "price": 12.5, - "tags": ["home", "green"], - } - ] - - json_2 = [ - { # amount missing - "date": "2017-01-01", - "what": "refund", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - } - ] - - from ihatemoney.web import import_project - - for data in [json_1, json_2]: - file = io.StringIO() - json.dump(data, file) - file.seek(0) - with pytest.raises(ValueError): - import_project(file, project) - def test_access_other_projects(self): """Test that accessing or editing bills and participants from another project fails""" # Create project @@ -1880,7 +1114,7 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) # Ensure it has been created - raclette = models.Project.query.get("raclette") + raclette = self.get_project("raclette") self.assertEqual(raclette.get_bills().count(), 1) # Log out @@ -2025,7 +1259,7 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") # First all converted_amount should be the same as amount, with no currency for bill in project.get_bills(): @@ -2106,7 +1340,7 @@ class BudgetTestCase(IhatemoneyTestCase): # A user displayed error should be generated, and its currency should be the same. self.assertStatus(200, resp) self.assertIn('

', resp.data.decode("utf-8")) - self.assertEqual(models.Project.query.get("raclette").default_currency, "USD") + self.assertEqual(self.get_project("raclette").default_currency, "USD") def test_currency_switch_to_bill_currency(self): @@ -2130,7 +1364,7 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") bill = project.get_bills().first() assert bill.converted_amount == self.converter.exchange_currency( @@ -2176,7 +1410,7 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") for bill in project.get_bills_unordered(): assert bill.converted_amount == self.converter.exchange_currency( diff --git a/ihatemoney/tests/common/ihatemoney_testcase.py b/ihatemoney/tests/common/ihatemoney_testcase.py index 5ad0a56a..eda86ab7 100644 --- a/ihatemoney/tests/common/ihatemoney_testcase.py +++ b/ihatemoney/tests/common/ihatemoney_testcase.py @@ -74,6 +74,14 @@ class BaseTestCase(TestCase): follow_redirects=follow_redirects, ) + def import_project(self, id, data, success=True): + resp = self.client.post( + f"/{id}/import", + data=data, + # follow_redirects=True, + ) + self.assertEqual("/{id}/edit" in str(resp.response), not success) + def create_project(self, id, default_currency="XXX", name=None, password=None): name = name or str(id) password = password or id @@ -87,6 +95,9 @@ class BaseTestCase(TestCase): models.db.session.add(project) models.db.session.commit() + def get_project(self, id) -> models.Project: + return models.Project.query.get(id) + class IhatemoneyTestCase(BaseTestCase): TESTING = True diff --git a/ihatemoney/tests/history_test.py b/ihatemoney/tests/history_test.py index 38a3740e..b742a8c8 100644 --- a/ihatemoney/tests/history_test.py +++ b/ihatemoney/tests/history_test.py @@ -614,14 +614,14 @@ class HistoryTestCase(IhatemoneyTestCase): models.db.session.add(b2) models.db.session.commit() - history_list = history.get_history(models.Project.query.get("demo")) + history_list = history.get_history(self.get_project("demo")) self.assertEqual(len(history_list), 5) # Change just the amount b1.amount = 5 models.db.session.commit() - history_list = history.get_history(models.Project.query.get("demo")) + history_list = history.get_history(self.get_project("demo")) for entry in history_list: if "prop_changed" in entry: self.assertNotIn("owers", entry["prop_changed"]) diff --git a/ihatemoney/tests/import_test.py b/ihatemoney/tests/import_test.py new file mode 100644 index 00000000..f1ee1d07 --- /dev/null +++ b/ihatemoney/tests/import_test.py @@ -0,0 +1,614 @@ +import copy +import json +import unittest + +from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase +from ihatemoney.utils import list_of_dicts2csv, list_of_dicts2json + + +class CommonTestCase(object): + class Import(IhatemoneyTestCase): + def setUp(self): + super().setUp() + self.data = [ + { + "date": "2017-01-01", + "what": "refund", + "amount": 13.33, + "payer_name": "tata", + "payer_weight": 1.0, + "owers": ["fred"], + }, + { + "date": "2016-12-31", + "what": "red wine", + "amount": 200.0, + "payer_name": "fred", + "payer_weight": 1.0, + "owers": ["zorglub", "tata"], + }, + { + "date": "2016-12-31", + "what": "fromage a raclette", + "amount": 10.0, + "payer_name": "zorglub", + "payer_weight": 2.0, + "owers": ["zorglub", "fred", "tata", "pepe"], + }, + ] + + def populate_data_with_currencies(self, currencies): + for d in range(len(self.data)): + self.data[d]["currency"] = currencies[d] + + def test_import_currencies_in_empty_project_with_currency(self): + # Import JSON with currencies in an empty project with a default currency + + self.post_project("raclette", default_currency="EUR") + self.login("raclette") + + project = self.get_project("raclette") + + self.populate_data_with_currencies(["EUR", "CAD", "EUR"]) + self.import_project("raclette", self.generate_form_data(self.data)) + + bills = project.get_pretty_bills() + + # Check if all bills have been added + self.assertEqual(len(bills), len(self.data)) + + # Check if name of bills are ok + b = [e["what"] for e in bills] + b.sort() + ref = [e["what"] for e in self.data] + ref.sort() + + self.assertEqual(b, ref) + + # Check if other informations in bill are ok + for d in self.data: + for b in bills: + if b["what"] == d["what"]: + self.assertEqual(b["payer_name"], d["payer_name"]) + self.assertEqual(b["amount"], d["amount"]) + self.assertEqual(b["currency"], d["currency"]) + self.assertEqual(b["payer_weight"], d["payer_weight"]) + self.assertEqual(b["date"], d["date"]) + list_project = [ower for ower in b["owers"]] + list_project.sort() + list_json = [ower for ower in d["owers"]] + list_json.sort() + self.assertEqual(list_project, list_json) + + def test_import_single_currency_in_empty_project_without_currency(self): + # Import JSON with a single currency in an empty project with no + # default currency. It should work by stripping the currency from + # bills. + + self.post_project("raclette") + self.login("raclette") + + project = self.get_project("raclette") + + self.populate_data_with_currencies(["EUR", "EUR", "EUR"]) + self.import_project("raclette", self.generate_form_data(self.data)) + + bills = project.get_pretty_bills() + + # Check if all bills have been added + self.assertEqual(len(bills), len(self.data)) + + # Check if name of bills are ok + b = [e["what"] for e in bills] + b.sort() + ref = [e["what"] for e in self.data] + ref.sort() + + self.assertEqual(b, ref) + + # Check if other informations in bill are ok + for d in self.data: + for b in bills: + if b["what"] == d["what"]: + self.assertEqual(b["payer_name"], d["payer_name"]) + self.assertEqual(b["amount"], d["amount"]) + # Currency should have been stripped + self.assertEqual(b["currency"], "XXX") + self.assertEqual(b["payer_weight"], d["payer_weight"]) + self.assertEqual(b["date"], d["date"]) + list_project = [ower for ower in b["owers"]] + list_project.sort() + list_json = [ower for ower in d["owers"]] + list_json.sort() + self.assertEqual(list_project, list_json) + + def test_import_multiple_currencies_in_empty_project_without_currency(self): + # Import JSON with multiple currencies in an empty project with no + # default currency. It should fail. + + self.post_project("raclette") + self.login("raclette") + + project = self.get_project("raclette") + + self.populate_data_with_currencies(["EUR", "CAD", "EUR"]) + # Import should fail + self.import_project("raclette", self.generate_form_data(self.data), 400) + + bills = project.get_pretty_bills() + + # Check that there are no bills + self.assertEqual(len(bills), 0) + + def test_import_no_currency_in_empty_project_with_currency(self): + # Import JSON without currencies (from ihatemoney < 5) in an empty + # project with a default currency. + + self.post_project("raclette", default_currency="EUR") + self.login("raclette") + + project = self.get_project("raclette") + + self.import_project("raclette", self.generate_form_data(self.data)) + + bills = project.get_pretty_bills() + + # Check if all bills have been added + self.assertEqual(len(bills), len(self.data)) + + # Check if name of bills are ok + b = [e["what"] for e in bills] + b.sort() + ref = [e["what"] for e in self.data] + ref.sort() + + self.assertEqual(b, ref) + + # Check if other informations in bill are ok + for d in self.data: + for b in bills: + if b["what"] == d["what"]: + self.assertEqual(b["payer_name"], d["payer_name"]) + self.assertEqual(b["amount"], d["amount"]) + # All bills are converted to default project currency + self.assertEqual(b["currency"], "EUR") + self.assertEqual(b["payer_weight"], d["payer_weight"]) + self.assertEqual(b["date"], d["date"]) + list_project = [ower for ower in b["owers"]] + list_project.sort() + list_json = [ower for ower in d["owers"]] + list_json.sort() + self.assertEqual(list_project, list_json) + + def test_import_no_currency_in_empty_project_without_currency(self): + # Import JSON without currencies (from ihatemoney < 5) in an empty + # project with no default currency. + + self.post_project("raclette") + self.login("raclette") + + project = self.get_project("raclette") + + self.import_project("raclette", self.generate_form_data(self.data)) + + bills = project.get_pretty_bills() + + # Check if all bills have been added + self.assertEqual(len(bills), len(self.data)) + + # Check if name of bills are ok + b = [e["what"] for e in bills] + b.sort() + ref = [e["what"] for e in self.data] + ref.sort() + + self.assertEqual(b, ref) + + # Check if other informations in bill are ok + for d in self.data: + for b in bills: + if b["what"] == d["what"]: + self.assertEqual(b["payer_name"], d["payer_name"]) + self.assertEqual(b["amount"], d["amount"]) + self.assertEqual(b["currency"], "XXX") + self.assertEqual(b["payer_weight"], d["payer_weight"]) + self.assertEqual(b["date"], d["date"]) + list_project = [ower for ower in b["owers"]] + list_project.sort() + list_json = [ower for ower in d["owers"]] + list_json.sort() + self.assertEqual(list_project, list_json) + + def test_import_partial_project(self): + # Import a JSON in a project with already existing data + + self.post_project("raclette") + self.login("raclette") + + project = self.get_project("raclette") + + self.client.post( + "/raclette/members/add", data={"name": "zorglub", "weight": 2} + ) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "tata"}) + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "red wine", + "payer": 2, + "payed_for": [1, 3], + "amount": "200", + }, + ) + + self.populate_data_with_currencies(["XXX", "XXX", "XXX"]) + + self.import_project("raclette", self.generate_form_data(self.data)) + + bills = project.get_pretty_bills() + + # Check if all bills have been added + self.assertEqual(len(bills), len(self.data)) + + # Check if name of bills are ok + b = [e["what"] for e in bills] + b.sort() + ref = [e["what"] for e in self.data] + ref.sort() + + self.assertEqual(b, ref) + + # Check if other informations in bill are ok + for d in self.data: + for b in bills: + if b["what"] == d["what"]: + self.assertEqual(b["payer_name"], d["payer_name"]) + self.assertEqual(b["amount"], d["amount"]) + self.assertEqual(b["currency"], d["currency"]) + self.assertEqual(b["payer_weight"], d["payer_weight"]) + self.assertEqual(b["date"], d["date"]) + list_project = [ower for ower in b["owers"]] + list_project.sort() + list_json = [ower for ower in d["owers"]] + list_json.sort() + self.assertEqual(list_project, list_json) + + def test_import_wrong_data(self): + self.post_project("raclette") + self.login("raclette") + data_wrong_keys = [ + { + "checked": False, + "dimensions": {"width": 5, "height": 10}, + "id": 1, + "name": "A green door", + "price": 12.5, + "tags": ["home", "green"], + } + ] + data_amount_missing = [ + { + "date": "2017-01-01", + "what": "refund", + "payer_name": "tata", + "payer_weight": 1.0, + "owers": ["fred"], + } + ] + for data in [data_wrong_keys, data_amount_missing]: + # Import should fail + self.import_project("raclette", self.generate_form_data(data), 400) + + +class ExportTestCase(IhatemoneyTestCase): + def test_export(self): + # Export a simple project without currencies + + self.post_project("raclette") + + # add participants + self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "tata"}) + self.client.post("/raclette/members/add", data={"name": "pépé"}) + + # create bills + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2, 3, 4], + "amount": "10.0", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "red wine", + "payer": 2, + "payed_for": [1, 3], + "amount": "200", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2017-01-01", + "what": "refund", + "payer": 3, + "payed_for": [2], + "amount": "13.33", + }, + ) + + # generate json export of bills + resp = self.client.get("/raclette/export/bills.json") + expected = [ + { + "date": "2017-01-01", + "what": "refund", + "amount": 13.33, + "currency": "XXX", + "payer_name": "tata", + "payer_weight": 1.0, + "owers": ["fred"], + }, + { + "date": "2016-12-31", + "what": "red wine", + "amount": 200.0, + "currency": "XXX", + "payer_name": "fred", + "payer_weight": 1.0, + "owers": ["zorglub", "tata"], + }, + { + "date": "2016-12-31", + "what": "fromage \xe0 raclette", + "amount": 10.0, + "currency": "XXX", + "payer_name": "zorglub", + "payer_weight": 2.0, + "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"], + }, + ] + self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) + + # generate csv export of bills + resp = self.client.get("/raclette/export/bills.csv") + expected = [ + "date,what,amount,currency,payer_name,payer_weight,owers", + "2017-01-01,refund,XXX,13.33,tata,1.0,fred", + '2016-12-31,red wine,XXX,200.0,fred,1.0,"zorglub, tata"', + '2016-12-31,fromage à raclette,10.0,XXX,zorglub,2.0,"zorglub, fred, tata, pépé"', + ] + received_lines = resp.data.decode("utf-8").split("\n") + + for i, line in enumerate(expected): + self.assertEqual( + set(line.split(",")), set(received_lines[i].strip("\r").split(",")) + ) + + # generate json export of transactions + resp = self.client.get("/raclette/export/transactions.json") + expected = [ + { + "amount": 2.00, + "currency": "XXX", + "receiver": "fred", + "ower": "p\xe9p\xe9", + }, + {"amount": 55.34, "currency": "XXX", "receiver": "fred", "ower": "tata"}, + { + "amount": 127.33, + "currency": "XXX", + "receiver": "fred", + "ower": "zorglub", + }, + ] + + self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) + + # generate csv export of transactions + resp = self.client.get("/raclette/export/transactions.csv") + + expected = [ + "amount,currency,receiver,ower", + "2.0,XXX,fred,pépé", + "55.34,XXX,fred,tata", + "127.33,XXX,fred,zorglub", + ] + received_lines = resp.data.decode("utf-8").split("\n") + + for i, line in enumerate(expected): + self.assertEqual( + set(line.split(",")), set(received_lines[i].strip("\r").split(",")) + ) + + # wrong export_format should return a 404 + resp = self.client.get("/raclette/export/transactions.wrong") + self.assertEqual(resp.status_code, 404) + + def test_export_with_currencies(self): + self.post_project("raclette", default_currency="EUR") + + # add participants + self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "tata"}) + self.client.post("/raclette/members/add", data={"name": "pépé"}) + + # create bills + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2, 3, 4], + "amount": "10.0", + "original_currency": "EUR", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "poutine from Québec", + "payer": 2, + "payed_for": [1, 3], + "amount": "100", + "original_currency": "CAD", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2017-01-01", + "what": "refund", + "payer": 3, + "payed_for": [2], + "amount": "13.33", + "original_currency": "EUR", + }, + ) + + # generate json export of bills + resp = self.client.get("/raclette/export/bills.json") + expected = [ + { + "date": "2017-01-01", + "what": "refund", + "amount": 13.33, + "currency": "EUR", + "payer_name": "tata", + "payer_weight": 1.0, + "owers": ["fred"], + }, + { + "date": "2016-12-31", + "what": "poutine from Qu\xe9bec", + "amount": 100.0, + "currency": "CAD", + "payer_name": "fred", + "payer_weight": 1.0, + "owers": ["zorglub", "tata"], + }, + { + "date": "2016-12-31", + "what": "fromage \xe0 raclette", + "amount": 10.0, + "currency": "EUR", + "payer_name": "zorglub", + "payer_weight": 2.0, + "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"], + }, + ] + self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) + + # generate csv export of bills + resp = self.client.get("/raclette/export/bills.csv") + expected = [ + "date,what,amount,currency,payer_name,payer_weight,owers", + "2017-01-01,refund,13.33,EUR,tata,1.0,fred", + '2016-12-31,poutine from Québec,100.0,CAD,fred,1.0,"zorglub, tata"', + '2016-12-31,fromage à raclette,10.0,EUR,zorglub,2.0,"zorglub, fred, tata, pépé"', + ] + received_lines = resp.data.decode("utf-8").split("\n") + + for i, line in enumerate(expected): + self.assertEqual( + set(line.split(",")), set(received_lines[i].strip("\r").split(",")) + ) + + # generate json export of transactions (in EUR!) + resp = self.client.get("/raclette/export/transactions.json") + expected = [ + { + "amount": 2.00, + "currency": "EUR", + "receiver": "fred", + "ower": "p\xe9p\xe9", + }, + {"amount": 10.89, "currency": "EUR", "receiver": "fred", "ower": "tata"}, + {"amount": 38.45, "currency": "EUR", "receiver": "fred", "ower": "zorglub"}, + ] + + self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) + + # generate csv export of transactions + resp = self.client.get("/raclette/export/transactions.csv") + + expected = [ + "amount,currency,receiver,ower", + "2.0,EUR,fred,pépé", + "10.89,EUR,fred,tata", + "38.45,EUR,fred,zorglub", + ] + received_lines = resp.data.decode("utf-8").split("\n") + + for i, line in enumerate(expected): + self.assertEqual( + set(line.split(",")), set(received_lines[i].strip("\r").split(",")) + ) + + # Change project currency to CAD + project = self.get_project("raclette") + project.switch_currency("CAD") + + # generate json export of transactions (now in CAD!) + resp = self.client.get("/raclette/export/transactions.json") + expected = [ + { + "amount": 3.00, + "currency": "CAD", + "receiver": "fred", + "ower": "p\xe9p\xe9", + }, + {"amount": 16.34, "currency": "CAD", "receiver": "fred", "ower": "tata"}, + {"amount": 57.67, "currency": "CAD", "receiver": "fred", "ower": "zorglub"}, + ] + + self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) + + # generate csv export of transactions + resp = self.client.get("/raclette/export/transactions.csv") + + expected = [ + "amount,currency,receiver,ower", + "3.0,CAD,fred,pépé", + "16.34,CAD,fred,tata", + "57.67,CAD,fred,zorglub", + ] + received_lines = resp.data.decode("utf-8").split("\n") + + for i, line in enumerate(expected): + self.assertEqual( + set(line.split(",")), set(received_lines[i].strip("\r").split(",")) + ) + + +class ImportTestCaseJSON(CommonTestCase.Import): + def generate_form_data(self, data): + return {"file": (list_of_dicts2json(data), "test.json")} + + +class ImportTestCaseCSV(CommonTestCase.Import): + def generate_form_data(self, data): + formatted_data = copy.deepcopy(data) + for d in formatted_data: + d["owers"] = ", ".join([o for o in d.get("owers", [])]) + return {"file": (list_of_dicts2csv(formatted_data), "test.csv")} + + +if __name__ == "__main__": + unittest.main() diff --git a/ihatemoney/tests/main_test.py b/ihatemoney/tests/main_test.py index 9707bd8a..fac655d3 100644 --- a/ihatemoney/tests/main_test.py +++ b/ihatemoney/tests/main_test.py @@ -99,7 +99,7 @@ class CommandTestCase(BaseTestCase): def test_demo_project_deletion(self): self.create_project("demo") - self.assertEqual(models.Project.query.get("demo").name, "demo") + self.assertEqual(self.get_project("demo").name, "demo") runner = self.app.test_cli_runner() runner.invoke(delete_project, "demo") @@ -246,7 +246,7 @@ class CaptchaTestCase(IhatemoneyTestCase): ENABLE_CAPTCHA = True def test_project_creation_with_captcha(self): - with self.app.test_client() as c: + with self.client as c: c.post( "/create", data={ diff --git a/ihatemoney/utils.py b/ihatemoney/utils.py index 66f5b6a4..96b80816 100644 --- a/ihatemoney/utils.py +++ b/ihatemoney/utils.py @@ -2,7 +2,7 @@ import ast import csv from datetime import datetime, timedelta from enum import Enum -from io import BytesIO, StringIO +from io import BytesIO, StringIO, TextIOWrapper from json import JSONEncoder, dumps import operator import os @@ -150,6 +150,31 @@ def list_of_dicts2csv(dict_to_convert): return csv_file +def csv2list_of_dicts(csv_to_convert): + """Take a csv in-memory file and turns it into + a list of dictionnaries + """ + csv_file = TextIOWrapper(csv_to_convert, encoding="utf-8") + reader = csv.DictReader(csv_file) + result = [] + for r in reader: + """ + cospend embeds various data helping (cospend) imports + 'deleteMeIfYouWant' lines contains users + 'categoryname' table contains categories description + we don't need them as we determine users and categories from bills + """ + if r["what"] == "deleteMeIfYouWant": + continue + elif r["what"] == "categoryname": + break + r["amount"] = float(r["amount"]) + r["payer_weight"] = float(r["payer_weight"]) + r["owers"] = [o.strip() for o in r["owers"].split(",")] + result.append(r) + return result + + class LoginThrottler: """Simple login throttler used to limit authentication attempts based on client's ip address. When using multiple workers, remaining number of attempts can get inconsistent diff --git a/ihatemoney/web.py b/ihatemoney/web.py index 7f986ee5..e77b41f3 100644 --- a/ihatemoney/web.py +++ b/ihatemoney/web.py @@ -13,7 +13,6 @@ from functools import wraps import json import os -from dateutil.parser import parse from dateutil.relativedelta import relativedelta from flask import ( Blueprint, @@ -44,13 +43,13 @@ from ihatemoney.forms import ( DestructiveActionProjectForm, EditProjectForm, EmptyForm, + ImportProjectForm, InviteForm, MemberForm, PasswordReminder, ProjectForm, ProjectFormWithCaptcha, ResetPasswordForm, - UploadForm, get_billform_for, ) from ihatemoney.history import get_history, get_history_queries @@ -58,12 +57,11 @@ from ihatemoney.models import Bill, LoggingMode, Person, Project, db from ihatemoney.utils import ( LoginThrottler, Redirect303, + csv2list_of_dicts, format_form_errors, - get_members, list_of_dicts2csv, list_of_dicts2json, render_localized_template, - same_bill, send_email, ) @@ -412,17 +410,8 @@ def reset_password(): @main.route("//edit", methods=["GET", "POST"]) def edit_project(): edit_form = EditProjectForm(id=g.project.id) + import_form = ImportProjectForm(id=g.project.id) delete_form = DestructiveActionProjectForm(id=g.project.id) - import_form = UploadForm() - # Import form - if import_form.validate_on_submit(): - try: - import_project(import_form.file.data.stream, g.project) - flash(_("Project successfully uploaded")) - - return redirect(url_for("main.list_bills")) - except ValueError as e: - flash(e.args[0], category="danger") # Edit form if edit_form.validate_on_submit(): @@ -446,103 +435,71 @@ def edit_project(): return render_template( "edit_project.html", edit_form=edit_form, - delete_form=delete_form, import_form=import_form, + delete_form=delete_form, current_view="edit_project", ) -def import_project(file, project): - json_file = json.load(file) +@main.route("//import", methods=["POST"]) +def import_project(): + form = ImportProjectForm() + if form.validate(): + try: + data = form.file.data + if data.mimetype == "application/json": + bills = json.load(data.stream) + elif data.mimetype == "text/csv": + try: + bills = csv2list_of_dicts(data) + except Exception: + raise ValueError(_("Unable to parse CSV")) + else: + raise ValueError("Unsupported file type") - # Check if JSON is correct - attr = ["what", "payer_name", "payer_weight", "amount", "currency", "date", "owers"] - attr.sort() - currencies = set() - for e in json_file: - # If currency is absent, empty, or explicitly set to XXX - # set it to project default. - if e.get("currency", "") in ["", "XXX"]: - e["currency"] = project.default_currency - if len(e) != len(attr): - raise ValueError(_("Invalid JSON")) - list_attr = [] - for i in e: - list_attr.append(i) - list_attr.sort() - if list_attr != attr: - raise ValueError(_("Invalid JSON")) - # Keep track of currencies - currencies.add(e["currency"]) + # Check data + attr = [ + "amount", + "currency", + "date", + "owers", + "payer_name", + "payer_weight", + "what", + ] + currencies = set() + for b in bills: + if b.get("currency", "") in ["", "XXX"]: + b["currency"] = g.project.default_currency + for a in attr: + if a not in b: + raise ValueError(_("Missing attribute {}").format(a)) + currencies.add(b["currency"]) - # Additional checks if project has no default currency - if project.default_currency == CurrencyConverter.no_currency: - # If bills have currencies, they must be consistent - if len(currencies - {CurrencyConverter.no_currency}) >= 2: - raise ValueError( - _( - "Cannot add bills in multiple currencies to a project without default currency" - ) - ) - # Strip currency from bills (since it's the same for every bill) - for e in json_file: - e["currency"] = CurrencyConverter.no_currency + # Additional checks if project has no default currency + if g.project.default_currency == CurrencyConverter.no_currency: + # If bills have currencies, they must be consistent + if len(currencies - {CurrencyConverter.no_currency}) >= 2: + raise ValueError( + _( + "Cannot add bills in multiple currencies to a project without default " + "currency" + ) + ) + # Strip currency from bills (since it's the same for every bill) + for b in bills: + b["currency"] = CurrencyConverter.no_currency - # From json : export list of members - members_json = get_members(json_file) - members = project.members - members_already_here = list() - for m in members: - members_already_here.append(str(m)) + g.project.import_bills(bills) - # List all members not in the project and weight associated - # List of tuples (name,weight) - members_to_add = list() - for i in members_json: - if str(i[0]) not in members_already_here: - members_to_add.append(i) - - # List bills not in the project - # Same format than JSON element - project_bills = project.get_pretty_bills() - bill_to_add = list() - for j in json_file: - same = False - for p in project_bills: - if same_bill(p, j): - same = True - break - if not same: - bill_to_add.append(j) - - # Add users to DB - for m in members_to_add: - Person(name=m[0], project=project, weight=m[1]) - db.session.commit() - - id_dict = {} - for i in project.members: - id_dict[i.name] = i.id - - # Create bills - for b in bill_to_add: - owers_id = list() - for ower in b["owers"]: - owers_id.append(id_dict[ower]) - - bill = Bill() - form = get_billform_for(project) - form.what = b["what"] - form.amount = b["amount"] - form.original_currency = b["currency"] - form.date = parse(b["date"]) - form.payer = id_dict[b["payer_name"]] - form.payed_for = owers_id - - db.session.add(form.fake_form(bill, project)) - - # Add bills to DB - db.session.commit() + flash(_("Project successfully uploaded")) + return redirect(url_for("main.list_bills")) + except ValueError as b: + flash(b.args[0], category="danger") + else: + for component, errors in form.errors.items(): + flash(_(component + ": ") + ", ".join(errors), category="danger") + return redirect(request.headers.get("Referer") or url_for(".edit_project")) @main.route("//delete", methods=["POST"]) @@ -760,8 +717,7 @@ def add_bill(): session["last_selected_payer"] = form.payer.data session.update() - bill = Bill() - db.session.add(form.save(bill, g.project)) + db.session.add(form.export(g.project)) db.session.commit() flash(_("The bill has been added"))