From b93ea4830d5290def99d597f17292a8aa5d4c090 Mon Sep 17 00:00:00 2001 From: 0livd Date: Thu, 25 Jan 2018 17:41:28 +0100 Subject: [PATCH] API: Migrate from flask-rest to flask-restful (#315) The flask-rest custom json encoder is still needed and thus was added to ihatemoney's utils. Closes #298 --- CHANGELOG.rst | 5 ++ ihatemoney/api.py | 158 +++++++++++++++++++------------------- ihatemoney/run.py | 4 +- ihatemoney/tests/tests.py | 14 ++-- ihatemoney/utils.py | 27 ++++++- requirements.txt | 2 +- 6 files changed, 121 insertions(+), 89 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index efd0648b..ea2b61e5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,11 @@ This document describes changes between each past release. 2.1 (unreleased) ---------------- +Changed +======= + +- Use flask-restful instead of deprecated flask-rest for the REST API (#315) + Fixed ===== diff --git a/ihatemoney/api.py b/ihatemoney/api.py index 827202c8..31ed06cc 100644 --- a/ihatemoney/api.py +++ b/ihatemoney/api.py @@ -1,62 +1,68 @@ # -*- coding: utf-8 -*- from flask import Blueprint, request -from flask_rest import RESTResource, need_auth +from flask_restful import Resource, Api, abort from wtforms.fields.core import BooleanField from ihatemoney.models import db, Project, Person, Bill from ihatemoney.forms import (ProjectForm, EditProjectForm, MemberForm, get_billform_for) from werkzeug.security import check_password_hash +from functools import wraps api = Blueprint("api", __name__, url_prefix="/api") +restful_api = Api(api) -def check_project(*args, **kwargs): +def need_auth(f): """Check the request for basic authentication for a given project. - Return the project if the authorization is good, False otherwise + Return the project if the authorization is good, abort the request with a 401 otherwise """ - auth = request.authorization + @wraps(f) + def wrapper(*args, **kwargs): + auth = request.authorization + project_id = kwargs.get("project_id") - # project_id should be contained in kwargs and equal to the username - if auth and "project_id" in kwargs and \ - auth.username == kwargs["project_id"]: - project = Project.query.get(auth.username) - if project and check_password_hash(project.password, auth.password): - return project - return False + if auth and project_id and auth.username == project_id: + project = Project.query.get(auth.username) + if project and check_password_hash(project.password, auth.password): + # The whole project object will be passed instead of project_id + kwargs.pop("project_id") + return f(*args, project=project, **kwargs) + abort(401) + return wrapper -class ProjectHandler(object): - - def add(self): +class ProjectsHandler(Resource): + def post(self): form = ProjectForm(meta={'csrf': False}) if form.validate(): project = form.save() db.session.add(project) db.session.commit() - return 201, project.id - return 400, form.errors + return project.id, 201 + return form.errors, 400 + + +class ProjectHandler(Resource): + method_decorators = [need_auth] - @need_auth(check_project, "project") def get(self, project): - return 200, project + return project - @need_auth(check_project, "project") def delete(self, project): db.session.delete(project) db.session.commit() - return 200, "DELETED" + return "DELETED" - @need_auth(check_project, "project") - def update(self, project): + def put(self, project): form = EditProjectForm(meta={'csrf': False}) if form.validate(): form.update(project) db.session.commit() - return 200, "UPDATED" - return 400, form.errors + return "UPDATED" + return form.errors, 400 class APIMemberForm(MemberForm): @@ -71,98 +77,92 @@ class APIMemberForm(MemberForm): return super(APIMemberForm, self).save(project, person) -class MemberHandler(object): +class MembersHandler(Resource): + method_decorators = [need_auth] - def get(self, project, member_id): - member = Person.query.get(member_id, project) - if not member or member.project != project: - return 404, "Not Found" - return 200, member + def get(self, project): + return project.members - def list(self, project): - return 200, project.members - - def add(self, project): + def post(self, project): form = MemberForm(project, meta={'csrf': False}) if form.validate(): member = Person() form.save(project, member) db.session.commit() - return 201, member.id - return 400, form.errors + return member.id, 201 + return form.errors, 400 - def update(self, project, member_id): + +class MemberHandler(Resource): + method_decorators = [need_auth] + + def get(self, project, member_id): + member = Person.query.get(member_id, project) + if not member or member.project != project: + return "Not Found", 404 + return member + + def put(self, project, member_id): form = APIMemberForm(project, meta={'csrf': False}, edit=True) if form.validate(): member = Person.query.get(member_id, project) form.save(project, member) db.session.commit() - return 200, member - return 400, form.errors + return member + return form.errors, 400 def delete(self, project, member_id): if project.remove_member(member_id): - return 200, "OK" - return 404, "Not Found" + return "OK" + return "Not Found", 404 -class BillHandler(object): +class BillsHandler(Resource): + method_decorators = [need_auth] - def get(self, project, bill_id): - bill = Bill.query.get(project, bill_id) - if not bill: - return 404, "Not Found" - return 200, bill - - def list(self, project): + def get(self, project): return project.get_bills().all() - def add(self, project): + def post(self, project): form = get_billform_for(project, True, meta={'csrf': False}) if form.validate(): bill = Bill() form.save(bill, project) db.session.add(bill) db.session.commit() - return 201, bill.id - return 400, form.errors + return bill.id, 201 + return form.errors, 400 - def update(self, project, bill_id): + +class BillHandler(Resource): + method_decorators = [need_auth] + + def get(self, project, bill_id): + bill = Bill.query.get(project, bill_id) + if not bill: + return "Not Found", 404 + return bill, 200 + + def put(self, project, bill_id): form = get_billform_for(project, True, meta={'csrf': False}) if form.validate(): bill = Bill.query.get(project, bill_id) form.save(bill, project) db.session.commit() - return 200, bill.id - return 400, form.errors + return bill.id, 200 + return form.errors, 400 def delete(self, project, bill_id): bill = Bill.query.delete(project, bill_id) db.session.commit() if not bill: - return 404, "Not Found" - return 200, "OK" + return "Not Found", 404 + return "OK", 200 -project_resource = RESTResource( - name="project", - route="/projects", - app=api, - actions=["add", "update", "delete", "get"], - handler=ProjectHandler()) - -member_resource = RESTResource( - name="member", - inject_name="project", - route="/projects//members", - app=api, - handler=MemberHandler(), - authentifier=check_project) - -bill_resource = RESTResource( - name="bill", - inject_name="project", - route="/projects//bills", - app=api, - handler=BillHandler(), - authentifier=check_project) +restful_api.add_resource(ProjectsHandler, '/projects') +restful_api.add_resource(ProjectHandler, '/projects/') +restful_api.add_resource(MembersHandler, "/projects//members") +restful_api.add_resource(MemberHandler, "/projects//members/") +restful_api.add_resource(BillsHandler, "/projects//bills") +restful_api.add_resource(BillHandler, "/projects//bills/") diff --git a/ihatemoney/run.py b/ihatemoney/run.py index e3a7c1e5..a8de26f0 100644 --- a/ihatemoney/run.py +++ b/ihatemoney/run.py @@ -11,7 +11,7 @@ from werkzeug.contrib.fixers import ProxyFix from ihatemoney.api import api from ihatemoney.models import db -from ihatemoney.utils import PrefixedWSGI, minimal_round +from ihatemoney.utils import PrefixedWSGI, minimal_round, IhmJSONEncoder from ihatemoney.web import main as web_interface from ihatemoney import default_settings @@ -68,6 +68,8 @@ def load_configuration(app, configuration=None): app.config.from_pyfile(env_var_config) else: app.config.from_pyfile('ihatemoney.cfg', silent=True) + # Configure custom JSONEncoder used by the API + app.config['RESTFUL_JSON'] = {'cls': IhmJSONEncoder} def validate_configuration(app): diff --git a/ihatemoney/tests/tests.py b/ihatemoney/tests/tests.py index d4b6d7a1..c13131c4 100644 --- a/ihatemoney/tests/tests.py +++ b/ihatemoney/tests/tests.py @@ -1053,7 +1053,7 @@ class APITestCase(IhatemoneyTestCase): }) self.assertTrue(400, resp.status_code) - self.assertEqual('{"contact_email": ["Invalid email address."]}', + self.assertEqual('{"contact_email": ["Invalid email address."]}\n', resp.data.decode('utf-8')) # create it @@ -1139,7 +1139,7 @@ class APITestCase(IhatemoneyTestCase): headers=self.get_auth("raclette")) self.assertStatus(200, req) - self.assertEqual('[]', req.data.decode('utf-8')) + self.assertEqual('[]\n', req.data.decode('utf-8')) # add a member req = self.client.post("/api/projects/raclette/members", data={ @@ -1148,7 +1148,7 @@ class APITestCase(IhatemoneyTestCase): # the id of the new member should be returned self.assertStatus(201, req) - self.assertEqual("1", req.data.decode('utf-8')) + self.assertEqual("1\n", req.data.decode('utf-8')) # the list of members should contain one member req = self.client.get("/api/projects/raclette/members", @@ -1223,7 +1223,7 @@ class APITestCase(IhatemoneyTestCase): headers=self.get_auth("raclette")) self.assertStatus(200, req) - self.assertEqual('[]', req.data.decode('utf-8')) + self.assertEqual('[]\n', req.data.decode('utf-8')) def test_bills(self): # create a project @@ -1239,7 +1239,7 @@ class APITestCase(IhatemoneyTestCase): headers=self.get_auth("raclette")) self.assertStatus(200, req) - self.assertEqual("[]", req.data.decode('utf-8')) + self.assertEqual("[]\n", req.data.decode('utf-8')) # add a bill req = self.client.post("/api/projects/raclette/bills", data={ @@ -1252,7 +1252,7 @@ class APITestCase(IhatemoneyTestCase): # should return the id self.assertStatus(201, req) - self.assertEqual(req.data.decode('utf-8'), "1") + self.assertEqual(req.data.decode('utf-8'), "1\n") # get this bill details req = self.client.get("/api/projects/raclette/bills/1", @@ -1288,7 +1288,7 @@ class APITestCase(IhatemoneyTestCase): }, headers=self.get_auth("raclette")) self.assertStatus(400, req) - self.assertEqual('{"date": ["This field is required."]}', req.data.decode('utf-8')) + self.assertEqual('{"date": ["This field is required."]}\n', req.data.decode('utf-8')) # edit a bill req = self.client.put("/api/projects/raclette/bills/1", data={ diff --git a/ihatemoney/utils.py b/ihatemoney/utils.py index 6af0112c..aaae2a08 100644 --- a/ihatemoney/utils.py +++ b/ihatemoney/utils.py @@ -3,7 +3,7 @@ import re from io import BytesIO, StringIO from jinja2 import filters -from json import dumps +from json import dumps, JSONEncoder from flask import redirect from werkzeug.routing import HTTPException, RoutingException import six @@ -170,3 +170,28 @@ class LoginThrottler(): def reset(self, ip): self._attempts.pop(ip, None) + + +class IhmJSONEncoder(JSONEncoder): + """Subclass of the default encoder to support custom objects. + Taken from the deprecated flask-rest package.""" + def default(self, o): + if hasattr(o, "_to_serialize"): + # build up the object + data = {} + for attr in o._to_serialize: + data[attr] = getattr(o, attr) + return data + elif hasattr(o, "isoformat"): + return o.isoformat() + else: + try: + from flask_babel import speaklater + if isinstance(o, speaklater.LazyString): + try: + return unicode(o) # For python 2. + except NameError: + return str(o) # For python 3. + except ImportError: + pass + return JSONEncoder.default(self, o) diff --git a/requirements.txt b/requirements.txt index 64610abd..c2fe5348 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ flask-mail>=0.8 Flask-Migrate>=1.8.0 Flask-script flask-babel -flask-rest>=1.3 +flask-restful>=0.3.6 jinja2>=2.6 raven blinker