diff --git a/budget/api.py b/budget/api.py index 3df8ab2b..c5ae76b4 100644 --- a/budget/api.py +++ b/budget/api.py @@ -2,8 +2,8 @@ from flask import * from models import db, Project, Person, Bill -from forms import ProjectForm -from utils import for_all_methods +from forms import ProjectForm, EditProjectForm, MemberForm, BillForm +from utils import for_all_methods, get_billform_for from rest import RESTResource, need_auth# FIXME make it an ext from werkzeug import Response @@ -32,7 +32,7 @@ class ProjectHandler(object): def add(self): form = ProjectForm(csrf_enabled=False) if form.validate(): - project = form.save(Project()) + project = form.save() db.session.add(project) db.session.commit() return 201, project.id @@ -40,7 +40,7 @@ class ProjectHandler(object): @need_auth(check_project, "project") def get(self, project): - return project + return 200, project @need_auth(check_project, "project") def delete(self, project): @@ -50,9 +50,9 @@ class ProjectHandler(object): @need_auth(check_project, "project") def update(self, project): - form = ProjectForm(csrf_enabled=False) + form = EditProjectForm(csrf_enabled=False) if form.validate(): - form.save(project) + form.update(project) db.session.commit() return 200, "UPDATED" return 400, form.errors @@ -61,25 +61,25 @@ class ProjectHandler(object): class MemberHandler(object): def get(self, project, member_id): - member = Person.query.get(member_id) + member = Person.query.get(member_id, project) if not member or member.project != project: return 404, "Not Found" - return member + return 200, member def list(self, project): - return project.members + return 200, project.members def add(self, project): - form = MemberForm(csrf_enabled=False) + form = MemberForm(project, csrf_enabled=False) if form.validate(): member = Person() form.save(project, member) db.session.commit() - return 200, member.id + return 201, member.id return 400, form.errors def update(self, project, member_id): - form = MemberForm(csrf_enabled=False) + form = MemberForm(project, csrf_enabled=False) if form.validate(): member = Person.query.get(member_id, project) form.save(project, member) @@ -99,39 +99,41 @@ class BillHandler(object): bill = Bill.query.get(project, bill_id) if not bill: return 404, "Not Found" - return bill + return 200, bill def list(self, project): return project.get_bills().all() def add(self, project): - form = BillForm(csrf_enabled=False) + form = get_billform_for(project, True, csrf_enabled=False) if form.validate(): bill = Bill() - form.save(bill) + form.save(bill, project) db.session.add(bill) db.session.commit() - return 200, bill.id + return 201, bill.id return 400, form.errors def update(self, project, bill_id): - form = BillForm(csrf_enabled=False) + form = get_billform_for(project, True, csrf_enabled=False) if form.validate(): - form.save(bill) + bill = Bill.query.get(project, bill_id) + form.save(bill, project) db.session.commit() return 200, bill.id return 400, form.errors def delete(self, project, bill_id): bill = Bill.query.delete(project, bill_id) + db.session.commit() if not bill: return 404, "Not Found" - return bill + return 200, "OK" project_resource = RESTResource( name="project", - route="/project", + route="/projects", app=api, actions=["add", "update", "delete", "get"], handler=ProjectHandler()) @@ -139,7 +141,7 @@ project_resource = RESTResource( member_resource = RESTResource( name="member", inject_name="project", - route="/project//members", + route="/projects//members", app=api, handler=MemberHandler(), authentifier=check_project) @@ -147,7 +149,7 @@ member_resource = RESTResource( bill_resource = RESTResource( name="bill", inject_name="project", - route="/project//bills", + route="/projects//bills", app=api, handler=BillHandler(), authentifier=check_project) diff --git a/budget/forms.py b/budget/forms.py index 16fa0d61..25731bc6 100644 --- a/budget/forms.py +++ b/budget/forms.py @@ -95,12 +95,13 @@ class BillForm(Form): validators=[Required()], widget=select_multi_checkbox) submit = SubmitField("Send the bill") - def save(self, bill): + def save(self, bill, project): bill.payer_id=self.payer.data bill.amount=self.amount.data bill.what=self.what.data bill.date=self.date.data - bill.owers = [Person.query.get(ower) for ower in self.payed_for.data] + bill.owers = [Person.query.get(ower, project) + for ower in self.payed_for.data] return bill diff --git a/budget/models.py b/budget/models.py index 5ee7b07e..c938e978 100644 --- a/budget/models.py +++ b/budget/models.py @@ -60,14 +60,13 @@ class Project(db.Model): This method returns the status DELETED or DEACTIVATED regarding the changes made. """ - person = Person.query.get_or_404(member_id) - if person.project == self: - if not person.has_bills(): - db.session.delete(person) - db.session.commit() - else: - person.activated = False - db.session.commit() + person = Person.query.get(member_id, self) + if not person.has_bills(): + db.session.delete(person) + db.session.commit() + else: + person.activated = False + db.session.commit() return person def __repr__(self): diff --git a/budget/rest.py b/budget/rest.py index f2372173..992a61e9 100644 --- a/budget/rest.py +++ b/budget/rest.py @@ -90,7 +90,7 @@ def need_auth(authentifier, name=None, remove_attr=True): If the request is authorized, the object returned by the authentifier is added to the kwargs of the method. - If not, issue a 403 Forbidden error + If not, issue a 401 Unauthorized error :authentifier: The callable to check the context onto. @@ -114,7 +114,7 @@ def need_auth(authentifier, name=None, remove_attr=True): del kwargs["%s_id" % name] return func(*args, **kwargs) else: - return 403, "Forbidden" + return 401, "Unauthorized" return wrapped return wrapper @@ -126,7 +126,7 @@ def serialize(func): """ def wrapped(*args, **kwargs): # get the mimetype - mime = request.accept_mimetypes.best_match(SERIALIZERS.keys()) + mime = request.accept_mimetypes.best_match(SERIALIZERS.keys()) or "text/json" data = func(*args, **kwargs) serializer = SERIALIZERS[mime] diff --git a/budget/tests.py b/budget/tests.py index 4bb8e603..1541cbb4 100644 --- a/budget/tests.py +++ b/budget/tests.py @@ -2,6 +2,8 @@ import os import tempfile import unittest +import base64 +import json from flask import session @@ -13,7 +15,7 @@ class TestCase(unittest.TestCase): def setUp(self): run.app.config['TESTING'] = True - run.app.config['SQLALCHEMY_DATABASE_URI'] = "sqlite:///memory" + run.app.config['SQLALCHEMY_DATABASE_URI'] = "sqlite:///memory" run.app.config['CSRF_ENABLED'] = False # simplify the tests self.app = run.app.test_client() @@ -45,7 +47,7 @@ class TestCase(unittest.TestCase): }) def create_project(self, name): - models.db.session.add(models.Project(id=name, name=unicode(name), + models.db.session.add(models.Project(id=name, name=unicode(name), password=name, contact_email="%s@notmyidea.org" % name)) models.db.session.commit() @@ -76,7 +78,7 @@ class BudgetTestCase(TestCase): # only one message is sent to multiple persons self.assertEqual(len(outbox), 1) - self.assertEqual(outbox[0].recipients, + self.assertEqual(outbox[0].recipients, ["alexis@notmyidea.org", "toto@notmyidea.org"]) # mail address checking @@ -107,7 +109,7 @@ class BudgetTestCase(TestCase): # session is updated self.assertEqual(session['raclette'], 'party') - + # project is created self.assertEqual(len(models.Project.query.all()), 1) @@ -144,9 +146,9 @@ class BudgetTestCase(TestCase): # check fred is present in the bills page result = self.app.get("/raclette/") self.assertIn("fred", result.data) - + # remove fred - self.app.post("/raclette/members/%s/delete" % + self.app.post("/raclette/members/%s/delete" % models.Project.query.get("raclette").members[-1].id) # as fred is not bound to any bill, he is removed @@ -186,7 +188,7 @@ class BudgetTestCase(TestCase): self.assertEqual( len(models.Project.query.get("raclette").active_members), 2) - # adding an user with the same name as another user from a different + # adding an user with the same name as another user from a different # project should not cause any troubles self.post_project("randomid") self.login("randomid") @@ -198,7 +200,7 @@ class BudgetTestCase(TestCase): def test_demo(self): # Test that it is possible to connect automatically by going onto /demo with run.app.test_client() as c: - models.db.session.add(models.Project(id="demo", name=u"demonstration", + models.db.session.add(models.Project(id="demo", name=u"demonstration", password="demo", contact_email="demo@notmyidea.org")) models.db.session.commit() c.get("/demo") @@ -216,14 +218,14 @@ class BudgetTestCase(TestCase): # raclette that the login / logout process works self.create_project("raclette") - # try to see the project while not being authenticated should redirect + # try to see the project while not being authenticated should redirect # to the authentication page resp = self.app.post("/raclette", follow_redirects=True) self.assertIn("Authentication", resp.data) - + # try to connect with wrong credentials should not work with run.app.test_client() as c: - resp = c.post("/authenticate", + resp = c.post("/authenticate", data={'id': 'raclette', 'password': 'nope'}) self.assertIn("Authentication", resp.data) @@ -231,7 +233,7 @@ class BudgetTestCase(TestCase): # try to connect with the right credentials should work with run.app.test_client() as c: - resp = c.post("/authenticate", + resp = c.post("/authenticate", data={'id': 'raclette', 'password': 'raclette'}) self.assertNotIn("Authentication", resp.data) @@ -250,7 +252,7 @@ class BudgetTestCase(TestCase): self.app.post("/raclette/members/add", data={'name': 'fred' }) members_ids = [m.id for m in models.Project.query.get("raclette").members] - + # create a bill self.app.post("/raclette/add", data={ 'date': '2011-08-10', @@ -317,7 +319,7 @@ class BudgetTestCase(TestCase): 'password': 'didoudida' } - resp = self.app.post("/raclette/edit", data=new_data, + resp = self.app.post("/raclette/edit", data=new_data, follow_redirects=True) self.assertEqual(resp.status_code, 200) project = models.Project.query.get("raclette") @@ -333,5 +335,287 @@ class BudgetTestCase(TestCase): self.assertIn("Invalid email address", resp.data) +class APITestCase(TestCase): + """Tests the API""" + + def api_create(self, name, id=None, password=None, contact=None): + id = id or name + password = password or name + contact = contact or "%s@notmyidea.org" % name + + return self.app.post("/api/projects", data={ + 'name': name, + 'id': id, + 'password': password, + 'contact_email': contact + }) + + def api_add_member(self, project, name): + self.app.post("/api/projects/%s/members" % project, + data={"name": name}, headers=self.get_auth(project)) + + def get_auth(self, username, password=None): + password = password or username + base64string = base64.encodestring( + '%s:%s' % (username, password)).replace('\n', '') + return {"Authorization": "Basic %s" % base64string} + + def assertStatus(self, expected, resp, url=""): + + return self.assertEqual(expected, resp.status_code, + "%s expected %s, got %s" % (url, expected, resp.status_code)) + + def test_basic_auth(self): + # create a project + resp = self.api_create("raclette") + self.assertStatus(201, resp) + + # try to do something on it being unauth should return a 401 + resp = self.app.get("/api/projects/raclette") + self.assertStatus(401, resp) + + # PUT / POST / DELETE / GET on the different resources + # should also return a 401 + for verb in ('post',): + for resource in ("/raclette/members", "/raclette/bills"): + url = "/api/projects" + resource + self.assertStatus(401, getattr(self.app, verb)(url), + verb + resource) + + for verb in ('get', 'delete', 'put'): + for resource in ("/raclette", "/raclette/members/1", + "/raclette/bills/1"): + url = "/api/projects" + resource + + self.assertStatus(401, getattr(self.app, verb)(url), + verb + resource) + + def test_project(self): + # wrong email should return an error + resp = self.app.post("/api/projects", data={ + 'name': "raclette", + 'id': "raclette", + 'password': "raclette", + 'contact_email': "not-an-email" + }) + + self.assertTrue(400, resp.status_code) + self.assertEqual('{"contact_email": ["Invalid email address."]}', resp.data) + + # create it + resp = self.api_create("raclette") + self.assertTrue(201, resp.status_code) + + # create it twice should return a 400 + resp = self.api_create("raclette") + + self.assertTrue(400, resp.status_code) + self.assertEqual('{"id": ["This project id is already used"]}', resp.data) + # get information about it + resp = self.app.get("/api/projects/raclette", + headers=self.get_auth("raclette")) + + self.assertTrue(200, resp.status_code) + expected = { + "active_members": [], + "name": "raclette", + "contact_email": "raclette@notmyidea.org", + "members": [], + "password": "raclette", + "id": "raclette" + } + self.assertDictEqual(json.loads(resp.data), expected) + + # edit should work + resp = self.app.put("/api/projects/raclette", data = { + "contact_email": "yeah@notmyidea.org", + "password": "raclette", + "name": "The raclette party", + }, headers=self.get_auth("raclette")) + + self.assertEqual(200, resp.status_code) + + resp = self.app.get("/api/projects/raclette", + headers=self.get_auth("raclette")) + + self.assertEqual(200, resp.status_code) + expected = { + "active_members": [], + "name": "The raclette party", + "contact_email": "yeah@notmyidea.org", + "members": [], + "password": "raclette", + "id": "raclette" + } + self.assertDictEqual(json.loads(resp.data), expected) + + # delete should work + resp = self.app.delete("/api/projects/raclette", + headers=self.get_auth("raclette")) + + self.assertEqual(200, resp.status_code) + + # get should return a 401 on an unknown resource + resp = self.app.get("/api/projects/raclette", + headers=self.get_auth("raclette")) + self.assertEqual(401, resp.status_code) + + def test_member(self): + # create a project + self.api_create("raclette") + + # get the list of members (should be empty) + req = self.app.get("/api/projects/raclette/members", + headers=self.get_auth("raclette")) + + self.assertStatus(200, req) + self.assertEqual('[]', req.data) + + # add a member + req = self.app.post("/api/projects/raclette/members", data={ + "name": "Alexis" + }, headers=self.get_auth("raclette")) + + # the id of the new member should be returned + self.assertStatus(201, req) + self.assertEqual("1", req.data) + + # the list of members should contain one member + req = self.app.get("/api/projects/raclette/members", + headers=self.get_auth("raclette")) + + self.assertStatus(200, req) + self.assertEqual(len(json.loads(req.data)), 1) + + # edit this member + req = self.app.put("/api/projects/raclette/members/1", data={ + "name": "Fred" + }, headers=self.get_auth("raclette")) + + self.assertStatus(200, req) + + # get should return the new name + req = self.app.get("/api/projects/raclette/members/1", + headers=self.get_auth("raclette")) + + self.assertStatus(200, req) + self.assertEqual("Fred", json.loads(req.data)["name"]) + + # delete a member + + req = self.app.delete("/api/projects/raclette/members/1", + headers=self.get_auth("raclette")) + + self.assertStatus(200, req) + + # the list of members should be empty + # get the list of members (should be empty) + req = self.app.get("/api/projects/raclette/members", + headers=self.get_auth("raclette")) + + self.assertStatus(200, req) + self.assertEqual('[]', req.data) + + def test_bills(self): + # create a project + self.api_create("raclette") + + # add members + self.api_add_member("raclette", "alexis") + self.api_add_member("raclette", "fred") + self.api_add_member("raclette", "arnaud") + + # get the list of bills (should be empty) + req = self.app.get("/api/projects/raclette/bills", + headers=self.get_auth("raclette")) + self.assertStatus(200, req) + + self.assertEqual("[]", req.data) + + # add a bill + req = self.app.post("/api/projects/raclette/bills", data={ + 'date': '2011-08-10', + 'what': u'fromage', + 'payer': "1", + 'payed_for': ["1", "2"], + 'amount': '25', + }, headers=self.get_auth("raclette")) + + # should return the id + self.assertStatus(201, req) + self.assertEqual(req.data, "1") + + # get this bill details + req = self.app.get("/api/projects/raclette/bills/1", + headers=self.get_auth("raclette")) + + # compare with the added info + self.assertStatus(200, req) + expected = { + "what": "fromage", + "payer_id": 1, + "owers": [ + {"activated": True, "id": 1, "name": "alexis"}, + {"activated": True, "id": 2, "name": "fred"}], + "amount": 25.0, + "date": "2011-08-10", + "id": 1} + + self.assertDictEqual(expected, json.loads(req.data)) + + # the list of bills should lenght 1 + req = self.app.get("/api/projects/raclette/bills", + headers=self.get_auth("raclette")) + self.assertStatus(200, req) + self.assertEqual(1, len(json.loads(req.data))) + + # edit with errors should return an error + req = self.app.put("/api/projects/raclette/bills/1", data={ + 'date': '201111111-08-10', # not a date + 'what': u'fromage', + 'payer': "1", + 'payed_for': ["1", "2"], + 'amount': '25', + }, headers=self.get_auth("raclette")) + + self.assertStatus(400, req) + self.assertEqual('{"date": ["This field is required."]}', req.data) + + # edit a bill + req = self.app.put("/api/projects/raclette/bills/1", data={ + 'date': '2011-09-10', + 'what': u'beer', + 'payer': "2", + 'payed_for': ["1", "2"], + 'amount': '25', + }, headers=self.get_auth("raclette")) + + # check its fields + req = self.app.get("/api/projects/raclette/bills/1", + headers=self.get_auth("raclette")) + + expected = { + "what": "beer", + "payer_id": 2, + "owers": [ + {"activated": True, "id": 1, "name": "alexis"}, + {"activated": True, "id": 2, "name": "fred"}], + "amount": 25.0, + "date": "2011-09-10", + "id": 1} + + self.assertDictEqual(expected, json.loads(req.data)) + + # delete a bill + req = self.app.delete("/api/projects/raclette/bills/1", + headers=self.get_auth("raclette")) + self.assertStatus(200, req) + + # getting it should return a 404 + req = self.app.get("/api/projects/raclette/bills/1", + headers=self.get_auth("raclette")) + self.assertStatus(404, req) + + if __name__ == "__main__": unittest.main() diff --git a/budget/utils.py b/budget/utils.py index 88b85803..df165b50 100644 --- a/budget/utils.py +++ b/budget/utils.py @@ -17,6 +17,22 @@ def slugify(value): value = unicode(re.sub('[^\w\s-]', '', value).strip().lower()) return re.sub('[-\s]+', '-', value) + +def get_billform_for(project, set_default=True, **kwargs): + """Return an instance of BillForm configured for a particular project. + + :set_default: if set to True, on GET methods (usually when we want to + display the default form, it will call set_default on it. + + """ + form = BillForm(**kwargs) + form.payed_for.choices = form.payer.choices = [(str(m.id), m.name) for m in project.active_members] + form.payed_for.default = [str(m.id) for m in project.active_members] + + if set_default and request.method == "GET": + form.set_default() + return form + class Redirect303(HTTPException, RoutingException): """Raise if the map requests a redirect. This is for example the case if `strict_slashes` are activated and an url that requires a trailing slash. @@ -39,4 +55,3 @@ def for_all_methods(decorator): setattr(cls, name, decorator(method)) return cls return decorate - diff --git a/budget/web.py b/budget/web.py index 37c6415f..94c42d3c 100644 --- a/budget/web.py +++ b/budget/web.py @@ -262,7 +262,7 @@ def add_bill(): if request.method == 'POST': if form.validate(): bill = Bill() - db.session.add(form.save(bill)) + db.session.add(form.save(bill, g.project)) db.session.commit() flash("The bill has been added") @@ -295,7 +295,7 @@ def edit_bill(bill_id): form = get_billform_for(request, g.project, set_default=False) if request.method == 'POST' and form.validate(): - form.save(bill) + form.save(bill, g.project) db.session.commit() flash("The bill has been modified")