csv compatible import

This commit is contained in:
Youe Graillot 2021-11-28 02:13:10 +01:00
parent d24efac10a
commit e7d991ecc2
3 changed files with 24 additions and 7 deletions

View file

@ -187,8 +187,7 @@ class ImportProjectForm(FlaskForm):
"File", "File",
validators=[ validators=[
FileRequired(), FileRequired(),
"JSON", FileAllowed(["json", "JSON", "csv", "CSV"], "Incorrect file format"),
validators=[FileRequired(), FileAllowed(["json", "JSON"], "JSON only!")],
], ],
description=_("Import previously exported JSON file"), description=_("Import previously exported JSON file"),
) )

View file

@ -2,7 +2,7 @@ import ast
import csv import csv
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from io import BytesIO, StringIO from io import BytesIO, StringIO, TextIOWrapper
from json import JSONEncoder, dumps from json import JSONEncoder, dumps
import operator import operator
import os import os
@ -150,6 +150,21 @@ def list_of_dicts2csv(dict_to_convert):
return csv_file return csv_file
def csv2list_of_dicts(csv_to_convert):
"""Take a csv in-memory file and turns it into
a list of dictionnaries
"""
csv_file = TextIOWrapper(csv_to_convert)
reader = csv.DictReader(csv_file)
result = []
for r in reader:
r["amount"] = float(r["amount"])
r["payer_weight"] = float(r["payer_weight"])
r["owers"] = [o.strip() for o in r["owers"].split(",")]
result.append(r)
return result
class LoginThrottler: class LoginThrottler:
"""Simple login throttler used to limit authentication attempts based on client's ip address. """Simple login throttler used to limit authentication attempts based on client's ip address.
When using multiple workers, remaining number of attempts can get inconsistent When using multiple workers, remaining number of attempts can get inconsistent

View file

@ -58,6 +58,7 @@ from ihatemoney.models import Bill, LoggingMode, Person, Project, db
from ihatemoney.utils import ( from ihatemoney.utils import (
LoginThrottler, LoginThrottler,
Redirect303, Redirect303,
csv2list_of_dicts,
format_form_errors, format_form_errors,
get_members, get_members,
list_of_dicts2csv, list_of_dicts2csv,
@ -451,6 +452,11 @@ def import_project():
data = form.file.data data = form.file.data
if data.mimetype == "application/json": if data.mimetype == "application/json":
json_file = json.load(data.stream) json_file = json.load(data.stream)
elif data.mimetype == "text/csv":
try:
json_file = csv2list_of_dicts(data)
except Exception as e:
raise ValueError(_("Unable to parse CSV"))
else: else:
raise ValueError("Unsupported file type") raise ValueError("Unsupported file type")
@ -554,10 +560,7 @@ def import_project():
return redirect(url_for("main.list_bills")) return redirect(url_for("main.list_bills"))
except ValueError as e: except ValueError as e:
flash(e.args[0], category="danger") flash(e.args[0], category="danger")
return render_template( return redirect(url_for(".edit_project"))
"edit_project.html",
current_view="edit_project",
)
else: else:
for component, errors in form.errors.items(): for component, errors in form.errors.items():
flash(_(component + ": ") + ", ".join(errors), category="danger") flash(_(component + ": ") + ", ".join(errors), category="danger")