diff --git a/ihatemoney/api/common.py b/ihatemoney/api/common.py index 1cfc34c5..44727f5a 100644 --- a/ihatemoney/api/common.py +++ b/ihatemoney/api/common.py @@ -159,8 +159,7 @@ class BillsHandler(Resource): def post(self, project): form = get_billform_for(project, True, meta={"csrf": False}) if form.validate(): - bill = Bill() - form.save(bill, project) + bill = form.export(project) db.session.add(bill) db.session.commit() return bill.id, 201 diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py index 4e241c86..fe966778 100644 --- a/ihatemoney/forms.py +++ b/ihatemoney/forms.py @@ -23,7 +23,7 @@ from wtforms.validators import ( ) from ihatemoney.currency_convertor import CurrencyConverter -from ihatemoney.models import LoggingMode, Person, Project +from ihatemoney.models import Bill, LoggingMode, Person, Project from ihatemoney.utils import ( eval_arithmetic_expression, render_localized_currency, @@ -182,13 +182,15 @@ class EditProjectForm(FlaskForm): return project -class UploadForm(FlaskForm): +class ImportProjectForm(FlaskForm): file = FileField( - "JSON", - validators=[FileRequired(), FileAllowed(["json", "JSON"], "JSON only!")], - description=_("Import previously exported JSON file"), + "File", + validators=[ + FileRequired(), + FileAllowed(["json", "JSON", "csv", "CSV"], "Incorrect file format"), + ], + description=_("Compatible with Cospend"), ) - submit = SubmitField(_("Import")) class ProjectForm(EditProjectForm): @@ -319,33 +321,31 @@ class BillForm(FlaskForm): submit = SubmitField(_("Submit")) submit2 = SubmitField(_("Submit and add a new one")) + def export(self, project): + return Bill( + amount=float(self.amount.data), + date=self.date.data, + external_link=self.external_link.data, + original_currency=str(self.original_currency.data), + owers=Person.query.get_by_ids(self.payed_for.data, project), + payer_id=self.payer.data, + project_default_currency=project.default_currency, + what=self.what.data, + ) + def save(self, bill, project): bill.payer_id = self.payer.data bill.amount = self.amount.data bill.what = self.what.data bill.external_link = self.external_link.data bill.date = self.date.data - bill.owers = [Person.query.get(ower, project) for ower in self.payed_for.data] + bill.owers = Person.query.get_by_ids(self.payed_for.data, project) bill.original_currency = self.original_currency.data bill.converted_amount = self.currency_helper.exchange_currency( bill.amount, bill.original_currency, project.default_currency ) return bill - def fake_form(self, bill, project): - bill.payer_id = self.payer - bill.amount = self.amount - bill.what = self.what - bill.external_link = "" - bill.date = self.date - bill.owers = [Person.query.get(ower, project) for ower in self.payed_for] - bill.original_currency = self.original_currency - bill.converted_amount = self.currency_helper.exchange_currency( - bill.amount, bill.original_currency, project.default_currency - ) - - return bill - def fill(self, bill, project): self.payer.data = bill.payer_id self.amount.data = bill.amount diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 473e7c0b..7877b410 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -1,6 +1,7 @@ from collections import defaultdict from datetime import datetime +from dateutil.parser import parse from debts import settle from flask import current_app, g from flask_sqlalchemy import BaseQuery, SQLAlchemy @@ -19,6 +20,7 @@ from werkzeug.security import generate_password_hash from ihatemoney.currency_convertor import CurrencyConverter from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder +from ihatemoney.utils import get_members, same_bill from ihatemoney.versioning import ( ConditionalVersioningManager, LoggingMode, @@ -320,6 +322,44 @@ class Project(db.Model): db.session.add(self) db.session.commit() + def import_bills(self, bills: list): + """Import bills from a list of dictionaries""" + # Add members not already in the project + project_members = [str(m) for m in self.members] + new_members = [ + m for m in get_members(bills) if str(m[0]) not in project_members + ] + for m in new_members: + Person(name=m[0], project=self, weight=m[1]) + db.session.commit() + + # Import bills not already in the project + project_bills = self.get_pretty_bills() + id_dict = {m.name: m.id for m in self.members} + for b in bills: + same = False + for p_b in project_bills: + if same_bill(p_b, b): + same = True + break + if not same: + # Create bills + try: + new_bill = Bill( + amount=b["amount"], + date=parse(b["date"]), + external_link="", + original_currency=b["currency"], + owers=Person.query.get_by_names(b["owers"], self), + payer_id=id_dict[b["payer_name"]], + project_default_currency=self.default_currency, + what=b["what"], + ) + except Exception as e: + raise ValueError(f"Unable to import csv data: {repr(e)}") + db.session.add(new_bill) + db.session.commit() + def remove_member(self, member_id): """Remove a member from the project. @@ -435,16 +475,17 @@ class Project(db.Model): ("Alice", 20, ("Amina", "Alice"), "Beer !"), ("Amina", 50, ("Amina", "Alice", "Georg"), "AMAP"), ) - for (payer, amount, owers, subject) in operations: - bill = Bill() - bill.payer_id = members[payer].id - bill.what = subject - bill.owers = [members[name] for name in owers] - bill.amount = amount - bill.original_currency = "XXX" - bill.converted_amount = amount - - db.session.add(bill) + for (payer, amount, owers, what) in operations: + db.session.add( + Bill( + amount=amount, + original_currency=project.default_currency, + owers=[members[name] for name in owers], + payer_id=members[payer].id, + project_default_currency=project.default_currency, + what=what, + ) + ) db.session.commit() return project @@ -459,6 +500,13 @@ class Person(db.Model): .one_or_none() ) + def get_by_names(self, names, project): + return ( + Person.query.filter(Person.name.in_(names)) + .filter(Person.project_id == project.id) + .all() + ) + def get(self, id, project=None): if not project: project = g.project @@ -468,6 +516,15 @@ class Person(db.Model): .one_or_none() ) + def get_by_ids(self, ids, project=None): + if not project: + project = g.project + return ( + Person.query.filter(Person.id.in_(ids)) + .filter(Person.project_id == project.id) + .all() + ) + query_class = PersonQuery # Direct SQLAlchemy-Continuum to track changes to this model @@ -561,6 +618,31 @@ class Bill(db.Model): archive = db.Column(db.Integer, db.ForeignKey("archive.id")) + currency_helper = CurrencyConverter() + + def __init__( + self, + amount: float, + date: datetime = None, + external_link: str = "", + original_currency: str = "", + owers: list = [], + payer_id: int = None, + project_default_currency: str = "", + what: str = "", + ): + super().__init__() + self.amount = amount + self.date = date + self.external_link = external_link + self.original_currency = original_currency + self.owers = owers + self.payer_id = payer_id + self.what = what + self.converted_amount = self.currency_helper.exchange_currency( + self.amount, self.original_currency, project_default_currency + ) + @property def _to_serialize(self): return { diff --git a/ihatemoney/templates/edit_project.html b/ihatemoney/templates/edit_project.html index 7fcc725e..7ea47d9f 100644 --- a/ihatemoney/templates/edit_project.html +++ b/ihatemoney/templates/edit_project.html @@ -29,7 +29,7 @@
{{ _("Import previously exported project") }}
+ +', resp.data.decode("utf-8")) - self.assertEqual(len(models.Project.query.get("raclette").members), 1) - self.assertEqual(models.Project.query.get("raclette").members[0].weight, 1) + self.assertEqual(len(self.get_project("raclette").members), 1) + self.assertEqual(self.get_project("raclette").members[0].weight, 1) def test_rounding(self): self.post_project("raclette") @@ -840,11 +836,11 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) - balance = models.Project.query.get("raclette").balance + balance = self.get_project("raclette").balance result = {} - result[models.Project.query.get("raclette").members[0].id] = 8.12 - result[models.Project.query.get("raclette").members[1].id] = 0.0 - result[models.Project.query.get("raclette").members[2].id] = -8.12 + result[self.get_project("raclette").members[0].id] = 8.12 + result[self.get_project("raclette").members[1].id] = 0.0 + result[self.get_project("raclette").members[2].id] = -8.12 # Since we're using floating point to store currency, we can have some # rounding issues that prevent test from working. # However, we should obtain the same values as the theoretical ones if we @@ -866,7 +862,7 @@ class BudgetTestCase(IhatemoneyTestCase): resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True) self.assertEqual(resp.status_code, 200) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") self.assertEqual(project.name, new_data["name"]) self.assertEqual(project.contact_email, new_data["contact_email"]) @@ -1030,14 +1026,14 @@ class BudgetTestCase(IhatemoneyTestCase): "amount": "10", }, ) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") transactions = project.get_transactions_to_settle_bill() members = defaultdict(int) # We should have the same values between transactions and project balances for t in transactions: members[t["ower"]] -= t["amount"] members[t["receiver"]] += t["amount"] - balance = models.Project.query.get("raclette").balance + balance = self.get_project("raclette").balance for m, a in members.items(): assert abs(a - balance[m.id]) < 0.01 return @@ -1083,7 +1079,7 @@ class BudgetTestCase(IhatemoneyTestCase): "amount": "13.33", }, ) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") transactions = project.get_transactions_to_settle_bill() # There should not be any zero-amount transfer after rounding @@ -1095,768 +1091,6 @@ class BudgetTestCase(IhatemoneyTestCase): msg=f"{t['amount']} is equal to zero after rounding", ) - def test_export(self): - # Export a simple project without currencies - - self.post_project("raclette") - - # add participants - self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) - self.client.post("/raclette/members/add", data={"name": "fred"}) - self.client.post("/raclette/members/add", data={"name": "tata"}) - self.client.post("/raclette/members/add", data={"name": "pépé"}) - - # create bills - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "fromage à raclette", - "payer": 1, - "payed_for": [1, 2, 3, 4], - "amount": "10.0", - }, - ) - - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "red wine", - "payer": 2, - "payed_for": [1, 3], - "amount": "200", - }, - ) - - self.client.post( - "/raclette/add", - data={ - "date": "2017-01-01", - "what": "refund", - "payer": 3, - "payed_for": [2], - "amount": "13.33", - }, - ) - - # generate json export of bills - resp = self.client.get("/raclette/export/bills.json") - expected = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "XXX", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "red wine", - "amount": 200.0, - "currency": "XXX", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage \xe0 raclette", - "amount": 10.0, - "currency": "XXX", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"], - }, - ] - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of bills - resp = self.client.get("/raclette/export/bills.csv") - expected = [ - "date,what,amount,currency,payer_name,payer_weight,owers", - "2017-01-01,refund,XXX,13.33,tata,1.0,fred", - '2016-12-31,red wine,XXX,200.0,fred,1.0,"zorglub, tata"', - '2016-12-31,fromage à raclette,10.0,XXX,zorglub,2.0,"zorglub, fred, tata, pépé"', - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - # generate json export of transactions - resp = self.client.get("/raclette/export/transactions.json") - expected = [ - { - "amount": 2.00, - "currency": "XXX", - "receiver": "fred", - "ower": "p\xe9p\xe9", - }, - {"amount": 55.34, "currency": "XXX", "receiver": "fred", "ower": "tata"}, - { - "amount": 127.33, - "currency": "XXX", - "receiver": "fred", - "ower": "zorglub", - }, - ] - - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of transactions - resp = self.client.get("/raclette/export/transactions.csv") - - expected = [ - "amount,currency,receiver,ower", - "2.0,XXX,fred,pépé", - "55.34,XXX,fred,tata", - "127.33,XXX,fred,zorglub", - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - # wrong export_format should return a 404 - resp = self.client.get("/raclette/export/transactions.wrong") - self.assertEqual(resp.status_code, 404) - - def test_export_with_currencies(self): - self.post_project("raclette", default_currency="EUR") - - # add participants - self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) - self.client.post("/raclette/members/add", data={"name": "fred"}) - self.client.post("/raclette/members/add", data={"name": "tata"}) - self.client.post("/raclette/members/add", data={"name": "pépé"}) - - # create bills - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "fromage à raclette", - "payer": 1, - "payed_for": [1, 2, 3, 4], - "amount": "10.0", - "original_currency": "EUR", - }, - ) - - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "poutine from Québec", - "payer": 2, - "payed_for": [1, 3], - "amount": "100", - "original_currency": "CAD", - }, - ) - - self.client.post( - "/raclette/add", - data={ - "date": "2017-01-01", - "what": "refund", - "payer": 3, - "payed_for": [2], - "amount": "13.33", - "original_currency": "EUR", - }, - ) - - # generate json export of bills - resp = self.client.get("/raclette/export/bills.json") - expected = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "EUR", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "poutine from Qu\xe9bec", - "amount": 100.0, - "currency": "CAD", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage \xe0 raclette", - "amount": 10.0, - "currency": "EUR", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"], - }, - ] - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of bills - resp = self.client.get("/raclette/export/bills.csv") - expected = [ - "date,what,amount,currency,payer_name,payer_weight,owers", - "2017-01-01,refund,13.33,EUR,tata,1.0,fred", - '2016-12-31,poutine from Québec,100.0,CAD,fred,1.0,"zorglub, tata"', - '2016-12-31,fromage à raclette,10.0,EUR,zorglub,2.0,"zorglub, fred, tata, pépé"', - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - # generate json export of transactions (in EUR!) - resp = self.client.get("/raclette/export/transactions.json") - expected = [ - { - "amount": 2.00, - "currency": "EUR", - "receiver": "fred", - "ower": "p\xe9p\xe9", - }, - {"amount": 10.89, "currency": "EUR", "receiver": "fred", "ower": "tata"}, - {"amount": 38.45, "currency": "EUR", "receiver": "fred", "ower": "zorglub"}, - ] - - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of transactions - resp = self.client.get("/raclette/export/transactions.csv") - - expected = [ - "amount,currency,receiver,ower", - "2.0,EUR,fred,pépé", - "10.89,EUR,fred,tata", - "38.45,EUR,fred,zorglub", - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - # Change project currency to CAD - project = models.Project.query.get("raclette") - project.switch_currency("CAD") - - # generate json export of transactions (now in CAD!) - resp = self.client.get("/raclette/export/transactions.json") - expected = [ - { - "amount": 3.00, - "currency": "CAD", - "receiver": "fred", - "ower": "p\xe9p\xe9", - }, - {"amount": 16.34, "currency": "CAD", "receiver": "fred", "ower": "tata"}, - {"amount": 57.67, "currency": "CAD", "receiver": "fred", "ower": "zorglub"}, - ] - - self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) - - # generate csv export of transactions - resp = self.client.get("/raclette/export/transactions.csv") - - expected = [ - "amount,currency,receiver,ower", - "3.0,CAD,fred,pépé", - "16.34,CAD,fred,tata", - "57.67,CAD,fred,zorglub", - ] - received_lines = resp.data.decode("utf-8").split("\n") - - for i, line in enumerate(expected): - self.assertEqual( - set(line.split(",")), set(received_lines[i].strip("\r").split(",")) - ) - - def test_import_currencies_in_empty_project_with_currency(self): - # Import JSON with currencies in an empty project with a default currency - - self.post_project("raclette", default_currency="EUR") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "EUR", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "poutine from québec", - "amount": 50.0, - "currency": "CAD", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "currency": "EUR", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - self.assertEqual(j["currency"], i["currency"]) - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_single_currency_in_empty_project_without_currency(self): - # Import JSON with a single currency in an empty project with no - # default currency. It should work by stripping the currency from - # bills. - - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "EUR", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "currency": "EUR", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - # Currency should have been stripped - self.assertEqual(j["currency"], "XXX") - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_multiple_currencies_in_empty_project_without_currency(self): - # Import JSON with multiple currencies in an empty project with no - # default currency. It should fail. - - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "EUR", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "poutine from québec", - "amount": 50.0, - "currency": "CAD", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "currency": "EUR", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - # Import should fail - with pytest.raises(ValueError): - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check that there are no bills - self.assertEqual(len(bills), 0) - - def test_import_no_currency_in_empty_project_with_currency(self): - # Import JSON without currencies (from ihatemoney < 5) in an empty - # project with a default currency. - - self.post_project("raclette", default_currency="EUR") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "red wine", - "amount": 200.0, - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - # All bills are converted to default project currency - self.assertEqual(j["currency"], "EUR") - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_no_currency_in_empty_project_without_currency(self): - # Import JSON without currencies (from ihatemoney < 5) in an empty - # project with no default currency. - - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { - "date": "2016-12-31", - "what": "red wine", - "amount": 200.0, - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - self.assertEqual(j["currency"], "XXX") - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_partial_project(self): - # Import a JSON in a project with already existing data - - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2}) - self.client.post("/raclette/members/add", data={"name": "fred"}) - self.client.post("/raclette/members/add", data={"name": "tata"}) - self.client.post( - "/raclette/add", - data={ - "date": "2016-12-31", - "what": "red wine", - "payer": 2, - "payed_for": [1, 3], - "amount": "200", - }, - ) - - json_to_import = [ - { - "date": "2017-01-01", - "what": "refund", - "amount": 13.33, - "currency": "XXX", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - }, - { # This expense does not have to be present twice. - "date": "2016-12-31", - "what": "red wine", - "amount": 200.0, - "currency": "XXX", - "payer_name": "fred", - "payer_weight": 1.0, - "owers": ["zorglub", "tata"], - }, - { - "date": "2016-12-31", - "what": "fromage a raclette", - "amount": 10.0, - "currency": "XXX", - "payer_name": "zorglub", - "payer_weight": 2.0, - "owers": ["zorglub", "fred", "tata", "pepe"], - }, - ] - - from ihatemoney.web import import_project - - file = io.StringIO() - json.dump(json_to_import, file) - file.seek(0) - import_project(file, project) - - bills = project.get_pretty_bills() - - # Check if all bills have been added - self.assertEqual(len(bills), len(json_to_import)) - - # Check if name of bills are ok - b = [e["what"] for e in bills] - b.sort() - ref = [e["what"] for e in json_to_import] - ref.sort() - - self.assertEqual(b, ref) - - # Check if other informations in bill are ok - for i in json_to_import: - for j in bills: - if j["what"] == i["what"]: - self.assertEqual(j["payer_name"], i["payer_name"]) - self.assertEqual(j["amount"], i["amount"]) - self.assertEqual(j["currency"], i["currency"]) - self.assertEqual(j["payer_weight"], i["payer_weight"]) - self.assertEqual(j["date"], i["date"]) - - list_project = [ower for ower in j["owers"]] - list_project.sort() - list_json = [ower for ower in i["owers"]] - list_json.sort() - - self.assertEqual(list_project, list_json) - - def test_import_wrong_json(self): - self.post_project("raclette") - self.login("raclette") - - project = models.Project.query.get("raclette") - - json_1 = [ - { # wrong keys - "checked": False, - "dimensions": {"width": 5, "height": 10}, - "id": 1, - "name": "A green door", - "price": 12.5, - "tags": ["home", "green"], - } - ] - - json_2 = [ - { # amount missing - "date": "2017-01-01", - "what": "refund", - "payer_name": "tata", - "payer_weight": 1.0, - "owers": ["fred"], - } - ] - - from ihatemoney.web import import_project - - for data in [json_1, json_2]: - file = io.StringIO() - json.dump(data, file) - file.seek(0) - with pytest.raises(ValueError): - import_project(file, project) - def test_access_other_projects(self): """Test that accessing or editing bills and participants from another project fails""" # Create project @@ -1880,7 +1114,7 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) # Ensure it has been created - raclette = models.Project.query.get("raclette") + raclette = self.get_project("raclette") self.assertEqual(raclette.get_bills().count(), 1) # Log out @@ -2025,7 +1259,7 @@ class BudgetTestCase(IhatemoneyTestCase): }, ) - project = models.Project.query.get("raclette") + project = self.get_project("raclette") # First all converted_amount should be the same as amount, with no currency for bill in project.get_bills(): @@ -2106,7 +1340,7 @@ class BudgetTestCase(IhatemoneyTestCase): # A user displayed error should be generated, and its currency should be the same. self.assertStatus(200, resp) self.assertIn('
', resp.data.decode("utf-8"))
- self.assertEqual(models.Project.query.get("raclette").default_currency, "USD")
+ self.assertEqual(self.get_project("raclette").default_currency, "USD")
def test_currency_switch_to_bill_currency(self):
@@ -2130,7 +1364,7 @@ class BudgetTestCase(IhatemoneyTestCase):
},
)
- project = models.Project.query.get("raclette")
+ project = self.get_project("raclette")
bill = project.get_bills().first()
assert bill.converted_amount == self.converter.exchange_currency(
@@ -2176,7 +1410,7 @@ class BudgetTestCase(IhatemoneyTestCase):
},
)
- project = models.Project.query.get("raclette")
+ project = self.get_project("raclette")
for bill in project.get_bills_unordered():
assert bill.converted_amount == self.converter.exchange_currency(
diff --git a/ihatemoney/tests/common/ihatemoney_testcase.py b/ihatemoney/tests/common/ihatemoney_testcase.py
index 5ad0a56a..eda86ab7 100644
--- a/ihatemoney/tests/common/ihatemoney_testcase.py
+++ b/ihatemoney/tests/common/ihatemoney_testcase.py
@@ -74,6 +74,14 @@ class BaseTestCase(TestCase):
follow_redirects=follow_redirects,
)
+ def import_project(self, id, data, success=True):
+ resp = self.client.post(
+ f"/{id}/import",
+ data=data,
+ # follow_redirects=True,
+ )
+ self.assertEqual("/{id}/edit" in str(resp.response), not success)
+
def create_project(self, id, default_currency="XXX", name=None, password=None):
name = name or str(id)
password = password or id
@@ -87,6 +95,9 @@ class BaseTestCase(TestCase):
models.db.session.add(project)
models.db.session.commit()
+ def get_project(self, id) -> models.Project:
+ return models.Project.query.get(id)
+
class IhatemoneyTestCase(BaseTestCase):
TESTING = True
diff --git a/ihatemoney/tests/history_test.py b/ihatemoney/tests/history_test.py
index 38a3740e..b742a8c8 100644
--- a/ihatemoney/tests/history_test.py
+++ b/ihatemoney/tests/history_test.py
@@ -614,14 +614,14 @@ class HistoryTestCase(IhatemoneyTestCase):
models.db.session.add(b2)
models.db.session.commit()
- history_list = history.get_history(models.Project.query.get("demo"))
+ history_list = history.get_history(self.get_project("demo"))
self.assertEqual(len(history_list), 5)
# Change just the amount
b1.amount = 5
models.db.session.commit()
- history_list = history.get_history(models.Project.query.get("demo"))
+ history_list = history.get_history(self.get_project("demo"))
for entry in history_list:
if "prop_changed" in entry:
self.assertNotIn("owers", entry["prop_changed"])
diff --git a/ihatemoney/tests/import_test.py b/ihatemoney/tests/import_test.py
new file mode 100644
index 00000000..f1ee1d07
--- /dev/null
+++ b/ihatemoney/tests/import_test.py
@@ -0,0 +1,614 @@
+import copy
+import json
+import unittest
+
+from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
+from ihatemoney.utils import list_of_dicts2csv, list_of_dicts2json
+
+
+class CommonTestCase(object):
+ class Import(IhatemoneyTestCase):
+ def setUp(self):
+ super().setUp()
+ self.data = [
+ {
+ "date": "2017-01-01",
+ "what": "refund",
+ "amount": 13.33,
+ "payer_name": "tata",
+ "payer_weight": 1.0,
+ "owers": ["fred"],
+ },
+ {
+ "date": "2016-12-31",
+ "what": "red wine",
+ "amount": 200.0,
+ "payer_name": "fred",
+ "payer_weight": 1.0,
+ "owers": ["zorglub", "tata"],
+ },
+ {
+ "date": "2016-12-31",
+ "what": "fromage a raclette",
+ "amount": 10.0,
+ "payer_name": "zorglub",
+ "payer_weight": 2.0,
+ "owers": ["zorglub", "fred", "tata", "pepe"],
+ },
+ ]
+
+ def populate_data_with_currencies(self, currencies):
+ for d in range(len(self.data)):
+ self.data[d]["currency"] = currencies[d]
+
+ def test_import_currencies_in_empty_project_with_currency(self):
+ # Import JSON with currencies in an empty project with a default currency
+
+ self.post_project("raclette", default_currency="EUR")
+ self.login("raclette")
+
+ project = self.get_project("raclette")
+
+ self.populate_data_with_currencies(["EUR", "CAD", "EUR"])
+ self.import_project("raclette", self.generate_form_data(self.data))
+
+ bills = project.get_pretty_bills()
+
+ # Check if all bills have been added
+ self.assertEqual(len(bills), len(self.data))
+
+ # Check if name of bills are ok
+ b = [e["what"] for e in bills]
+ b.sort()
+ ref = [e["what"] for e in self.data]
+ ref.sort()
+
+ self.assertEqual(b, ref)
+
+ # Check if other informations in bill are ok
+ for d in self.data:
+ for b in bills:
+ if b["what"] == d["what"]:
+ self.assertEqual(b["payer_name"], d["payer_name"])
+ self.assertEqual(b["amount"], d["amount"])
+ self.assertEqual(b["currency"], d["currency"])
+ self.assertEqual(b["payer_weight"], d["payer_weight"])
+ self.assertEqual(b["date"], d["date"])
+ list_project = [ower for ower in b["owers"]]
+ list_project.sort()
+ list_json = [ower for ower in d["owers"]]
+ list_json.sort()
+ self.assertEqual(list_project, list_json)
+
+ def test_import_single_currency_in_empty_project_without_currency(self):
+ # Import JSON with a single currency in an empty project with no
+ # default currency. It should work by stripping the currency from
+ # bills.
+
+ self.post_project("raclette")
+ self.login("raclette")
+
+ project = self.get_project("raclette")
+
+ self.populate_data_with_currencies(["EUR", "EUR", "EUR"])
+ self.import_project("raclette", self.generate_form_data(self.data))
+
+ bills = project.get_pretty_bills()
+
+ # Check if all bills have been added
+ self.assertEqual(len(bills), len(self.data))
+
+ # Check if name of bills are ok
+ b = [e["what"] for e in bills]
+ b.sort()
+ ref = [e["what"] for e in self.data]
+ ref.sort()
+
+ self.assertEqual(b, ref)
+
+ # Check if other informations in bill are ok
+ for d in self.data:
+ for b in bills:
+ if b["what"] == d["what"]:
+ self.assertEqual(b["payer_name"], d["payer_name"])
+ self.assertEqual(b["amount"], d["amount"])
+ # Currency should have been stripped
+ self.assertEqual(b["currency"], "XXX")
+ self.assertEqual(b["payer_weight"], d["payer_weight"])
+ self.assertEqual(b["date"], d["date"])
+ list_project = [ower for ower in b["owers"]]
+ list_project.sort()
+ list_json = [ower for ower in d["owers"]]
+ list_json.sort()
+ self.assertEqual(list_project, list_json)
+
+ def test_import_multiple_currencies_in_empty_project_without_currency(self):
+ # Import JSON with multiple currencies in an empty project with no
+ # default currency. It should fail.
+
+ self.post_project("raclette")
+ self.login("raclette")
+
+ project = self.get_project("raclette")
+
+ self.populate_data_with_currencies(["EUR", "CAD", "EUR"])
+ # Import should fail
+ self.import_project("raclette", self.generate_form_data(self.data), 400)
+
+ bills = project.get_pretty_bills()
+
+ # Check that there are no bills
+ self.assertEqual(len(bills), 0)
+
+ def test_import_no_currency_in_empty_project_with_currency(self):
+ # Import JSON without currencies (from ihatemoney < 5) in an empty
+ # project with a default currency.
+
+ self.post_project("raclette", default_currency="EUR")
+ self.login("raclette")
+
+ project = self.get_project("raclette")
+
+ self.import_project("raclette", self.generate_form_data(self.data))
+
+ bills = project.get_pretty_bills()
+
+ # Check if all bills have been added
+ self.assertEqual(len(bills), len(self.data))
+
+ # Check if name of bills are ok
+ b = [e["what"] for e in bills]
+ b.sort()
+ ref = [e["what"] for e in self.data]
+ ref.sort()
+
+ self.assertEqual(b, ref)
+
+ # Check if other informations in bill are ok
+ for d in self.data:
+ for b in bills:
+ if b["what"] == d["what"]:
+ self.assertEqual(b["payer_name"], d["payer_name"])
+ self.assertEqual(b["amount"], d["amount"])
+ # All bills are converted to default project currency
+ self.assertEqual(b["currency"], "EUR")
+ self.assertEqual(b["payer_weight"], d["payer_weight"])
+ self.assertEqual(b["date"], d["date"])
+ list_project = [ower for ower in b["owers"]]
+ list_project.sort()
+ list_json = [ower for ower in d["owers"]]
+ list_json.sort()
+ self.assertEqual(list_project, list_json)
+
+ def test_import_no_currency_in_empty_project_without_currency(self):
+ # Import JSON without currencies (from ihatemoney < 5) in an empty
+ # project with no default currency.
+
+ self.post_project("raclette")
+ self.login("raclette")
+
+ project = self.get_project("raclette")
+
+ self.import_project("raclette", self.generate_form_data(self.data))
+
+ bills = project.get_pretty_bills()
+
+ # Check if all bills have been added
+ self.assertEqual(len(bills), len(self.data))
+
+ # Check if name of bills are ok
+ b = [e["what"] for e in bills]
+ b.sort()
+ ref = [e["what"] for e in self.data]
+ ref.sort()
+
+ self.assertEqual(b, ref)
+
+ # Check if other informations in bill are ok
+ for d in self.data:
+ for b in bills:
+ if b["what"] == d["what"]:
+ self.assertEqual(b["payer_name"], d["payer_name"])
+ self.assertEqual(b["amount"], d["amount"])
+ self.assertEqual(b["currency"], "XXX")
+ self.assertEqual(b["payer_weight"], d["payer_weight"])
+ self.assertEqual(b["date"], d["date"])
+ list_project = [ower for ower in b["owers"]]
+ list_project.sort()
+ list_json = [ower for ower in d["owers"]]
+ list_json.sort()
+ self.assertEqual(list_project, list_json)
+
+ def test_import_partial_project(self):
+ # Import a JSON in a project with already existing data
+
+ self.post_project("raclette")
+ self.login("raclette")
+
+ project = self.get_project("raclette")
+
+ self.client.post(
+ "/raclette/members/add", data={"name": "zorglub", "weight": 2}
+ )
+ self.client.post("/raclette/members/add", data={"name": "fred"})
+ self.client.post("/raclette/members/add", data={"name": "tata"})
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2016-12-31",
+ "what": "red wine",
+ "payer": 2,
+ "payed_for": [1, 3],
+ "amount": "200",
+ },
+ )
+
+ self.populate_data_with_currencies(["XXX", "XXX", "XXX"])
+
+ self.import_project("raclette", self.generate_form_data(self.data))
+
+ bills = project.get_pretty_bills()
+
+ # Check if all bills have been added
+ self.assertEqual(len(bills), len(self.data))
+
+ # Check if name of bills are ok
+ b = [e["what"] for e in bills]
+ b.sort()
+ ref = [e["what"] for e in self.data]
+ ref.sort()
+
+ self.assertEqual(b, ref)
+
+ # Check if other informations in bill are ok
+ for d in self.data:
+ for b in bills:
+ if b["what"] == d["what"]:
+ self.assertEqual(b["payer_name"], d["payer_name"])
+ self.assertEqual(b["amount"], d["amount"])
+ self.assertEqual(b["currency"], d["currency"])
+ self.assertEqual(b["payer_weight"], d["payer_weight"])
+ self.assertEqual(b["date"], d["date"])
+ list_project = [ower for ower in b["owers"]]
+ list_project.sort()
+ list_json = [ower for ower in d["owers"]]
+ list_json.sort()
+ self.assertEqual(list_project, list_json)
+
+ def test_import_wrong_data(self):
+ self.post_project("raclette")
+ self.login("raclette")
+ data_wrong_keys = [
+ {
+ "checked": False,
+ "dimensions": {"width": 5, "height": 10},
+ "id": 1,
+ "name": "A green door",
+ "price": 12.5,
+ "tags": ["home", "green"],
+ }
+ ]
+ data_amount_missing = [
+ {
+ "date": "2017-01-01",
+ "what": "refund",
+ "payer_name": "tata",
+ "payer_weight": 1.0,
+ "owers": ["fred"],
+ }
+ ]
+ for data in [data_wrong_keys, data_amount_missing]:
+ # Import should fail
+ self.import_project("raclette", self.generate_form_data(data), 400)
+
+
+class ExportTestCase(IhatemoneyTestCase):
+ def test_export(self):
+ # Export a simple project without currencies
+
+ self.post_project("raclette")
+
+ # add participants
+ self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2})
+ self.client.post("/raclette/members/add", data={"name": "fred"})
+ self.client.post("/raclette/members/add", data={"name": "tata"})
+ self.client.post("/raclette/members/add", data={"name": "pépé"})
+
+ # create bills
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2016-12-31",
+ "what": "fromage à raclette",
+ "payer": 1,
+ "payed_for": [1, 2, 3, 4],
+ "amount": "10.0",
+ },
+ )
+
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2016-12-31",
+ "what": "red wine",
+ "payer": 2,
+ "payed_for": [1, 3],
+ "amount": "200",
+ },
+ )
+
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2017-01-01",
+ "what": "refund",
+ "payer": 3,
+ "payed_for": [2],
+ "amount": "13.33",
+ },
+ )
+
+ # generate json export of bills
+ resp = self.client.get("/raclette/export/bills.json")
+ expected = [
+ {
+ "date": "2017-01-01",
+ "what": "refund",
+ "amount": 13.33,
+ "currency": "XXX",
+ "payer_name": "tata",
+ "payer_weight": 1.0,
+ "owers": ["fred"],
+ },
+ {
+ "date": "2016-12-31",
+ "what": "red wine",
+ "amount": 200.0,
+ "currency": "XXX",
+ "payer_name": "fred",
+ "payer_weight": 1.0,
+ "owers": ["zorglub", "tata"],
+ },
+ {
+ "date": "2016-12-31",
+ "what": "fromage \xe0 raclette",
+ "amount": 10.0,
+ "currency": "XXX",
+ "payer_name": "zorglub",
+ "payer_weight": 2.0,
+ "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"],
+ },
+ ]
+ self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
+
+ # generate csv export of bills
+ resp = self.client.get("/raclette/export/bills.csv")
+ expected = [
+ "date,what,amount,currency,payer_name,payer_weight,owers",
+ "2017-01-01,refund,XXX,13.33,tata,1.0,fred",
+ '2016-12-31,red wine,XXX,200.0,fred,1.0,"zorglub, tata"',
+ '2016-12-31,fromage à raclette,10.0,XXX,zorglub,2.0,"zorglub, fred, tata, pépé"',
+ ]
+ received_lines = resp.data.decode("utf-8").split("\n")
+
+ for i, line in enumerate(expected):
+ self.assertEqual(
+ set(line.split(",")), set(received_lines[i].strip("\r").split(","))
+ )
+
+ # generate json export of transactions
+ resp = self.client.get("/raclette/export/transactions.json")
+ expected = [
+ {
+ "amount": 2.00,
+ "currency": "XXX",
+ "receiver": "fred",
+ "ower": "p\xe9p\xe9",
+ },
+ {"amount": 55.34, "currency": "XXX", "receiver": "fred", "ower": "tata"},
+ {
+ "amount": 127.33,
+ "currency": "XXX",
+ "receiver": "fred",
+ "ower": "zorglub",
+ },
+ ]
+
+ self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
+
+ # generate csv export of transactions
+ resp = self.client.get("/raclette/export/transactions.csv")
+
+ expected = [
+ "amount,currency,receiver,ower",
+ "2.0,XXX,fred,pépé",
+ "55.34,XXX,fred,tata",
+ "127.33,XXX,fred,zorglub",
+ ]
+ received_lines = resp.data.decode("utf-8").split("\n")
+
+ for i, line in enumerate(expected):
+ self.assertEqual(
+ set(line.split(",")), set(received_lines[i].strip("\r").split(","))
+ )
+
+ # wrong export_format should return a 404
+ resp = self.client.get("/raclette/export/transactions.wrong")
+ self.assertEqual(resp.status_code, 404)
+
+ def test_export_with_currencies(self):
+ self.post_project("raclette", default_currency="EUR")
+
+ # add participants
+ self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2})
+ self.client.post("/raclette/members/add", data={"name": "fred"})
+ self.client.post("/raclette/members/add", data={"name": "tata"})
+ self.client.post("/raclette/members/add", data={"name": "pépé"})
+
+ # create bills
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2016-12-31",
+ "what": "fromage à raclette",
+ "payer": 1,
+ "payed_for": [1, 2, 3, 4],
+ "amount": "10.0",
+ "original_currency": "EUR",
+ },
+ )
+
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2016-12-31",
+ "what": "poutine from Québec",
+ "payer": 2,
+ "payed_for": [1, 3],
+ "amount": "100",
+ "original_currency": "CAD",
+ },
+ )
+
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2017-01-01",
+ "what": "refund",
+ "payer": 3,
+ "payed_for": [2],
+ "amount": "13.33",
+ "original_currency": "EUR",
+ },
+ )
+
+ # generate json export of bills
+ resp = self.client.get("/raclette/export/bills.json")
+ expected = [
+ {
+ "date": "2017-01-01",
+ "what": "refund",
+ "amount": 13.33,
+ "currency": "EUR",
+ "payer_name": "tata",
+ "payer_weight": 1.0,
+ "owers": ["fred"],
+ },
+ {
+ "date": "2016-12-31",
+ "what": "poutine from Qu\xe9bec",
+ "amount": 100.0,
+ "currency": "CAD",
+ "payer_name": "fred",
+ "payer_weight": 1.0,
+ "owers": ["zorglub", "tata"],
+ },
+ {
+ "date": "2016-12-31",
+ "what": "fromage \xe0 raclette",
+ "amount": 10.0,
+ "currency": "EUR",
+ "payer_name": "zorglub",
+ "payer_weight": 2.0,
+ "owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"],
+ },
+ ]
+ self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
+
+ # generate csv export of bills
+ resp = self.client.get("/raclette/export/bills.csv")
+ expected = [
+ "date,what,amount,currency,payer_name,payer_weight,owers",
+ "2017-01-01,refund,13.33,EUR,tata,1.0,fred",
+ '2016-12-31,poutine from Québec,100.0,CAD,fred,1.0,"zorglub, tata"',
+ '2016-12-31,fromage à raclette,10.0,EUR,zorglub,2.0,"zorglub, fred, tata, pépé"',
+ ]
+ received_lines = resp.data.decode("utf-8").split("\n")
+
+ for i, line in enumerate(expected):
+ self.assertEqual(
+ set(line.split(",")), set(received_lines[i].strip("\r").split(","))
+ )
+
+ # generate json export of transactions (in EUR!)
+ resp = self.client.get("/raclette/export/transactions.json")
+ expected = [
+ {
+ "amount": 2.00,
+ "currency": "EUR",
+ "receiver": "fred",
+ "ower": "p\xe9p\xe9",
+ },
+ {"amount": 10.89, "currency": "EUR", "receiver": "fred", "ower": "tata"},
+ {"amount": 38.45, "currency": "EUR", "receiver": "fred", "ower": "zorglub"},
+ ]
+
+ self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
+
+ # generate csv export of transactions
+ resp = self.client.get("/raclette/export/transactions.csv")
+
+ expected = [
+ "amount,currency,receiver,ower",
+ "2.0,EUR,fred,pépé",
+ "10.89,EUR,fred,tata",
+ "38.45,EUR,fred,zorglub",
+ ]
+ received_lines = resp.data.decode("utf-8").split("\n")
+
+ for i, line in enumerate(expected):
+ self.assertEqual(
+ set(line.split(",")), set(received_lines[i].strip("\r").split(","))
+ )
+
+ # Change project currency to CAD
+ project = self.get_project("raclette")
+ project.switch_currency("CAD")
+
+ # generate json export of transactions (now in CAD!)
+ resp = self.client.get("/raclette/export/transactions.json")
+ expected = [
+ {
+ "amount": 3.00,
+ "currency": "CAD",
+ "receiver": "fred",
+ "ower": "p\xe9p\xe9",
+ },
+ {"amount": 16.34, "currency": "CAD", "receiver": "fred", "ower": "tata"},
+ {"amount": 57.67, "currency": "CAD", "receiver": "fred", "ower": "zorglub"},
+ ]
+
+ self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
+
+ # generate csv export of transactions
+ resp = self.client.get("/raclette/export/transactions.csv")
+
+ expected = [
+ "amount,currency,receiver,ower",
+ "3.0,CAD,fred,pépé",
+ "16.34,CAD,fred,tata",
+ "57.67,CAD,fred,zorglub",
+ ]
+ received_lines = resp.data.decode("utf-8").split("\n")
+
+ for i, line in enumerate(expected):
+ self.assertEqual(
+ set(line.split(",")), set(received_lines[i].strip("\r").split(","))
+ )
+
+
+class ImportTestCaseJSON(CommonTestCase.Import):
+ def generate_form_data(self, data):
+ return {"file": (list_of_dicts2json(data), "test.json")}
+
+
+class ImportTestCaseCSV(CommonTestCase.Import):
+ def generate_form_data(self, data):
+ formatted_data = copy.deepcopy(data)
+ for d in formatted_data:
+ d["owers"] = ", ".join([o for o in d.get("owers", [])])
+ return {"file": (list_of_dicts2csv(formatted_data), "test.csv")}
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/ihatemoney/tests/main_test.py b/ihatemoney/tests/main_test.py
index 9707bd8a..fac655d3 100644
--- a/ihatemoney/tests/main_test.py
+++ b/ihatemoney/tests/main_test.py
@@ -99,7 +99,7 @@ class CommandTestCase(BaseTestCase):
def test_demo_project_deletion(self):
self.create_project("demo")
- self.assertEqual(models.Project.query.get("demo").name, "demo")
+ self.assertEqual(self.get_project("demo").name, "demo")
runner = self.app.test_cli_runner()
runner.invoke(delete_project, "demo")
@@ -246,7 +246,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
ENABLE_CAPTCHA = True
def test_project_creation_with_captcha(self):
- with self.app.test_client() as c:
+ with self.client as c:
c.post(
"/create",
data={
diff --git a/ihatemoney/utils.py b/ihatemoney/utils.py
index 66f5b6a4..96b80816 100644
--- a/ihatemoney/utils.py
+++ b/ihatemoney/utils.py
@@ -2,7 +2,7 @@ import ast
import csv
from datetime import datetime, timedelta
from enum import Enum
-from io import BytesIO, StringIO
+from io import BytesIO, StringIO, TextIOWrapper
from json import JSONEncoder, dumps
import operator
import os
@@ -150,6 +150,31 @@ def list_of_dicts2csv(dict_to_convert):
return csv_file
+def csv2list_of_dicts(csv_to_convert):
+ """Take a csv in-memory file and turns it into
+ a list of dictionnaries
+ """
+ csv_file = TextIOWrapper(csv_to_convert, encoding="utf-8")
+ reader = csv.DictReader(csv_file)
+ result = []
+ for r in reader:
+ """
+ cospend embeds various data helping (cospend) imports
+ 'deleteMeIfYouWant' lines contains users
+ 'categoryname' table contains categories description
+ we don't need them as we determine users and categories from bills
+ """
+ if r["what"] == "deleteMeIfYouWant":
+ continue
+ elif r["what"] == "categoryname":
+ break
+ r["amount"] = float(r["amount"])
+ r["payer_weight"] = float(r["payer_weight"])
+ r["owers"] = [o.strip() for o in r["owers"].split(",")]
+ result.append(r)
+ return result
+
+
class LoginThrottler:
"""Simple login throttler used to limit authentication attempts based on client's ip address.
When using multiple workers, remaining number of attempts can get inconsistent
diff --git a/ihatemoney/web.py b/ihatemoney/web.py
index 7f986ee5..e77b41f3 100644
--- a/ihatemoney/web.py
+++ b/ihatemoney/web.py
@@ -13,7 +13,6 @@ from functools import wraps
import json
import os
-from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from flask import (
Blueprint,
@@ -44,13 +43,13 @@ from ihatemoney.forms import (
DestructiveActionProjectForm,
EditProjectForm,
EmptyForm,
+ ImportProjectForm,
InviteForm,
MemberForm,
PasswordReminder,
ProjectForm,
ProjectFormWithCaptcha,
ResetPasswordForm,
- UploadForm,
get_billform_for,
)
from ihatemoney.history import get_history, get_history_queries
@@ -58,12 +57,11 @@ from ihatemoney.models import Bill, LoggingMode, Person, Project, db
from ihatemoney.utils import (
LoginThrottler,
Redirect303,
+ csv2list_of_dicts,
format_form_errors,
- get_members,
list_of_dicts2csv,
list_of_dicts2json,
render_localized_template,
- same_bill,
send_email,
)
@@ -412,17 +410,8 @@ def reset_password():
@main.route("/