API: Migrate from flask-rest to flask-restful (#315)

The flask-rest custom json encoder is still needed
and thus was added to ihatemoney's utils.

Closes #298
This commit is contained in:
0livd 2018-01-25 17:41:28 +01:00 committed by Alexis Metaireau
parent 830718e1fe
commit b93ea4830d
6 changed files with 121 additions and 89 deletions

View file

@ -6,6 +6,11 @@ This document describes changes between each past release.
2.1 (unreleased) 2.1 (unreleased)
---------------- ----------------
Changed
=======
- Use flask-restful instead of deprecated flask-rest for the REST API (#315)
Fixed Fixed
===== =====

View file

@ -1,62 +1,68 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from flask import Blueprint, request from flask import Blueprint, request
from flask_rest import RESTResource, need_auth from flask_restful import Resource, Api, abort
from wtforms.fields.core import BooleanField from wtforms.fields.core import BooleanField
from ihatemoney.models import db, Project, Person, Bill from ihatemoney.models import db, Project, Person, Bill
from ihatemoney.forms import (ProjectForm, EditProjectForm, MemberForm, from ihatemoney.forms import (ProjectForm, EditProjectForm, MemberForm,
get_billform_for) get_billform_for)
from werkzeug.security import check_password_hash from werkzeug.security import check_password_hash
from functools import wraps
api = Blueprint("api", __name__, url_prefix="/api") api = Blueprint("api", __name__, url_prefix="/api")
restful_api = Api(api)
def check_project(*args, **kwargs): def need_auth(f):
"""Check the request for basic authentication for a given project. """Check the request for basic authentication for a given project.
Return the project if the authorization is good, False otherwise Return the project if the authorization is good, abort the request with a 401 otherwise
""" """
auth = request.authorization @wraps(f)
def wrapper(*args, **kwargs):
auth = request.authorization
project_id = kwargs.get("project_id")
# project_id should be contained in kwargs and equal to the username if auth and project_id and auth.username == project_id:
if auth and "project_id" in kwargs and \ project = Project.query.get(auth.username)
auth.username == kwargs["project_id"]: if project and check_password_hash(project.password, auth.password):
project = Project.query.get(auth.username) # The whole project object will be passed instead of project_id
if project and check_password_hash(project.password, auth.password): kwargs.pop("project_id")
return project return f(*args, project=project, **kwargs)
return False abort(401)
return wrapper
class ProjectHandler(object): class ProjectsHandler(Resource):
def post(self):
def add(self):
form = ProjectForm(meta={'csrf': False}) form = ProjectForm(meta={'csrf': False})
if form.validate(): if form.validate():
project = form.save() project = form.save()
db.session.add(project) db.session.add(project)
db.session.commit() db.session.commit()
return 201, project.id return project.id, 201
return 400, form.errors return form.errors, 400
class ProjectHandler(Resource):
method_decorators = [need_auth]
@need_auth(check_project, "project")
def get(self, project): def get(self, project):
return 200, project return project
@need_auth(check_project, "project")
def delete(self, project): def delete(self, project):
db.session.delete(project) db.session.delete(project)
db.session.commit() db.session.commit()
return 200, "DELETED" return "DELETED"
@need_auth(check_project, "project") def put(self, project):
def update(self, project):
form = EditProjectForm(meta={'csrf': False}) form = EditProjectForm(meta={'csrf': False})
if form.validate(): if form.validate():
form.update(project) form.update(project)
db.session.commit() db.session.commit()
return 200, "UPDATED" return "UPDATED"
return 400, form.errors return form.errors, 400
class APIMemberForm(MemberForm): class APIMemberForm(MemberForm):
@ -71,98 +77,92 @@ class APIMemberForm(MemberForm):
return super(APIMemberForm, self).save(project, person) return super(APIMemberForm, self).save(project, person)
class MemberHandler(object): class MembersHandler(Resource):
method_decorators = [need_auth]
def get(self, project, member_id): def get(self, project):
member = Person.query.get(member_id, project) return project.members
if not member or member.project != project:
return 404, "Not Found"
return 200, member
def list(self, project): def post(self, project):
return 200, project.members
def add(self, project):
form = MemberForm(project, meta={'csrf': False}) form = MemberForm(project, meta={'csrf': False})
if form.validate(): if form.validate():
member = Person() member = Person()
form.save(project, member) form.save(project, member)
db.session.commit() db.session.commit()
return 201, member.id return member.id, 201
return 400, form.errors return form.errors, 400
def update(self, project, member_id):
class MemberHandler(Resource):
method_decorators = [need_auth]
def get(self, project, member_id):
member = Person.query.get(member_id, project)
if not member or member.project != project:
return "Not Found", 404
return member
def put(self, project, member_id):
form = APIMemberForm(project, meta={'csrf': False}, edit=True) form = APIMemberForm(project, meta={'csrf': False}, edit=True)
if form.validate(): if form.validate():
member = Person.query.get(member_id, project) member = Person.query.get(member_id, project)
form.save(project, member) form.save(project, member)
db.session.commit() db.session.commit()
return 200, member return member
return 400, form.errors return form.errors, 400
def delete(self, project, member_id): def delete(self, project, member_id):
if project.remove_member(member_id): if project.remove_member(member_id):
return 200, "OK" return "OK"
return 404, "Not Found" return "Not Found", 404
class BillHandler(object): class BillsHandler(Resource):
method_decorators = [need_auth]
def get(self, project, bill_id): def get(self, project):
bill = Bill.query.get(project, bill_id)
if not bill:
return 404, "Not Found"
return 200, bill
def list(self, project):
return project.get_bills().all() return project.get_bills().all()
def add(self, project): def post(self, project):
form = get_billform_for(project, True, meta={'csrf': False}) form = get_billform_for(project, True, meta={'csrf': False})
if form.validate(): if form.validate():
bill = Bill() bill = Bill()
form.save(bill, project) form.save(bill, project)
db.session.add(bill) db.session.add(bill)
db.session.commit() db.session.commit()
return 201, bill.id return bill.id, 201
return 400, form.errors return form.errors, 400
def update(self, project, bill_id):
class BillHandler(Resource):
method_decorators = [need_auth]
def get(self, project, bill_id):
bill = Bill.query.get(project, bill_id)
if not bill:
return "Not Found", 404
return bill, 200
def put(self, project, bill_id):
form = get_billform_for(project, True, meta={'csrf': False}) form = get_billform_for(project, True, meta={'csrf': False})
if form.validate(): if form.validate():
bill = Bill.query.get(project, bill_id) bill = Bill.query.get(project, bill_id)
form.save(bill, project) form.save(bill, project)
db.session.commit() db.session.commit()
return 200, bill.id return bill.id, 200
return 400, form.errors return form.errors, 400
def delete(self, project, bill_id): def delete(self, project, bill_id):
bill = Bill.query.delete(project, bill_id) bill = Bill.query.delete(project, bill_id)
db.session.commit() db.session.commit()
if not bill: if not bill:
return 404, "Not Found" return "Not Found", 404
return 200, "OK" return "OK", 200
project_resource = RESTResource( restful_api.add_resource(ProjectsHandler, '/projects')
name="project", restful_api.add_resource(ProjectHandler, '/projects/<string:project_id>')
route="/projects", restful_api.add_resource(MembersHandler, "/projects/<string:project_id>/members")
app=api, restful_api.add_resource(MemberHandler, "/projects/<string:project_id>/members/<int:member_id>")
actions=["add", "update", "delete", "get"], restful_api.add_resource(BillsHandler, "/projects/<string:project_id>/bills")
handler=ProjectHandler()) restful_api.add_resource(BillHandler, "/projects/<string:project_id>/bills/<int:bill_id>")
member_resource = RESTResource(
name="member",
inject_name="project",
route="/projects/<project_id>/members",
app=api,
handler=MemberHandler(),
authentifier=check_project)
bill_resource = RESTResource(
name="bill",
inject_name="project",
route="/projects/<project_id>/bills",
app=api,
handler=BillHandler(),
authentifier=check_project)

View file

@ -11,7 +11,7 @@ from werkzeug.contrib.fixers import ProxyFix
from ihatemoney.api import api from ihatemoney.api import api
from ihatemoney.models import db from ihatemoney.models import db
from ihatemoney.utils import PrefixedWSGI, minimal_round from ihatemoney.utils import PrefixedWSGI, minimal_round, IhmJSONEncoder
from ihatemoney.web import main as web_interface from ihatemoney.web import main as web_interface
from ihatemoney import default_settings from ihatemoney import default_settings
@ -68,6 +68,8 @@ def load_configuration(app, configuration=None):
app.config.from_pyfile(env_var_config) app.config.from_pyfile(env_var_config)
else: else:
app.config.from_pyfile('ihatemoney.cfg', silent=True) app.config.from_pyfile('ihatemoney.cfg', silent=True)
# Configure custom JSONEncoder used by the API
app.config['RESTFUL_JSON'] = {'cls': IhmJSONEncoder}
def validate_configuration(app): def validate_configuration(app):

View file

@ -1053,7 +1053,7 @@ class APITestCase(IhatemoneyTestCase):
}) })
self.assertTrue(400, resp.status_code) self.assertTrue(400, resp.status_code)
self.assertEqual('{"contact_email": ["Invalid email address."]}', self.assertEqual('{"contact_email": ["Invalid email address."]}\n',
resp.data.decode('utf-8')) resp.data.decode('utf-8'))
# create it # create it
@ -1139,7 +1139,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette")) headers=self.get_auth("raclette"))
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual('[]', req.data.decode('utf-8')) self.assertEqual('[]\n', req.data.decode('utf-8'))
# add a member # add a member
req = self.client.post("/api/projects/raclette/members", data={ req = self.client.post("/api/projects/raclette/members", data={
@ -1148,7 +1148,7 @@ class APITestCase(IhatemoneyTestCase):
# the id of the new member should be returned # the id of the new member should be returned
self.assertStatus(201, req) self.assertStatus(201, req)
self.assertEqual("1", req.data.decode('utf-8')) self.assertEqual("1\n", req.data.decode('utf-8'))
# the list of members should contain one member # the list of members should contain one member
req = self.client.get("/api/projects/raclette/members", req = self.client.get("/api/projects/raclette/members",
@ -1223,7 +1223,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette")) headers=self.get_auth("raclette"))
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual('[]', req.data.decode('utf-8')) self.assertEqual('[]\n', req.data.decode('utf-8'))
def test_bills(self): def test_bills(self):
# create a project # create a project
@ -1239,7 +1239,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette")) headers=self.get_auth("raclette"))
self.assertStatus(200, req) self.assertStatus(200, req)
self.assertEqual("[]", req.data.decode('utf-8')) self.assertEqual("[]\n", req.data.decode('utf-8'))
# add a bill # add a bill
req = self.client.post("/api/projects/raclette/bills", data={ req = self.client.post("/api/projects/raclette/bills", data={
@ -1252,7 +1252,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id # should return the id
self.assertStatus(201, req) self.assertStatus(201, req)
self.assertEqual(req.data.decode('utf-8'), "1") self.assertEqual(req.data.decode('utf-8'), "1\n")
# get this bill details # get this bill details
req = self.client.get("/api/projects/raclette/bills/1", req = self.client.get("/api/projects/raclette/bills/1",
@ -1288,7 +1288,7 @@ class APITestCase(IhatemoneyTestCase):
}, headers=self.get_auth("raclette")) }, headers=self.get_auth("raclette"))
self.assertStatus(400, req) self.assertStatus(400, req)
self.assertEqual('{"date": ["This field is required."]}', req.data.decode('utf-8')) self.assertEqual('{"date": ["This field is required."]}\n', req.data.decode('utf-8'))
# edit a bill # edit a bill
req = self.client.put("/api/projects/raclette/bills/1", data={ req = self.client.put("/api/projects/raclette/bills/1", data={

View file

@ -3,7 +3,7 @@ import re
from io import BytesIO, StringIO from io import BytesIO, StringIO
from jinja2 import filters from jinja2 import filters
from json import dumps from json import dumps, JSONEncoder
from flask import redirect from flask import redirect
from werkzeug.routing import HTTPException, RoutingException from werkzeug.routing import HTTPException, RoutingException
import six import six
@ -170,3 +170,28 @@ class LoginThrottler():
def reset(self, ip): def reset(self, ip):
self._attempts.pop(ip, None) self._attempts.pop(ip, None)
class IhmJSONEncoder(JSONEncoder):
"""Subclass of the default encoder to support custom objects.
Taken from the deprecated flask-rest package."""
def default(self, o):
if hasattr(o, "_to_serialize"):
# build up the object
data = {}
for attr in o._to_serialize:
data[attr] = getattr(o, attr)
return data
elif hasattr(o, "isoformat"):
return o.isoformat()
else:
try:
from flask_babel import speaklater
if isinstance(o, speaklater.LazyString):
try:
return unicode(o) # For python 2.
except NameError:
return str(o) # For python 3.
except ImportError:
pass
return JSONEncoder.default(self, o)

View file

@ -5,7 +5,7 @@ flask-mail>=0.8
Flask-Migrate>=1.8.0 Flask-Migrate>=1.8.0
Flask-script Flask-script
flask-babel flask-babel
flask-rest>=1.3 flask-restful>=0.3.6
jinja2>=2.6 jinja2>=2.6
raven raven
blinker blinker