CSV bills import (cospend compatible) (#951)

* proper import form (fix messy errors)
* csv compatible import
* cospend compatible import
* localization (best effort)
* refactoring
* revert localization (best effort)
* import return 400 on error
* fix Person.query.get_by_ids calls
* Bill explicit init parameters
* fix tests
* refacto tests with self.get_project
* separate import tests
* fix tests
* csv import test case
* fix import csv parsing
* revert DestructiveActionProjectForm renaming
* fix csv import test
* fix error redirection on import
* fix lint
* import file input type hint
* various fixes from review

Co-authored-by: Youe Graillot <youe.graillot@gmail.com>
This commit is contained in:
Youe Graillot 2021-12-22 00:00:34 +01:00 committed by GitHub
parent 8b6a2afc63
commit 747824a298
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 895 additions and 973 deletions

View file

@ -159,8 +159,7 @@ class BillsHandler(Resource):
def post(self, project):
form = get_billform_for(project, True, meta={"csrf": False})
if form.validate():
bill = Bill()
form.save(bill, project)
bill = form.export(project)
db.session.add(bill)
db.session.commit()
return bill.id, 201

View file

@ -23,7 +23,7 @@ from wtforms.validators import (
)
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.models import LoggingMode, Person, Project
from ihatemoney.models import Bill, LoggingMode, Person, Project
from ihatemoney.utils import (
eval_arithmetic_expression,
render_localized_currency,
@ -182,13 +182,15 @@ class EditProjectForm(FlaskForm):
return project
class UploadForm(FlaskForm):
class ImportProjectForm(FlaskForm):
file = FileField(
"JSON",
validators=[FileRequired(), FileAllowed(["json", "JSON"], "JSON only!")],
description=_("Import previously exported JSON file"),
"File",
validators=[
FileRequired(),
FileAllowed(["json", "JSON", "csv", "CSV"], "Incorrect file format"),
],
description=_("Compatible with Cospend"),
)
submit = SubmitField(_("Import"))
class ProjectForm(EditProjectForm):
@ -319,33 +321,31 @@ class BillForm(FlaskForm):
submit = SubmitField(_("Submit"))
submit2 = SubmitField(_("Submit and add a new one"))
def export(self, project):
return Bill(
amount=float(self.amount.data),
date=self.date.data,
external_link=self.external_link.data,
original_currency=str(self.original_currency.data),
owers=Person.query.get_by_ids(self.payed_for.data, project),
payer_id=self.payer.data,
project_default_currency=project.default_currency,
what=self.what.data,
)
def save(self, bill, project):
bill.payer_id = self.payer.data
bill.amount = self.amount.data
bill.what = self.what.data
bill.external_link = self.external_link.data
bill.date = self.date.data
bill.owers = [Person.query.get(ower, project) for ower in self.payed_for.data]
bill.owers = Person.query.get_by_ids(self.payed_for.data, project)
bill.original_currency = self.original_currency.data
bill.converted_amount = self.currency_helper.exchange_currency(
bill.amount, bill.original_currency, project.default_currency
)
return bill
def fake_form(self, bill, project):
bill.payer_id = self.payer
bill.amount = self.amount
bill.what = self.what
bill.external_link = ""
bill.date = self.date
bill.owers = [Person.query.get(ower, project) for ower in self.payed_for]
bill.original_currency = self.original_currency
bill.converted_amount = self.currency_helper.exchange_currency(
bill.amount, bill.original_currency, project.default_currency
)
return bill
def fill(self, bill, project):
self.payer.data = bill.payer_id
self.amount.data = bill.amount

View file

@ -1,6 +1,7 @@
from collections import defaultdict
from datetime import datetime
from dateutil.parser import parse
from debts import settle
from flask import current_app, g
from flask_sqlalchemy import BaseQuery, SQLAlchemy
@ -19,6 +20,7 @@ from werkzeug.security import generate_password_hash
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder
from ihatemoney.utils import get_members, same_bill
from ihatemoney.versioning import (
ConditionalVersioningManager,
LoggingMode,
@ -320,6 +322,44 @@ class Project(db.Model):
db.session.add(self)
db.session.commit()
def import_bills(self, bills: list):
"""Import bills from a list of dictionaries"""
# Add members not already in the project
project_members = [str(m) for m in self.members]
new_members = [
m for m in get_members(bills) if str(m[0]) not in project_members
]
for m in new_members:
Person(name=m[0], project=self, weight=m[1])
db.session.commit()
# Import bills not already in the project
project_bills = self.get_pretty_bills()
id_dict = {m.name: m.id for m in self.members}
for b in bills:
same = False
for p_b in project_bills:
if same_bill(p_b, b):
same = True
break
if not same:
# Create bills
try:
new_bill = Bill(
amount=b["amount"],
date=parse(b["date"]),
external_link="",
original_currency=b["currency"],
owers=Person.query.get_by_names(b["owers"], self),
payer_id=id_dict[b["payer_name"]],
project_default_currency=self.default_currency,
what=b["what"],
)
except Exception as e:
raise ValueError(f"Unable to import csv data: {repr(e)}")
db.session.add(new_bill)
db.session.commit()
def remove_member(self, member_id):
"""Remove a member from the project.
@ -435,16 +475,17 @@ class Project(db.Model):
("Alice", 20, ("Amina", "Alice"), "Beer !"),
("Amina", 50, ("Amina", "Alice", "Georg"), "AMAP"),
)
for (payer, amount, owers, subject) in operations:
bill = Bill()
bill.payer_id = members[payer].id
bill.what = subject
bill.owers = [members[name] for name in owers]
bill.amount = amount
bill.original_currency = "XXX"
bill.converted_amount = amount
db.session.add(bill)
for (payer, amount, owers, what) in operations:
db.session.add(
Bill(
amount=amount,
original_currency=project.default_currency,
owers=[members[name] for name in owers],
payer_id=members[payer].id,
project_default_currency=project.default_currency,
what=what,
)
)
db.session.commit()
return project
@ -459,6 +500,13 @@ class Person(db.Model):
.one_or_none()
)
def get_by_names(self, names, project):
return (
Person.query.filter(Person.name.in_(names))
.filter(Person.project_id == project.id)
.all()
)
def get(self, id, project=None):
if not project:
project = g.project
@ -468,6 +516,15 @@ class Person(db.Model):
.one_or_none()
)
def get_by_ids(self, ids, project=None):
if not project:
project = g.project
return (
Person.query.filter(Person.id.in_(ids))
.filter(Person.project_id == project.id)
.all()
)
query_class = PersonQuery
# Direct SQLAlchemy-Continuum to track changes to this model
@ -561,6 +618,31 @@ class Bill(db.Model):
archive = db.Column(db.Integer, db.ForeignKey("archive.id"))
currency_helper = CurrencyConverter()
def __init__(
self,
amount: float,
date: datetime = None,
external_link: str = "",
original_currency: str = "",
owers: list = [],
payer_id: int = None,
project_default_currency: str = "",
what: str = "",
):
super().__init__()
self.amount = amount
self.date = date
self.external_link = external_link
self.original_currency = original_currency
self.owers = owers
self.payer_id = payer_id
self.what = what
self.converted_amount = self.currency_helper.exchange_currency(
self.amount, self.original_currency, project_default_currency
)
@property
def _to_serialize(self):
return {

View file

@ -29,7 +29,7 @@
<div class="container edit-project">
<h2>{{ _("Edit project") }}</h2>
<form class="form-horizontal" method="post">
<form id="edit-project" class="form-horizontal" method="post">
{{ forms.edit_project(edit_form) }}
</form>
@ -38,23 +38,10 @@
{{ forms.delete_project(delete_form) }}
</form>
<h2>{{ _("Import JSON") }}</h2>
<form class="form-horizontal" method="post" enctype="multipart/form-data">
{{ import_form.hidden_tag() }}
<div class="custom-file">
<div class="form-group">
{{ import_form.file(class="custom-file-input") }}
<small class="form-text text-muted">
{{ import_form.file.description }}
</small>
</div>
<label class="custom-file-label" for="customFile">{{ _('Choose file') }}</label>
</div>
<div class="actions">
{{ import_form.submit(class="btn btn-primary") }}
</div>
<h2>{{ _("Import project") }}</h2>
<form id="import-project" class="form-horizontal" action="{{ url_for(".import_project") }}" method="post" enctype="multipart/form-data">
{{ forms.import_project(import_form) }}
</form>
<h2>{{ _("Download project's data") }}</h2>

View file

@ -118,13 +118,27 @@
{% endmacro %}
{% macro upload_json(form) %}
{% macro import_project(form) %}
{% include "display_errors.html" %}
{{ form.hidden_tag() }}
{{ form.file }}
<div class="actions">
<button class="btn btn-primary">{{ _("Import") }}</button>
<p><strong>{{ _("Import previously exported project") }}</strong></p>
<div class="custom-file">
<div class="form-group">
{{ form.file(class="custom-file-input", accept=".json,.csv") }}
<small class="form-text text-muted">
{{ form.file.description }}
</small>
</div>
<label class="custom-file-label" for="customFile">{{ _('Choose file') }}</label>
</div>
<div class="actions">
<button class="btn btn-primary">{{ _("Import project") }}</button>
</div>
{% endmacro %}
{% macro delete_project_history(form) %}

File diff suppressed because it is too large Load diff

View file

@ -74,6 +74,14 @@ class BaseTestCase(TestCase):
follow_redirects=follow_redirects,
)
def import_project(self, id, data, success=True):
resp = self.client.post(
f"/{id}/import",
data=data,
# follow_redirects=True,
)
self.assertEqual("/{id}/edit" in str(resp.response), not success)
def create_project(self, id, default_currency="XXX", name=None, password=None):
name = name or str(id)
password = password or id
@ -87,6 +95,9 @@ class BaseTestCase(TestCase):
models.db.session.add(project)
models.db.session.commit()
def get_project(self, id) -> models.Project:
return models.Project.query.get(id)
class IhatemoneyTestCase(BaseTestCase):
TESTING = True

View file

@ -614,14 +614,14 @@ class HistoryTestCase(IhatemoneyTestCase):
models.db.session.add(b2)
models.db.session.commit()
history_list = history.get_history(models.Project.query.get("demo"))
history_list = history.get_history(self.get_project("demo"))
self.assertEqual(len(history_list), 5)
# Change just the amount
b1.amount = 5
models.db.session.commit()
history_list = history.get_history(models.Project.query.get("demo"))
history_list = history.get_history(self.get_project("demo"))
for entry in history_list:
if "prop_changed" in entry:
self.assertNotIn("owers", entry["prop_changed"])

View file

@ -0,0 +1,614 @@
import copy
import json
import unittest
from ihatemoney.tests.common.ihatemoney_testcase import IhatemoneyTestCase
from ihatemoney.utils import list_of_dicts2csv, list_of_dicts2json
class CommonTestCase(object):
class Import(IhatemoneyTestCase):
def setUp(self):
super().setUp()
self.data = [
{
"date": "2017-01-01",
"what": "refund",
"amount": 13.33,
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
},
{
"date": "2016-12-31",
"what": "red wine",
"amount": 200.0,
"payer_name": "fred",
"payer_weight": 1.0,
"owers": ["zorglub", "tata"],
},
{
"date": "2016-12-31",
"what": "fromage a raclette",
"amount": 10.0,
"payer_name": "zorglub",
"payer_weight": 2.0,
"owers": ["zorglub", "fred", "tata", "pepe"],
},
]
def populate_data_with_currencies(self, currencies):
for d in range(len(self.data)):
self.data[d]["currency"] = currencies[d]
def test_import_currencies_in_empty_project_with_currency(self):
# Import JSON with currencies in an empty project with a default currency
self.post_project("raclette", default_currency="EUR")
self.login("raclette")
project = self.get_project("raclette")
self.populate_data_with_currencies(["EUR", "CAD", "EUR"])
self.import_project("raclette", self.generate_form_data(self.data))
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
# Check if name of bills are ok
b = [e["what"] for e in bills]
b.sort()
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
self.assertEqual(b["currency"], d["currency"])
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
def test_import_single_currency_in_empty_project_without_currency(self):
# Import JSON with a single currency in an empty project with no
# default currency. It should work by stripping the currency from
# bills.
self.post_project("raclette")
self.login("raclette")
project = self.get_project("raclette")
self.populate_data_with_currencies(["EUR", "EUR", "EUR"])
self.import_project("raclette", self.generate_form_data(self.data))
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
# Check if name of bills are ok
b = [e["what"] for e in bills]
b.sort()
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
# Currency should have been stripped
self.assertEqual(b["currency"], "XXX")
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
def test_import_multiple_currencies_in_empty_project_without_currency(self):
# Import JSON with multiple currencies in an empty project with no
# default currency. It should fail.
self.post_project("raclette")
self.login("raclette")
project = self.get_project("raclette")
self.populate_data_with_currencies(["EUR", "CAD", "EUR"])
# Import should fail
self.import_project("raclette", self.generate_form_data(self.data), 400)
bills = project.get_pretty_bills()
# Check that there are no bills
self.assertEqual(len(bills), 0)
def test_import_no_currency_in_empty_project_with_currency(self):
# Import JSON without currencies (from ihatemoney < 5) in an empty
# project with a default currency.
self.post_project("raclette", default_currency="EUR")
self.login("raclette")
project = self.get_project("raclette")
self.import_project("raclette", self.generate_form_data(self.data))
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
# Check if name of bills are ok
b = [e["what"] for e in bills]
b.sort()
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
# All bills are converted to default project currency
self.assertEqual(b["currency"], "EUR")
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
def test_import_no_currency_in_empty_project_without_currency(self):
# Import JSON without currencies (from ihatemoney < 5) in an empty
# project with no default currency.
self.post_project("raclette")
self.login("raclette")
project = self.get_project("raclette")
self.import_project("raclette", self.generate_form_data(self.data))
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
# Check if name of bills are ok
b = [e["what"] for e in bills]
b.sort()
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
self.assertEqual(b["currency"], "XXX")
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
def test_import_partial_project(self):
# Import a JSON in a project with already existing data
self.post_project("raclette")
self.login("raclette")
project = self.get_project("raclette")
self.client.post(
"/raclette/members/add", data={"name": "zorglub", "weight": 2}
)
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "red wine",
"payer": 2,
"payed_for": [1, 3],
"amount": "200",
},
)
self.populate_data_with_currencies(["XXX", "XXX", "XXX"])
self.import_project("raclette", self.generate_form_data(self.data))
bills = project.get_pretty_bills()
# Check if all bills have been added
self.assertEqual(len(bills), len(self.data))
# Check if name of bills are ok
b = [e["what"] for e in bills]
b.sort()
ref = [e["what"] for e in self.data]
ref.sort()
self.assertEqual(b, ref)
# Check if other informations in bill are ok
for d in self.data:
for b in bills:
if b["what"] == d["what"]:
self.assertEqual(b["payer_name"], d["payer_name"])
self.assertEqual(b["amount"], d["amount"])
self.assertEqual(b["currency"], d["currency"])
self.assertEqual(b["payer_weight"], d["payer_weight"])
self.assertEqual(b["date"], d["date"])
list_project = [ower for ower in b["owers"]]
list_project.sort()
list_json = [ower for ower in d["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
def test_import_wrong_data(self):
self.post_project("raclette")
self.login("raclette")
data_wrong_keys = [
{
"checked": False,
"dimensions": {"width": 5, "height": 10},
"id": 1,
"name": "A green door",
"price": 12.5,
"tags": ["home", "green"],
}
]
data_amount_missing = [
{
"date": "2017-01-01",
"what": "refund",
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
}
]
for data in [data_wrong_keys, data_amount_missing]:
# Import should fail
self.import_project("raclette", self.generate_form_data(data), 400)
class ExportTestCase(IhatemoneyTestCase):
def test_export(self):
# Export a simple project without currencies
self.post_project("raclette")
# add participants
self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
self.client.post("/raclette/members/add", data={"name": "pépé"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3, 4],
"amount": "10.0",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "red wine",
"payer": 2,
"payed_for": [1, 3],
"amount": "200",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "refund",
"payer": 3,
"payed_for": [2],
"amount": "13.33",
},
)
# generate json export of bills
resp = self.client.get("/raclette/export/bills.json")
expected = [
{
"date": "2017-01-01",
"what": "refund",
"amount": 13.33,
"currency": "XXX",
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
},
{
"date": "2016-12-31",
"what": "red wine",
"amount": 200.0,
"currency": "XXX",
"payer_name": "fred",
"payer_weight": 1.0,
"owers": ["zorglub", "tata"],
},
{
"date": "2016-12-31",
"what": "fromage \xe0 raclette",
"amount": 10.0,
"currency": "XXX",
"payer_name": "zorglub",
"payer_weight": 2.0,
"owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"],
},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
# generate csv export of bills
resp = self.client.get("/raclette/export/bills.csv")
expected = [
"date,what,amount,currency,payer_name,payer_weight,owers",
"2017-01-01,refund,XXX,13.33,tata,1.0,fred",
'2016-12-31,red wine,XXX,200.0,fred,1.0,"zorglub, tata"',
'2016-12-31,fromage à raclette,10.0,XXX,zorglub,2.0,"zorglub, fred, tata, pépé"',
]
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# generate json export of transactions
resp = self.client.get("/raclette/export/transactions.json")
expected = [
{
"amount": 2.00,
"currency": "XXX",
"receiver": "fred",
"ower": "p\xe9p\xe9",
},
{"amount": 55.34, "currency": "XXX", "receiver": "fred", "ower": "tata"},
{
"amount": 127.33,
"currency": "XXX",
"receiver": "fred",
"ower": "zorglub",
},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
# generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv")
expected = [
"amount,currency,receiver,ower",
"2.0,XXX,fred,pépé",
"55.34,XXX,fred,tata",
"127.33,XXX,fred,zorglub",
]
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# wrong export_format should return a 404
resp = self.client.get("/raclette/export/transactions.wrong")
self.assertEqual(resp.status_code, 404)
def test_export_with_currencies(self):
self.post_project("raclette", default_currency="EUR")
# add participants
self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
self.client.post("/raclette/members/add", data={"name": "pépé"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3, 4],
"amount": "10.0",
"original_currency": "EUR",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "poutine from Québec",
"payer": 2,
"payed_for": [1, 3],
"amount": "100",
"original_currency": "CAD",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "refund",
"payer": 3,
"payed_for": [2],
"amount": "13.33",
"original_currency": "EUR",
},
)
# generate json export of bills
resp = self.client.get("/raclette/export/bills.json")
expected = [
{
"date": "2017-01-01",
"what": "refund",
"amount": 13.33,
"currency": "EUR",
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
},
{
"date": "2016-12-31",
"what": "poutine from Qu\xe9bec",
"amount": 100.0,
"currency": "CAD",
"payer_name": "fred",
"payer_weight": 1.0,
"owers": ["zorglub", "tata"],
},
{
"date": "2016-12-31",
"what": "fromage \xe0 raclette",
"amount": 10.0,
"currency": "EUR",
"payer_name": "zorglub",
"payer_weight": 2.0,
"owers": ["zorglub", "fred", "tata", "p\xe9p\xe9"],
},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
# generate csv export of bills
resp = self.client.get("/raclette/export/bills.csv")
expected = [
"date,what,amount,currency,payer_name,payer_weight,owers",
"2017-01-01,refund,13.33,EUR,tata,1.0,fred",
'2016-12-31,poutine from Québec,100.0,CAD,fred,1.0,"zorglub, tata"',
'2016-12-31,fromage à raclette,10.0,EUR,zorglub,2.0,"zorglub, fred, tata, pépé"',
]
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# generate json export of transactions (in EUR!)
resp = self.client.get("/raclette/export/transactions.json")
expected = [
{
"amount": 2.00,
"currency": "EUR",
"receiver": "fred",
"ower": "p\xe9p\xe9",
},
{"amount": 10.89, "currency": "EUR", "receiver": "fred", "ower": "tata"},
{"amount": 38.45, "currency": "EUR", "receiver": "fred", "ower": "zorglub"},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
# generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv")
expected = [
"amount,currency,receiver,ower",
"2.0,EUR,fred,pépé",
"10.89,EUR,fred,tata",
"38.45,EUR,fred,zorglub",
]
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# Change project currency to CAD
project = self.get_project("raclette")
project.switch_currency("CAD")
# generate json export of transactions (now in CAD!)
resp = self.client.get("/raclette/export/transactions.json")
expected = [
{
"amount": 3.00,
"currency": "CAD",
"receiver": "fred",
"ower": "p\xe9p\xe9",
},
{"amount": 16.34, "currency": "CAD", "receiver": "fred", "ower": "tata"},
{"amount": 57.67, "currency": "CAD", "receiver": "fred", "ower": "zorglub"},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
# generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv")
expected = [
"amount,currency,receiver,ower",
"3.0,CAD,fred,pépé",
"16.34,CAD,fred,tata",
"57.67,CAD,fred,zorglub",
]
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
class ImportTestCaseJSON(CommonTestCase.Import):
def generate_form_data(self, data):
return {"file": (list_of_dicts2json(data), "test.json")}
class ImportTestCaseCSV(CommonTestCase.Import):
def generate_form_data(self, data):
formatted_data = copy.deepcopy(data)
for d in formatted_data:
d["owers"] = ", ".join([o for o in d.get("owers", [])])
return {"file": (list_of_dicts2csv(formatted_data), "test.csv")}
if __name__ == "__main__":
unittest.main()

View file

@ -99,7 +99,7 @@ class CommandTestCase(BaseTestCase):
def test_demo_project_deletion(self):
self.create_project("demo")
self.assertEqual(models.Project.query.get("demo").name, "demo")
self.assertEqual(self.get_project("demo").name, "demo")
runner = self.app.test_cli_runner()
runner.invoke(delete_project, "demo")
@ -246,7 +246,7 @@ class CaptchaTestCase(IhatemoneyTestCase):
ENABLE_CAPTCHA = True
def test_project_creation_with_captcha(self):
with self.app.test_client() as c:
with self.client as c:
c.post(
"/create",
data={

View file

@ -2,7 +2,7 @@ import ast
import csv
from datetime import datetime, timedelta
from enum import Enum
from io import BytesIO, StringIO
from io import BytesIO, StringIO, TextIOWrapper
from json import JSONEncoder, dumps
import operator
import os
@ -150,6 +150,31 @@ def list_of_dicts2csv(dict_to_convert):
return csv_file
def csv2list_of_dicts(csv_to_convert):
"""Take a csv in-memory file and turns it into
a list of dictionnaries
"""
csv_file = TextIOWrapper(csv_to_convert, encoding="utf-8")
reader = csv.DictReader(csv_file)
result = []
for r in reader:
"""
cospend embeds various data helping (cospend) imports
'deleteMeIfYouWant' lines contains users
'categoryname' table contains categories description
we don't need them as we determine users and categories from bills
"""
if r["what"] == "deleteMeIfYouWant":
continue
elif r["what"] == "categoryname":
break
r["amount"] = float(r["amount"])
r["payer_weight"] = float(r["payer_weight"])
r["owers"] = [o.strip() for o in r["owers"].split(",")]
result.append(r)
return result
class LoginThrottler:
"""Simple login throttler used to limit authentication attempts based on client's ip address.
When using multiple workers, remaining number of attempts can get inconsistent

View file

@ -13,7 +13,6 @@ from functools import wraps
import json
import os
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from flask import (
Blueprint,
@ -44,13 +43,13 @@ from ihatemoney.forms import (
DestructiveActionProjectForm,
EditProjectForm,
EmptyForm,
ImportProjectForm,
InviteForm,
MemberForm,
PasswordReminder,
ProjectForm,
ProjectFormWithCaptcha,
ResetPasswordForm,
UploadForm,
get_billform_for,
)
from ihatemoney.history import get_history, get_history_queries
@ -58,12 +57,11 @@ from ihatemoney.models import Bill, LoggingMode, Person, Project, db
from ihatemoney.utils import (
LoginThrottler,
Redirect303,
csv2list_of_dicts,
format_form_errors,
get_members,
list_of_dicts2csv,
list_of_dicts2json,
render_localized_template,
same_bill,
send_email,
)
@ -412,17 +410,8 @@ def reset_password():
@main.route("/<project_id>/edit", methods=["GET", "POST"])
def edit_project():
edit_form = EditProjectForm(id=g.project.id)
import_form = ImportProjectForm(id=g.project.id)
delete_form = DestructiveActionProjectForm(id=g.project.id)
import_form = UploadForm()
# Import form
if import_form.validate_on_submit():
try:
import_project(import_form.file.data.stream, g.project)
flash(_("Project successfully uploaded"))
return redirect(url_for("main.list_bills"))
except ValueError as e:
flash(e.args[0], category="danger")
# Edit form
if edit_form.validate_on_submit():
@ -446,103 +435,71 @@ def edit_project():
return render_template(
"edit_project.html",
edit_form=edit_form,
delete_form=delete_form,
import_form=import_form,
delete_form=delete_form,
current_view="edit_project",
)
def import_project(file, project):
json_file = json.load(file)
@main.route("/<project_id>/import", methods=["POST"])
def import_project():
form = ImportProjectForm()
if form.validate():
try:
data = form.file.data
if data.mimetype == "application/json":
bills = json.load(data.stream)
elif data.mimetype == "text/csv":
try:
bills = csv2list_of_dicts(data)
except Exception:
raise ValueError(_("Unable to parse CSV"))
else:
raise ValueError("Unsupported file type")
# Check if JSON is correct
attr = ["what", "payer_name", "payer_weight", "amount", "currency", "date", "owers"]
attr.sort()
currencies = set()
for e in json_file:
# If currency is absent, empty, or explicitly set to XXX
# set it to project default.
if e.get("currency", "") in ["", "XXX"]:
e["currency"] = project.default_currency
if len(e) != len(attr):
raise ValueError(_("Invalid JSON"))
list_attr = []
for i in e:
list_attr.append(i)
list_attr.sort()
if list_attr != attr:
raise ValueError(_("Invalid JSON"))
# Keep track of currencies
currencies.add(e["currency"])
# Check data
attr = [
"amount",
"currency",
"date",
"owers",
"payer_name",
"payer_weight",
"what",
]
currencies = set()
for b in bills:
if b.get("currency", "") in ["", "XXX"]:
b["currency"] = g.project.default_currency
for a in attr:
if a not in b:
raise ValueError(_("Missing attribute {}").format(a))
currencies.add(b["currency"])
# Additional checks if project has no default currency
if project.default_currency == CurrencyConverter.no_currency:
# If bills have currencies, they must be consistent
if len(currencies - {CurrencyConverter.no_currency}) >= 2:
raise ValueError(
_(
"Cannot add bills in multiple currencies to a project without default currency"
)
)
# Strip currency from bills (since it's the same for every bill)
for e in json_file:
e["currency"] = CurrencyConverter.no_currency
# Additional checks if project has no default currency
if g.project.default_currency == CurrencyConverter.no_currency:
# If bills have currencies, they must be consistent
if len(currencies - {CurrencyConverter.no_currency}) >= 2:
raise ValueError(
_(
"Cannot add bills in multiple currencies to a project without default "
"currency"
)
)
# Strip currency from bills (since it's the same for every bill)
for b in bills:
b["currency"] = CurrencyConverter.no_currency
# From json : export list of members
members_json = get_members(json_file)
members = project.members
members_already_here = list()
for m in members:
members_already_here.append(str(m))
g.project.import_bills(bills)
# List all members not in the project and weight associated
# List of tuples (name,weight)
members_to_add = list()
for i in members_json:
if str(i[0]) not in members_already_here:
members_to_add.append(i)
# List bills not in the project
# Same format than JSON element
project_bills = project.get_pretty_bills()
bill_to_add = list()
for j in json_file:
same = False
for p in project_bills:
if same_bill(p, j):
same = True
break
if not same:
bill_to_add.append(j)
# Add users to DB
for m in members_to_add:
Person(name=m[0], project=project, weight=m[1])
db.session.commit()
id_dict = {}
for i in project.members:
id_dict[i.name] = i.id
# Create bills
for b in bill_to_add:
owers_id = list()
for ower in b["owers"]:
owers_id.append(id_dict[ower])
bill = Bill()
form = get_billform_for(project)
form.what = b["what"]
form.amount = b["amount"]
form.original_currency = b["currency"]
form.date = parse(b["date"])
form.payer = id_dict[b["payer_name"]]
form.payed_for = owers_id
db.session.add(form.fake_form(bill, project))
# Add bills to DB
db.session.commit()
flash(_("Project successfully uploaded"))
return redirect(url_for("main.list_bills"))
except ValueError as b:
flash(b.args[0], category="danger")
else:
for component, errors in form.errors.items():
flash(_(component + ": ") + ", ".join(errors), category="danger")
return redirect(request.headers.get("Referer") or url_for(".edit_project"))
@main.route("/<project_id>/delete", methods=["POST"])
@ -760,8 +717,7 @@ def add_bill():
session["last_selected_payer"] = form.payer.data
session.update()
bill = Bill()
db.session.add(form.save(bill, g.project))
db.session.add(form.export(g.project))
db.session.commit()
flash(_("The bill has been added"))