diff --git a/ihatemoney/tests/tests.py b/ihatemoney/tests/tests.py index 1d9646e1..59e78802 100644 --- a/ihatemoney/tests/tests.py +++ b/ihatemoney/tests/tests.py @@ -1190,10 +1190,14 @@ class BudgetTestCase(IhatemoneyTestCase): resp = self.client.get("/raclette/export/transactions.wrong") self.assertEqual(resp.status_code, 404) - def test_import(self): + def test_import_new_project(self): + # Import JSON in an empty project + self.post_project("raclette") self.login("raclette") + project = models.Project.query.get("raclette") + json_to_import = [ { "date": "2017-01-01", @@ -1221,28 +1225,104 @@ class BudgetTestCase(IhatemoneyTestCase): }, ] - with open("json_test.json", "w+") as file: - json.dump(json_to_import, file) + from ihatemoney.web import import_project - self.client.post("/raclette/upload_json", data={"file": "json_test.json"}) + import_project(json.dumps(json_to_import), project) - os.remove("json_test.json") - - project = models.Project.query.get("raclette") bills = project.get_pretty_bills() # Check if all bills has been add - # self.assertEqual(len(bills), 3) #FAIL HERE + self.assertEqual(len(bills), len(json_to_import)) # Check if name of bills are ok - b = list() - for j in bills: - b.append(j["what"]) + b = [e["what"] for e in bills] b.sort() - ref = ["refund", "red wine", "fromage a raclette"] + ref = [e["what"] for e in json_to_import] ref.sort() - # self.assertEqual(b, ref) + self.assertEqual(b, ref) + + # Check if other informations in bill are ok + for i in json_to_import: + for j in bills: + if j["what"] == i["what"]: + self.assertEqual(j["payer_name"], i["payer_name"]) + self.assertEqual(j["amount"], i["amount"]) + self.assertEqual(j["payer_weight"], i["payer_weight"]) + self.assertEqual(j["date"], i["date"]) + + list_project = [ower for ower in j["owers"]] + list_project.sort() + list_json = [ower for ower in i["owers"]] + list_json.sort() + + self.assertEqual(list_project, list_json) + + def test_import_partial_project(self): + # Import a JSON in a project with already existing data + + self.post_project("raclette") + self.login("raclette") + + project = models.Project.query.get("raclette") + + self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "tata"}) + self.client.post( + "/raclette/add", + data={ + "date": "2016-12-31", + "what": "red wine", + "payer": 2, + "payed_for": [1, 3], + "amount": "200", + }, + ) + + json_to_import = [ + { + "date": "2017-01-01", + "what": "refund", + "amount": 13.33, + "payer_name": "tata", + "payer_weight": 1.0, + "owers": ["fred"], + }, + { # This expense does not have to be present twice. + "date": "2016-12-31", + "what": "red wine", + "amount": 200.0, + "payer_name": "fred", + "payer_weight": 1.0, + "owers": ["alexis", "tata"], + }, + { + "date": "2016-12-31", + "what": "fromage a raclette", + "amount": 10.0, + "payer_name": "alexis", + "payer_weight": 2.0, + "owers": ["alexis", "fred", "tata", "pepe"], + }, + ] + + from ihatemoney.web import import_project + + import_project(json.dumps(json_to_import), project) + + bills = project.get_pretty_bills() + + # Check if all bills has been add + self.assertEqual(len(bills), len(json_to_import)) + + # Check if name of bills are ok + b = [e["what"] for e in bills] + b.sort() + ref = [e["what"] for e in json_to_import] + ref.sort() + + self.assertEqual(b, ref) # Check if other informations in bill are ok for i in json_to_import: diff --git a/ihatemoney/web.py b/ihatemoney/web.py index ada7688a..c390548e 100644 --- a/ihatemoney/web.py +++ b/ihatemoney/web.py @@ -402,7 +402,7 @@ def upload_json(): if form.validate_on_submit(): file = form.file.data.stream.read() try: - import_project(file) + import_project(file, g.project) flash(_("Project successfully uploaded")) except ValueError: flash(_("Invalid JSON"), category="error") @@ -411,11 +411,11 @@ def upload_json(): return render_template("upload_json.html", form=form) -def import_project(file): +def import_project(file, project): # From json : export list of members json_file = json.loads(file) members_json = get_members(json_file) - members = g.project.members + members = project.members members_already_here = list() for m in members: members_already_here.append(str(m)) @@ -429,7 +429,7 @@ def import_project(file): # List bills not in the project # Same format than JSON element - project_bills = g.project.get_pretty_bills() + project_bills = project.get_pretty_bills() bill_to_add = list() for j in json_file: same = False @@ -442,11 +442,11 @@ def import_project(file): # Add users to DB for m in members_to_add: - Person(name=m[0], project=g.project, weight=m[1]) + Person(name=m[0], project=project, weight=m[1]) db.session.commit() id_dict = {} - for i in g.project.members: + for i in project.members: id_dict[i.name] = i.id # Create bills @@ -456,14 +456,14 @@ def import_project(file): owers_id.append(id_dict[ower]) bill = Bill() - form = get_billform_for(g.project) + form = get_billform_for(project) form.what = b["what"] form.amount = b["amount"] form.date = parse(b["date"]) form.payer = id_dict[b["payer_name"]] form.payed_for = owers_id - db.session.add(form.fake_form(bill, g.project)) + db.session.add(form.fake_form(bill, project)) # Add bills to DB db.session.commit()