Use SQL statement for summing up weights

* Update models: Bill.pay_each()
* Import sql func
* reformatted using black
* Added ModelsTestCase.test_bill_pay_each() in order to test the SQL query change within pay_each.
Had to add Project.ProjectQuery.get_by_name() for the test.
This commit is contained in:
DavidRThrashJr 2020-02-17 12:39:51 -05:00 committed by GitHub
parent e4f18f0600
commit 32d76178c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 2 deletions

View file

@ -6,6 +6,7 @@ from flask import g, current_app
from debts import settle from debts import settle
from sqlalchemy import orm from sqlalchemy import orm
from sqlalchemy.sql import func
from itsdangerous import ( from itsdangerous import (
TimedJSONWebSignatureSerializer, TimedJSONWebSignatureSerializer,
URLSafeSerializer, URLSafeSerializer,
@ -17,6 +18,9 @@ db = SQLAlchemy()
class Project(db.Model): class Project(db.Model):
class ProjectQuery(BaseQuery):
def get_by_name(self, name):
return Project.query.filter(Project.name == name).one()
id = db.Column(db.String(64), primary_key=True) id = db.Column(db.String(64), primary_key=True)
@ -25,6 +29,8 @@ class Project(db.Model):
contact_email = db.Column(db.String(128)) contact_email = db.Column(db.String(128))
members = db.relationship("Person", backref="project") members = db.relationship("Person", backref="project")
query_class = ProjectQuery
@property @property
def _to_serialize(self): def _to_serialize(self):
obj = { obj = {
@ -388,8 +394,11 @@ class Bill(db.Model):
def pay_each(self): def pay_each(self):
"""Compute what each share has to pay""" """Compute what each share has to pay"""
if self.owers: if self.owers:
# FIXME: SQL might do that more efficiently weights = (
weights = sum(i.weight for i in self.owers) db.session.query(func.sum(Person.weight))
.join(billowers, Bill)
.filter(Bill.id == self.id)
).scalar()
return self.amount / weights return self.amount / weights
else: else:
return 0 return 0

View file

@ -19,6 +19,7 @@ from ihatemoney.run import create_app, db, load_configuration
from ihatemoney.manage import GenerateConfig, GeneratePasswordHash, DeleteProject from ihatemoney.manage import GenerateConfig, GeneratePasswordHash, DeleteProject
from ihatemoney import models from ihatemoney import models
from ihatemoney import utils from ihatemoney import utils
from sqlalchemy import orm
# Unset configuration file env var if previously set # Unset configuration file env var if previously set
os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None) os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
@ -2140,5 +2141,68 @@ class CommandTestCase(BaseTestCase):
self.assertEqual(len(models.Project.query.all()), 0) self.assertEqual(len(models.Project.query.all()), 0)
class ModelsTestCase(IhatemoneyTestCase):
def test_bill_pay_each(self):
self.post_project("raclette")
# add members
self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
# Add a member with a balance=0 :
self.client.post("/raclette/members/add", data={"name": "toto"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3],
"amount": "10.0",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "red wine",
"payer": 2,
"payed_for": [1],
"amount": "20",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "delicatessen",
"payer": 1,
"payed_for": [1, 2],
"amount": "10",
},
)
project = models.Project.query.get_by_name(name="raclette")
alexis = models.Person.query.get_by_name(name="alexis", project=project)
alexis_bills = models.Bill.query.options(
orm.subqueryload(models.Bill.owers)
).filter(models.Bill.owers.contains(alexis))
for bill in alexis_bills.all():
if bill.what == "red wine":
pay_each_expected = 20 / 2
self.assertEqual(bill.pay_each(), pay_each_expected)
if bill.what == "fromage à raclette":
pay_each_expected = 10 / 4
self.assertEqual(bill.pay_each(), pay_each_expected)
if bill.what == "delicatessen":
pay_each_expected = 10 / 3
self.assertEqual(bill.pay_each(), pay_each_expected)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()