mirror of
https://github.com/spiral-project/ihatemoney.git
synced 2025-04-29 01:42:37 +02:00
API: Migrate from flask-rest to flask-restful (#315)
The flask-rest custom json encoder is still needed and thus was added to ihatemoney's utils. Closes #298
This commit is contained in:
parent
830718e1fe
commit
b93ea4830d
6 changed files with 121 additions and 89 deletions
|
@ -6,6 +6,11 @@ This document describes changes between each past release.
|
||||||
2.1 (unreleased)
|
2.1 (unreleased)
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
|
Changed
|
||||||
|
=======
|
||||||
|
|
||||||
|
- Use flask-restful instead of deprecated flask-rest for the REST API (#315)
|
||||||
|
|
||||||
Fixed
|
Fixed
|
||||||
=====
|
=====
|
||||||
|
|
||||||
|
|
|
@ -1,62 +1,68 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from flask import Blueprint, request
|
from flask import Blueprint, request
|
||||||
from flask_rest import RESTResource, need_auth
|
from flask_restful import Resource, Api, abort
|
||||||
from wtforms.fields.core import BooleanField
|
from wtforms.fields.core import BooleanField
|
||||||
|
|
||||||
from ihatemoney.models import db, Project, Person, Bill
|
from ihatemoney.models import db, Project, Person, Bill
|
||||||
from ihatemoney.forms import (ProjectForm, EditProjectForm, MemberForm,
|
from ihatemoney.forms import (ProjectForm, EditProjectForm, MemberForm,
|
||||||
get_billform_for)
|
get_billform_for)
|
||||||
from werkzeug.security import check_password_hash
|
from werkzeug.security import check_password_hash
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
api = Blueprint("api", __name__, url_prefix="/api")
|
api = Blueprint("api", __name__, url_prefix="/api")
|
||||||
|
restful_api = Api(api)
|
||||||
|
|
||||||
|
|
||||||
def check_project(*args, **kwargs):
|
def need_auth(f):
|
||||||
"""Check the request for basic authentication for a given project.
|
"""Check the request for basic authentication for a given project.
|
||||||
|
|
||||||
Return the project if the authorization is good, False otherwise
|
Return the project if the authorization is good, abort the request with a 401 otherwise
|
||||||
"""
|
"""
|
||||||
|
@wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
auth = request.authorization
|
auth = request.authorization
|
||||||
|
project_id = kwargs.get("project_id")
|
||||||
|
|
||||||
# project_id should be contained in kwargs and equal to the username
|
if auth and project_id and auth.username == project_id:
|
||||||
if auth and "project_id" in kwargs and \
|
|
||||||
auth.username == kwargs["project_id"]:
|
|
||||||
project = Project.query.get(auth.username)
|
project = Project.query.get(auth.username)
|
||||||
if project and check_password_hash(project.password, auth.password):
|
if project and check_password_hash(project.password, auth.password):
|
||||||
return project
|
# The whole project object will be passed instead of project_id
|
||||||
return False
|
kwargs.pop("project_id")
|
||||||
|
return f(*args, project=project, **kwargs)
|
||||||
|
abort(401)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class ProjectHandler(object):
|
class ProjectsHandler(Resource):
|
||||||
|
def post(self):
|
||||||
def add(self):
|
|
||||||
form = ProjectForm(meta={'csrf': False})
|
form = ProjectForm(meta={'csrf': False})
|
||||||
if form.validate():
|
if form.validate():
|
||||||
project = form.save()
|
project = form.save()
|
||||||
db.session.add(project)
|
db.session.add(project)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return 201, project.id
|
return project.id, 201
|
||||||
return 400, form.errors
|
return form.errors, 400
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectHandler(Resource):
|
||||||
|
method_decorators = [need_auth]
|
||||||
|
|
||||||
@need_auth(check_project, "project")
|
|
||||||
def get(self, project):
|
def get(self, project):
|
||||||
return 200, project
|
return project
|
||||||
|
|
||||||
@need_auth(check_project, "project")
|
|
||||||
def delete(self, project):
|
def delete(self, project):
|
||||||
db.session.delete(project)
|
db.session.delete(project)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return 200, "DELETED"
|
return "DELETED"
|
||||||
|
|
||||||
@need_auth(check_project, "project")
|
def put(self, project):
|
||||||
def update(self, project):
|
|
||||||
form = EditProjectForm(meta={'csrf': False})
|
form = EditProjectForm(meta={'csrf': False})
|
||||||
if form.validate():
|
if form.validate():
|
||||||
form.update(project)
|
form.update(project)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return 200, "UPDATED"
|
return "UPDATED"
|
||||||
return 400, form.errors
|
return form.errors, 400
|
||||||
|
|
||||||
|
|
||||||
class APIMemberForm(MemberForm):
|
class APIMemberForm(MemberForm):
|
||||||
|
@ -71,98 +77,92 @@ class APIMemberForm(MemberForm):
|
||||||
return super(APIMemberForm, self).save(project, person)
|
return super(APIMemberForm, self).save(project, person)
|
||||||
|
|
||||||
|
|
||||||
class MemberHandler(object):
|
class MembersHandler(Resource):
|
||||||
|
method_decorators = [need_auth]
|
||||||
|
|
||||||
def get(self, project, member_id):
|
def get(self, project):
|
||||||
member = Person.query.get(member_id, project)
|
return project.members
|
||||||
if not member or member.project != project:
|
|
||||||
return 404, "Not Found"
|
|
||||||
return 200, member
|
|
||||||
|
|
||||||
def list(self, project):
|
def post(self, project):
|
||||||
return 200, project.members
|
|
||||||
|
|
||||||
def add(self, project):
|
|
||||||
form = MemberForm(project, meta={'csrf': False})
|
form = MemberForm(project, meta={'csrf': False})
|
||||||
if form.validate():
|
if form.validate():
|
||||||
member = Person()
|
member = Person()
|
||||||
form.save(project, member)
|
form.save(project, member)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return 201, member.id
|
return member.id, 201
|
||||||
return 400, form.errors
|
return form.errors, 400
|
||||||
|
|
||||||
def update(self, project, member_id):
|
|
||||||
|
class MemberHandler(Resource):
|
||||||
|
method_decorators = [need_auth]
|
||||||
|
|
||||||
|
def get(self, project, member_id):
|
||||||
|
member = Person.query.get(member_id, project)
|
||||||
|
if not member or member.project != project:
|
||||||
|
return "Not Found", 404
|
||||||
|
return member
|
||||||
|
|
||||||
|
def put(self, project, member_id):
|
||||||
form = APIMemberForm(project, meta={'csrf': False}, edit=True)
|
form = APIMemberForm(project, meta={'csrf': False}, edit=True)
|
||||||
if form.validate():
|
if form.validate():
|
||||||
member = Person.query.get(member_id, project)
|
member = Person.query.get(member_id, project)
|
||||||
form.save(project, member)
|
form.save(project, member)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return 200, member
|
return member
|
||||||
return 400, form.errors
|
return form.errors, 400
|
||||||
|
|
||||||
def delete(self, project, member_id):
|
def delete(self, project, member_id):
|
||||||
if project.remove_member(member_id):
|
if project.remove_member(member_id):
|
||||||
return 200, "OK"
|
return "OK"
|
||||||
return 404, "Not Found"
|
return "Not Found", 404
|
||||||
|
|
||||||
|
|
||||||
class BillHandler(object):
|
class BillsHandler(Resource):
|
||||||
|
method_decorators = [need_auth]
|
||||||
|
|
||||||
def get(self, project, bill_id):
|
def get(self, project):
|
||||||
bill = Bill.query.get(project, bill_id)
|
|
||||||
if not bill:
|
|
||||||
return 404, "Not Found"
|
|
||||||
return 200, bill
|
|
||||||
|
|
||||||
def list(self, project):
|
|
||||||
return project.get_bills().all()
|
return project.get_bills().all()
|
||||||
|
|
||||||
def add(self, project):
|
def post(self, project):
|
||||||
form = get_billform_for(project, True, meta={'csrf': False})
|
form = get_billform_for(project, True, meta={'csrf': False})
|
||||||
if form.validate():
|
if form.validate():
|
||||||
bill = Bill()
|
bill = Bill()
|
||||||
form.save(bill, project)
|
form.save(bill, project)
|
||||||
db.session.add(bill)
|
db.session.add(bill)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return 201, bill.id
|
return bill.id, 201
|
||||||
return 400, form.errors
|
return form.errors, 400
|
||||||
|
|
||||||
def update(self, project, bill_id):
|
|
||||||
|
class BillHandler(Resource):
|
||||||
|
method_decorators = [need_auth]
|
||||||
|
|
||||||
|
def get(self, project, bill_id):
|
||||||
|
bill = Bill.query.get(project, bill_id)
|
||||||
|
if not bill:
|
||||||
|
return "Not Found", 404
|
||||||
|
return bill, 200
|
||||||
|
|
||||||
|
def put(self, project, bill_id):
|
||||||
form = get_billform_for(project, True, meta={'csrf': False})
|
form = get_billform_for(project, True, meta={'csrf': False})
|
||||||
if form.validate():
|
if form.validate():
|
||||||
bill = Bill.query.get(project, bill_id)
|
bill = Bill.query.get(project, bill_id)
|
||||||
form.save(bill, project)
|
form.save(bill, project)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return 200, bill.id
|
return bill.id, 200
|
||||||
return 400, form.errors
|
return form.errors, 400
|
||||||
|
|
||||||
def delete(self, project, bill_id):
|
def delete(self, project, bill_id):
|
||||||
bill = Bill.query.delete(project, bill_id)
|
bill = Bill.query.delete(project, bill_id)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
if not bill:
|
if not bill:
|
||||||
return 404, "Not Found"
|
return "Not Found", 404
|
||||||
return 200, "OK"
|
return "OK", 200
|
||||||
|
|
||||||
|
|
||||||
project_resource = RESTResource(
|
restful_api.add_resource(ProjectsHandler, '/projects')
|
||||||
name="project",
|
restful_api.add_resource(ProjectHandler, '/projects/<string:project_id>')
|
||||||
route="/projects",
|
restful_api.add_resource(MembersHandler, "/projects/<string:project_id>/members")
|
||||||
app=api,
|
restful_api.add_resource(MemberHandler, "/projects/<string:project_id>/members/<int:member_id>")
|
||||||
actions=["add", "update", "delete", "get"],
|
restful_api.add_resource(BillsHandler, "/projects/<string:project_id>/bills")
|
||||||
handler=ProjectHandler())
|
restful_api.add_resource(BillHandler, "/projects/<string:project_id>/bills/<int:bill_id>")
|
||||||
|
|
||||||
member_resource = RESTResource(
|
|
||||||
name="member",
|
|
||||||
inject_name="project",
|
|
||||||
route="/projects/<project_id>/members",
|
|
||||||
app=api,
|
|
||||||
handler=MemberHandler(),
|
|
||||||
authentifier=check_project)
|
|
||||||
|
|
||||||
bill_resource = RESTResource(
|
|
||||||
name="bill",
|
|
||||||
inject_name="project",
|
|
||||||
route="/projects/<project_id>/bills",
|
|
||||||
app=api,
|
|
||||||
handler=BillHandler(),
|
|
||||||
authentifier=check_project)
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from werkzeug.contrib.fixers import ProxyFix
|
||||||
|
|
||||||
from ihatemoney.api import api
|
from ihatemoney.api import api
|
||||||
from ihatemoney.models import db
|
from ihatemoney.models import db
|
||||||
from ihatemoney.utils import PrefixedWSGI, minimal_round
|
from ihatemoney.utils import PrefixedWSGI, minimal_round, IhmJSONEncoder
|
||||||
from ihatemoney.web import main as web_interface
|
from ihatemoney.web import main as web_interface
|
||||||
|
|
||||||
from ihatemoney import default_settings
|
from ihatemoney import default_settings
|
||||||
|
@ -68,6 +68,8 @@ def load_configuration(app, configuration=None):
|
||||||
app.config.from_pyfile(env_var_config)
|
app.config.from_pyfile(env_var_config)
|
||||||
else:
|
else:
|
||||||
app.config.from_pyfile('ihatemoney.cfg', silent=True)
|
app.config.from_pyfile('ihatemoney.cfg', silent=True)
|
||||||
|
# Configure custom JSONEncoder used by the API
|
||||||
|
app.config['RESTFUL_JSON'] = {'cls': IhmJSONEncoder}
|
||||||
|
|
||||||
|
|
||||||
def validate_configuration(app):
|
def validate_configuration(app):
|
||||||
|
|
|
@ -1053,7 +1053,7 @@ class APITestCase(IhatemoneyTestCase):
|
||||||
})
|
})
|
||||||
|
|
||||||
self.assertTrue(400, resp.status_code)
|
self.assertTrue(400, resp.status_code)
|
||||||
self.assertEqual('{"contact_email": ["Invalid email address."]}',
|
self.assertEqual('{"contact_email": ["Invalid email address."]}\n',
|
||||||
resp.data.decode('utf-8'))
|
resp.data.decode('utf-8'))
|
||||||
|
|
||||||
# create it
|
# create it
|
||||||
|
@ -1139,7 +1139,7 @@ class APITestCase(IhatemoneyTestCase):
|
||||||
headers=self.get_auth("raclette"))
|
headers=self.get_auth("raclette"))
|
||||||
|
|
||||||
self.assertStatus(200, req)
|
self.assertStatus(200, req)
|
||||||
self.assertEqual('[]', req.data.decode('utf-8'))
|
self.assertEqual('[]\n', req.data.decode('utf-8'))
|
||||||
|
|
||||||
# add a member
|
# add a member
|
||||||
req = self.client.post("/api/projects/raclette/members", data={
|
req = self.client.post("/api/projects/raclette/members", data={
|
||||||
|
@ -1148,7 +1148,7 @@ class APITestCase(IhatemoneyTestCase):
|
||||||
|
|
||||||
# the id of the new member should be returned
|
# the id of the new member should be returned
|
||||||
self.assertStatus(201, req)
|
self.assertStatus(201, req)
|
||||||
self.assertEqual("1", req.data.decode('utf-8'))
|
self.assertEqual("1\n", req.data.decode('utf-8'))
|
||||||
|
|
||||||
# the list of members should contain one member
|
# the list of members should contain one member
|
||||||
req = self.client.get("/api/projects/raclette/members",
|
req = self.client.get("/api/projects/raclette/members",
|
||||||
|
@ -1223,7 +1223,7 @@ class APITestCase(IhatemoneyTestCase):
|
||||||
headers=self.get_auth("raclette"))
|
headers=self.get_auth("raclette"))
|
||||||
|
|
||||||
self.assertStatus(200, req)
|
self.assertStatus(200, req)
|
||||||
self.assertEqual('[]', req.data.decode('utf-8'))
|
self.assertEqual('[]\n', req.data.decode('utf-8'))
|
||||||
|
|
||||||
def test_bills(self):
|
def test_bills(self):
|
||||||
# create a project
|
# create a project
|
||||||
|
@ -1239,7 +1239,7 @@ class APITestCase(IhatemoneyTestCase):
|
||||||
headers=self.get_auth("raclette"))
|
headers=self.get_auth("raclette"))
|
||||||
self.assertStatus(200, req)
|
self.assertStatus(200, req)
|
||||||
|
|
||||||
self.assertEqual("[]", req.data.decode('utf-8'))
|
self.assertEqual("[]\n", req.data.decode('utf-8'))
|
||||||
|
|
||||||
# add a bill
|
# add a bill
|
||||||
req = self.client.post("/api/projects/raclette/bills", data={
|
req = self.client.post("/api/projects/raclette/bills", data={
|
||||||
|
@ -1252,7 +1252,7 @@ class APITestCase(IhatemoneyTestCase):
|
||||||
|
|
||||||
# should return the id
|
# should return the id
|
||||||
self.assertStatus(201, req)
|
self.assertStatus(201, req)
|
||||||
self.assertEqual(req.data.decode('utf-8'), "1")
|
self.assertEqual(req.data.decode('utf-8'), "1\n")
|
||||||
|
|
||||||
# get this bill details
|
# get this bill details
|
||||||
req = self.client.get("/api/projects/raclette/bills/1",
|
req = self.client.get("/api/projects/raclette/bills/1",
|
||||||
|
@ -1288,7 +1288,7 @@ class APITestCase(IhatemoneyTestCase):
|
||||||
}, headers=self.get_auth("raclette"))
|
}, headers=self.get_auth("raclette"))
|
||||||
|
|
||||||
self.assertStatus(400, req)
|
self.assertStatus(400, req)
|
||||||
self.assertEqual('{"date": ["This field is required."]}', req.data.decode('utf-8'))
|
self.assertEqual('{"date": ["This field is required."]}\n', req.data.decode('utf-8'))
|
||||||
|
|
||||||
# edit a bill
|
# edit a bill
|
||||||
req = self.client.put("/api/projects/raclette/bills/1", data={
|
req = self.client.put("/api/projects/raclette/bills/1", data={
|
||||||
|
|
|
@ -3,7 +3,7 @@ import re
|
||||||
|
|
||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
from jinja2 import filters
|
from jinja2 import filters
|
||||||
from json import dumps
|
from json import dumps, JSONEncoder
|
||||||
from flask import redirect
|
from flask import redirect
|
||||||
from werkzeug.routing import HTTPException, RoutingException
|
from werkzeug.routing import HTTPException, RoutingException
|
||||||
import six
|
import six
|
||||||
|
@ -170,3 +170,28 @@ class LoginThrottler():
|
||||||
|
|
||||||
def reset(self, ip):
|
def reset(self, ip):
|
||||||
self._attempts.pop(ip, None)
|
self._attempts.pop(ip, None)
|
||||||
|
|
||||||
|
|
||||||
|
class IhmJSONEncoder(JSONEncoder):
|
||||||
|
"""Subclass of the default encoder to support custom objects.
|
||||||
|
Taken from the deprecated flask-rest package."""
|
||||||
|
def default(self, o):
|
||||||
|
if hasattr(o, "_to_serialize"):
|
||||||
|
# build up the object
|
||||||
|
data = {}
|
||||||
|
for attr in o._to_serialize:
|
||||||
|
data[attr] = getattr(o, attr)
|
||||||
|
return data
|
||||||
|
elif hasattr(o, "isoformat"):
|
||||||
|
return o.isoformat()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from flask_babel import speaklater
|
||||||
|
if isinstance(o, speaklater.LazyString):
|
||||||
|
try:
|
||||||
|
return unicode(o) # For python 2.
|
||||||
|
except NameError:
|
||||||
|
return str(o) # For python 3.
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
return JSONEncoder.default(self, o)
|
||||||
|
|
|
@ -5,7 +5,7 @@ flask-mail>=0.8
|
||||||
Flask-Migrate>=1.8.0
|
Flask-Migrate>=1.8.0
|
||||||
Flask-script
|
Flask-script
|
||||||
flask-babel
|
flask-babel
|
||||||
flask-rest>=1.3
|
flask-restful>=0.3.6
|
||||||
jinja2>=2.6
|
jinja2>=2.6
|
||||||
raven
|
raven
|
||||||
blinker
|
blinker
|
||||||
|
|
Loading…
Reference in a new issue