diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 7cc12003..83776142 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -18,6 +18,9 @@ db = SQLAlchemy() class Project(db.Model): + class ProjectQuery(BaseQuery): + def get_by_name(self, name): + return Project.query.filter(Project.name == name).one() id = db.Column(db.String(64), primary_key=True) @@ -26,6 +29,8 @@ class Project(db.Model): contact_email = db.Column(db.String(128)) members = db.relationship("Person", backref="project") + query_class = ProjectQuery + @property def _to_serialize(self): obj = { diff --git a/ihatemoney/tests/tests.py b/ihatemoney/tests/tests.py index 0c99bca8..a12613c1 100644 --- a/ihatemoney/tests/tests.py +++ b/ihatemoney/tests/tests.py @@ -19,6 +19,7 @@ from ihatemoney.run import create_app, db, load_configuration from ihatemoney.manage import GenerateConfig, GeneratePasswordHash, DeleteProject from ihatemoney import models from ihatemoney import utils +from sqlalchemy import orm # Unset configuration file env var if previously set os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None) @@ -2140,5 +2141,68 @@ class CommandTestCase(BaseTestCase): self.assertEqual(len(models.Project.query.all()), 0) +class ModelsTestCase(IhatemoneyTestCase): + def test_bill_pay_each(self): + + self.post_project("raclette") + + # add members + self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "tata"}) + # Add a member with a balance=0 : + self.client.post("/raclette/members/add", data={"name": "toto"}) + + # create bills + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2, 3], + "amount": "10.0", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "red wine", + "payer": 2, + "payed_for": [1], + "amount": "20", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "delicatessen", + "payer": 1, + "payed_for": [1, 2], + "amount": "10", + }, + ) + + project = models.Project.query.get_by_name(name="raclette") + alexis = models.Person.query.get_by_name(name="alexis", project=project) + alexis_bills = models.Bill.query.options( + orm.subqueryload(models.Bill.owers) + ).filter(models.Bill.owers.contains(alexis)) + for bill in alexis_bills.all(): + if bill.what == "red wine": + pay_each_expected = 20 / 2 + self.assertEqual(bill.pay_each(), pay_each_expected) + if bill.what == "fromage à raclette": + pay_each_expected = 10 / 4 + self.assertEqual(bill.pay_each(), pay_each_expected) + if bill.what == "delicatessen": + pay_each_expected = 10 / 3 + self.assertEqual(bill.pay_each(), pay_each_expected) + + if __name__ == "__main__": unittest.main()