ihatemoney/ihatemoney/models.py
2019-12-08 14:03:20 -05:00

413 lines
13 KiB
Python

from collections import defaultdict
from datetime import datetime
from flask_sqlalchemy import SQLAlchemy, BaseQuery
from flask import g, current_app
from debts import settle
from sqlalchemy import orm
from itsdangerous import (
TimedJSONWebSignatureSerializer,
URLSafeSerializer,
BadSignature,
SignatureExpired,
)
db = SQLAlchemy()
class Project(db.Model):
id = db.Column(db.String(64), primary_key=True)
name = db.Column(db.UnicodeText)
password = db.Column(db.String(128))
contact_email = db.Column(db.String(128))
members = db.relationship("Person", backref="project")
default_currency = db.Column(db.String(3))
@property
def _to_serialize(self):
obj = {
"id": self.id,
"name": self.name,
"contact_email": self.contact_email,
"members": [],
"default_currency": self.default_currency,
}
balance = self.balance
for member in self.members:
member_obj = member._to_serialize
member_obj["balance"] = balance.get(member.id, 0)
obj["members"].append(member_obj)
return obj
@property
def active_members(self):
return [m for m in self.members if m.activated]
@property
def balance(self):
balances, should_pay, should_receive = (defaultdict(int) for time in (1, 2, 3))
# for each person
for person in self.members:
# get the list of bills he has to pay
bills = Bill.query.options(orm.subqueryload(Bill.owers)).filter(
Bill.owers.contains(person)
)
for bill in bills.all():
if person != bill.payer:
share = bill.pay_each() * person.weight
should_pay[person] += share
should_receive[bill.payer] += share
for person in self.members:
balance = should_receive[person] - should_pay[person]
balances[person.id] = balance
return balances
@property
def members_stats(self):
"""Compute what each member has paid
:return: one stat dict per member
:rtype list:
"""
return [
{
"member": member,
"paid": sum(
[bill.amount for bill in self.get_member_bills(member.id).all()]
),
"spent": sum(
[
bill.pay_each() * member.weight
for bill in self.get_bills().all()
if member in bill.owers
]
),
"balance": self.balance[member.id],
}
for member in self.active_members
]
@property
def uses_weights(self):
return len([i for i in self.members if i.weight != 1]) > 0
def get_transactions_to_settle_bill(self, pretty_output=False):
"""Return a list of transactions that could be made to settle the bill"""
def prettify(transactions, pretty_output):
""" Return pretty transactions
"""
if not pretty_output:
return transactions
pretty_transactions = []
for transaction in transactions:
pretty_transactions.append(
{
"ower": transaction["ower"].name,
"receiver": transaction["receiver"].name,
"amount": round(transaction["amount"], 2),
}
)
return pretty_transactions
# cache value for better performance
members = {person.id: person for person in self.members}
settle_plan = settle(self.balance.items()) or []
transactions = [
{
"ower": members[ower_id],
"receiver": members[receiver_id],
"amount": amount,
}
for ower_id, amount, receiver_id in settle_plan
]
return prettify(transactions, pretty_output)
def exactmatch(self, credit, debts):
"""Recursively try and find subsets of 'debts' whose sum is equal to credit"""
if not debts:
return None
if debts[0]["balance"] > credit:
return self.exactmatch(credit, debts[1:])
elif debts[0]["balance"] == credit:
return [debts[0]]
else:
match = self.exactmatch(credit - debts[0]["balance"], debts[1:])
if match:
match.append(debts[0])
else:
match = self.exactmatch(credit, debts[1:])
return match
def has_bills(self):
"""return if the project do have bills or not"""
return self.get_bills().count() > 0
def get_bills(self):
"""Return the list of bills related to this project"""
return (
Bill.query.join(Person, Project)
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
.filter(Project.id == self.id)
.order_by(Bill.date.desc())
.order_by(Bill.creation_date.desc())
.order_by(Bill.id.desc())
)
def get_member_bills(self, member_id):
"""Return the list of bills related to a specific member"""
return (
Bill.query.join(Person, Project)
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
.filter(Person.id == member_id)
.filter(Project.id == self.id)
.order_by(Bill.date.desc())
.order_by(Bill.id.desc())
)
def get_pretty_bills(self, export_format="json"):
"""Return a list of project's bills with pretty formatting"""
bills = self.get_bills()
pretty_bills = []
for bill in bills:
if export_format == "json":
owers = [ower.name for ower in bill.owers]
else:
owers = ", ".join([ower.name for ower in bill.owers])
pretty_bills.append(
{
"what": bill.what,
"amount": round(bill.amount, 2),
"date": str(bill.date),
"payer_name": Person.query.get(bill.payer_id).name,
"payer_weight": Person.query.get(bill.payer_id).weight,
"owers": owers,
}
)
return pretty_bills
def remove_member(self, member_id):
"""Remove a member from the project.
If the member is not bound to a bill, then he is deleted, otherwise
he is only deactivated.
This method returns the status DELETED or DEACTIVATED regarding the
changes made.
"""
try:
person = Person.query.get(member_id, self)
except orm.exc.NoResultFound:
return None
if not person.has_bills():
db.session.delete(person)
db.session.commit()
else:
person.activated = False
db.session.commit()
return person
def remove_project(self):
db.session.delete(self)
db.session.commit()
def generate_token(self, expiration=0):
"""Generate a timed and serialized JsonWebToken
:param expiration: Token expiration time (in seconds)
"""
if expiration:
serializer = TimedJSONWebSignatureSerializer(
current_app.config["SECRET_KEY"], expiration
)
token = serializer.dumps({"project_id": self.id}).decode("utf-8")
else:
serializer = URLSafeSerializer(current_app.config["SECRET_KEY"])
token = serializer.dumps({"project_id": self.id})
return token
@staticmethod
def verify_token(token, token_type="timed_token"):
"""Return the project id associated to the provided token,
None if the provided token is expired or not valid.
:param token: Serialized TimedJsonWebToken
"""
if token_type == "timed_token":
serializer = TimedJSONWebSignatureSerializer(
current_app.config["SECRET_KEY"]
)
else:
serializer = URLSafeSerializer(current_app.config["SECRET_KEY"])
try:
data = serializer.loads(token)
except SignatureExpired:
return None
except BadSignature:
return None
return data["project_id"]
def __repr__(self):
return "<Project %s>" % self.name
class Person(db.Model):
class PersonQuery(BaseQuery):
def get_by_name(self, name, project):
return (
Person.query.filter(Person.name == name)
.filter(Project.id == project.id)
.one()
)
def get(self, id, project=None):
if not project:
project = g.project
return (
Person.query.filter(Person.id == id)
.filter(Project.id == project.id)
.one()
)
query_class = PersonQuery
id = db.Column(db.Integer, primary_key=True)
project_id = db.Column(db.String(64), db.ForeignKey("project.id"))
bills = db.relationship("Bill", backref="payer")
name = db.Column(db.UnicodeText)
weight = db.Column(db.Float, default=1)
activated = db.Column(db.Boolean, default=True)
@property
def _to_serialize(self):
return {
"id": self.id,
"name": self.name,
"weight": self.weight,
"activated": self.activated,
}
def has_bills(self):
"""return if the user do have bills or not"""
bills_as_ower_number = (
db.session.query(billowers)
.filter(billowers.columns.get("person_id") == self.id)
.count()
)
return bills_as_ower_number != 0 or len(self.bills) != 0
def __str__(self):
return self.name
def __repr__(self):
return "<Person %s for project %s>" % (self.name, self.project.name)
# We need to manually define a join table for m2m relations
billowers = db.Table(
"billowers",
db.Column("bill_id", db.Integer, db.ForeignKey("bill.id")),
db.Column("person_id", db.Integer, db.ForeignKey("person.id")),
)
class Bill(db.Model):
class BillQuery(BaseQuery):
def get(self, project, id):
try:
return (
self.join(Person, Project)
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
.filter(Project.id == project.id)
.filter(Bill.id == id)
.one()
)
except orm.exc.NoResultFound:
return None
def delete(self, project, id):
bill = self.get(project, id)
if bill:
db.session.delete(bill)
return bill
query_class = BillQuery
id = db.Column(db.Integer, primary_key=True)
payer_id = db.Column(db.Integer, db.ForeignKey("person.id"))
owers = db.relationship(Person, secondary=billowers)
amount = db.Column(db.Float)
date = db.Column(db.Date, default=datetime.now)
creation_date = db.Column(db.Date, default=datetime.now)
what = db.Column(db.UnicodeText)
external_link = db.Column(db.UnicodeText)
original_currency = db.Column(db.String(3))
original_amount = db.Column(db.Float)
archive = db.Column(db.Integer, db.ForeignKey("archive.id"))
@property
def _to_serialize(self):
return {
"id": self.id,
"payer_id": self.payer_id,
"owers": self.owers,
"amount": self.amount,
"date": self.date,
"creation_date": self.creation_date,
"what": self.what,
"external_link": self.external_link,
}
def pay_each(self):
"""Compute what each share has to pay"""
if self.owers:
# FIXME: SQL might do that more efficiently
weights = sum(i.weight for i in self.owers)
return self.amount / weights
else:
return 0
def __repr__(self):
return "<Bill of %s from %s for %s>" % (
self.amount,
self.payer,
", ".join([o.name for o in self.owers]),
)
class Archive(db.Model):
id = db.Column(db.Integer, primary_key=True)
project_id = db.Column(db.String(64), db.ForeignKey("project.id"))
name = db.Column(db.UnicodeText)
@property
def start_date(self):
pass
@property
def end_date(self):
pass
def __repr__(self):
return "<Archive>"