API: Migrate from flask-rest to flask-restful (#315)

The flask-rest custom json encoder is still needed
and thus was added to ihatemoney's utils.

Closes #298
This commit is contained in:
0livd 2018-01-25 17:41:28 +01:00 committed by Alexis Metaireau
parent 830718e1fe
commit b93ea4830d
6 changed files with 121 additions and 89 deletions

View file

@ -6,6 +6,11 @@ This document describes changes between each past release.
2.1 (unreleased)
----------------
Changed
=======
- Use flask-restful instead of deprecated flask-rest for the REST API (#315)
Fixed
=====

View file

@ -1,62 +1,68 @@
# -*- coding: utf-8 -*-
from flask import Blueprint, request
from flask_rest import RESTResource, need_auth
from flask_restful import Resource, Api, abort
from wtforms.fields.core import BooleanField
from ihatemoney.models import db, Project, Person, Bill
from ihatemoney.forms import (ProjectForm, EditProjectForm, MemberForm,
get_billform_for)
from werkzeug.security import check_password_hash
from functools import wraps
api = Blueprint("api", __name__, url_prefix="/api")
restful_api = Api(api)
def check_project(*args, **kwargs):
def need_auth(f):
"""Check the request for basic authentication for a given project.
Return the project if the authorization is good, False otherwise
Return the project if the authorization is good, abort the request with a 401 otherwise
"""
auth = request.authorization
@wraps(f)
def wrapper(*args, **kwargs):
auth = request.authorization
project_id = kwargs.get("project_id")
# project_id should be contained in kwargs and equal to the username
if auth and "project_id" in kwargs and \
auth.username == kwargs["project_id"]:
project = Project.query.get(auth.username)
if project and check_password_hash(project.password, auth.password):
return project
return False
if auth and project_id and auth.username == project_id:
project = Project.query.get(auth.username)
if project and check_password_hash(project.password, auth.password):
# The whole project object will be passed instead of project_id
kwargs.pop("project_id")
return f(*args, project=project, **kwargs)
abort(401)
return wrapper
class ProjectHandler(object):
def add(self):
class ProjectsHandler(Resource):
def post(self):
form = ProjectForm(meta={'csrf': False})
if form.validate():
project = form.save()
db.session.add(project)
db.session.commit()
return 201, project.id
return 400, form.errors
return project.id, 201
return form.errors, 400
class ProjectHandler(Resource):
method_decorators = [need_auth]
@need_auth(check_project, "project")
def get(self, project):
return 200, project
return project
@need_auth(check_project, "project")
def delete(self, project):
db.session.delete(project)
db.session.commit()
return 200, "DELETED"
return "DELETED"
@need_auth(check_project, "project")
def update(self, project):
def put(self, project):
form = EditProjectForm(meta={'csrf': False})
if form.validate():
form.update(project)
db.session.commit()
return 200, "UPDATED"
return 400, form.errors
return "UPDATED"
return form.errors, 400
class APIMemberForm(MemberForm):
@ -71,98 +77,92 @@ class APIMemberForm(MemberForm):
return super(APIMemberForm, self).save(project, person)
class MemberHandler(object):
class MembersHandler(Resource):
method_decorators = [need_auth]
def get(self, project, member_id):
member = Person.query.get(member_id, project)
if not member or member.project != project:
return 404, "Not Found"
return 200, member
def get(self, project):
return project.members
def list(self, project):
return 200, project.members
def add(self, project):
def post(self, project):
form = MemberForm(project, meta={'csrf': False})
if form.validate():
member = Person()
form.save(project, member)
db.session.commit()
return 201, member.id
return 400, form.errors
return member.id, 201
return form.errors, 400
def update(self, project, member_id):
class MemberHandler(Resource):
method_decorators = [need_auth]
def get(self, project, member_id):
member = Person.query.get(member_id, project)
if not member or member.project != project:
return "Not Found", 404
return member
def put(self, project, member_id):
form = APIMemberForm(project, meta={'csrf': False}, edit=True)
if form.validate():
member = Person.query.get(member_id, project)
form.save(project, member)
db.session.commit()
return 200, member
return 400, form.errors
return member
return form.errors, 400
def delete(self, project, member_id):
if project.remove_member(member_id):
return 200, "OK"
return 404, "Not Found"
return "OK"
return "Not Found", 404
class BillHandler(object):
class BillsHandler(Resource):
method_decorators = [need_auth]
def get(self, project, bill_id):
bill = Bill.query.get(project, bill_id)
if not bill:
return 404, "Not Found"
return 200, bill
def list(self, project):
def get(self, project):
return project.get_bills().all()
def add(self, project):
def post(self, project):
form = get_billform_for(project, True, meta={'csrf': False})
if form.validate():
bill = Bill()
form.save(bill, project)
db.session.add(bill)
db.session.commit()
return 201, bill.id
return 400, form.errors
return bill.id, 201
return form.errors, 400
def update(self, project, bill_id):
class BillHandler(Resource):
method_decorators = [need_auth]
def get(self, project, bill_id):
bill = Bill.query.get(project, bill_id)
if not bill:
return "Not Found", 404
return bill, 200
def put(self, project, bill_id):
form = get_billform_for(project, True, meta={'csrf': False})
if form.validate():
bill = Bill.query.get(project, bill_id)
form.save(bill, project)
db.session.commit()
return 200, bill.id
return 400, form.errors
return bill.id, 200
return form.errors, 400
def delete(self, project, bill_id):
bill = Bill.query.delete(project, bill_id)
db.session.commit()
if not bill:
return 404, "Not Found"
return 200, "OK"
return "Not Found", 404
return "OK", 200
project_resource = RESTResource(
name="project",
route="/projects",
app=api,
actions=["add", "update", "delete", "get"],
handler=ProjectHandler())
member_resource = RESTResource(
name="member",
inject_name="project",
route="/projects/<project_id>/members",
app=api,
handler=MemberHandler(),
authentifier=check_project)
bill_resource = RESTResource(
name="bill",
inject_name="project",
route="/projects/<project_id>/bills",
app=api,
handler=BillHandler(),
authentifier=check_project)
restful_api.add_resource(ProjectsHandler, '/projects')
restful_api.add_resource(ProjectHandler, '/projects/<string:project_id>')
restful_api.add_resource(MembersHandler, "/projects/<string:project_id>/members")
restful_api.add_resource(MemberHandler, "/projects/<string:project_id>/members/<int:member_id>")
restful_api.add_resource(BillsHandler, "/projects/<string:project_id>/bills")
restful_api.add_resource(BillHandler, "/projects/<string:project_id>/bills/<int:bill_id>")

View file

@ -11,7 +11,7 @@ from werkzeug.contrib.fixers import ProxyFix
from ihatemoney.api import api
from ihatemoney.models import db
from ihatemoney.utils import PrefixedWSGI, minimal_round
from ihatemoney.utils import PrefixedWSGI, minimal_round, IhmJSONEncoder
from ihatemoney.web import main as web_interface
from ihatemoney import default_settings
@ -68,6 +68,8 @@ def load_configuration(app, configuration=None):
app.config.from_pyfile(env_var_config)
else:
app.config.from_pyfile('ihatemoney.cfg', silent=True)
# Configure custom JSONEncoder used by the API
app.config['RESTFUL_JSON'] = {'cls': IhmJSONEncoder}
def validate_configuration(app):

View file

@ -1053,7 +1053,7 @@ class APITestCase(IhatemoneyTestCase):
})
self.assertTrue(400, resp.status_code)
self.assertEqual('{"contact_email": ["Invalid email address."]}',
self.assertEqual('{"contact_email": ["Invalid email address."]}\n',
resp.data.decode('utf-8'))
# create it
@ -1139,7 +1139,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual('[]', req.data.decode('utf-8'))
self.assertEqual('[]\n', req.data.decode('utf-8'))
# add a member
req = self.client.post("/api/projects/raclette/members", data={
@ -1148,7 +1148,7 @@ class APITestCase(IhatemoneyTestCase):
# the id of the new member should be returned
self.assertStatus(201, req)
self.assertEqual("1", req.data.decode('utf-8'))
self.assertEqual("1\n", req.data.decode('utf-8'))
# the list of members should contain one member
req = self.client.get("/api/projects/raclette/members",
@ -1223,7 +1223,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual('[]', req.data.decode('utf-8'))
self.assertEqual('[]\n', req.data.decode('utf-8'))
def test_bills(self):
# create a project
@ -1239,7 +1239,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
self.assertEqual("[]", req.data.decode('utf-8'))
self.assertEqual("[]\n", req.data.decode('utf-8'))
# add a bill
req = self.client.post("/api/projects/raclette/bills", data={
@ -1252,7 +1252,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data.decode('utf-8'), "1")
self.assertEqual(req.data.decode('utf-8'), "1\n")
# get this bill details
req = self.client.get("/api/projects/raclette/bills/1",
@ -1288,7 +1288,7 @@ class APITestCase(IhatemoneyTestCase):
}, headers=self.get_auth("raclette"))
self.assertStatus(400, req)
self.assertEqual('{"date": ["This field is required."]}', req.data.decode('utf-8'))
self.assertEqual('{"date": ["This field is required."]}\n', req.data.decode('utf-8'))
# edit a bill
req = self.client.put("/api/projects/raclette/bills/1", data={

View file

@ -3,7 +3,7 @@ import re
from io import BytesIO, StringIO
from jinja2 import filters
from json import dumps
from json import dumps, JSONEncoder
from flask import redirect
from werkzeug.routing import HTTPException, RoutingException
import six
@ -170,3 +170,28 @@ class LoginThrottler():
def reset(self, ip):
self._attempts.pop(ip, None)
class IhmJSONEncoder(JSONEncoder):
"""Subclass of the default encoder to support custom objects.
Taken from the deprecated flask-rest package."""
def default(self, o):
if hasattr(o, "_to_serialize"):
# build up the object
data = {}
for attr in o._to_serialize:
data[attr] = getattr(o, attr)
return data
elif hasattr(o, "isoformat"):
return o.isoformat()
else:
try:
from flask_babel import speaklater
if isinstance(o, speaklater.LazyString):
try:
return unicode(o) # For python 2.
except NameError:
return str(o) # For python 3.
except ImportError:
pass
return JSONEncoder.default(self, o)

View file

@ -5,7 +5,7 @@ flask-mail>=0.8
Flask-Migrate>=1.8.0
Flask-script
flask-babel
flask-rest>=1.3
flask-restful>=0.3.6
jinja2>=2.6
raven
blinker