ihatemoney/ihatemoney/models.py
Glandos d9a4389d42 API BREAKING CHANGE the authenticate token now need project_id= param in URL
This change is introduced to have the ability to invalidate auth token with password change.
Token payload is still the same, but the key is the concatenation of SECRET_KEY project password.
To have a clean verification, we need to have the project id before loading payload, to build the serializer with the correct key (including the password).
2021-09-16 22:25:13 +02:00

617 lines
20 KiB
Python

from collections import defaultdict
from datetime import datetime
from debts import settle
from flask import current_app, g
from flask_sqlalchemy import BaseQuery, SQLAlchemy
from itsdangerous import (
BadSignature,
SignatureExpired,
URLSafeSerializer,
URLSafeTimedSerializer,
)
import sqlalchemy
from sqlalchemy import orm
from sqlalchemy.sql import func
from sqlalchemy_continuum import make_versioned, version_class
from sqlalchemy_continuum.plugins import FlaskPlugin
from werkzeug.security import generate_password_hash
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder
from ihatemoney.versioning import (
ConditionalVersioningManager,
LoggingMode,
get_ip_if_allowed,
version_privacy_predicate,
)
make_versioned(
user_cls=None,
manager=ConditionalVersioningManager(
# Conditionally Disable the versioning based on each
# project's privacy preferences
tracking_predicate=version_privacy_predicate,
# Patch in a fix to a SQLAchemy-Continuum Bug.
# See patch_sqlalchemy_continuum.py
builder=PatchedBuilder(),
),
plugins=[
FlaskPlugin(
# Redirect to our own function, which respects user preferences
# on IP address collection
remote_addr_factory=get_ip_if_allowed,
# Suppress the plugin's attempt to grab a user id,
# which imports the flask_login module (causing an error)
current_user_id_factory=lambda: None,
)
],
)
db = SQLAlchemy()
class Project(db.Model):
class ProjectQuery(BaseQuery):
def get_by_name(self, name):
return Project.query.filter(Project.name == name).one()
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
id = db.Column(db.String(64), primary_key=True)
name = db.Column(db.UnicodeText)
password = db.Column(db.String(128))
contact_email = db.Column(db.String(128))
logging_preference = db.Column(
db.Enum(LoggingMode),
default=LoggingMode.default(),
nullable=False,
server_default=LoggingMode.default().name,
)
members = db.relationship("Person", backref="project")
query_class = ProjectQuery
default_currency = db.Column(db.String(3))
@property
def _to_serialize(self):
obj = {
"id": self.id,
"name": self.name,
"contact_email": self.contact_email,
"logging_preference": self.logging_preference.value,
"members": [],
"default_currency": self.default_currency,
}
balance = self.balance
for member in self.members:
member_obj = member._to_serialize
member_obj["balance"] = balance.get(member.id, 0)
obj["members"].append(member_obj)
return obj
@property
def active_members(self):
return [m for m in self.members if m.activated]
@property
def balance(self):
balances, should_pay, should_receive = (defaultdict(int) for time in (1, 2, 3))
# for each person
for person in self.members:
# get the list of bills he has to pay
bills = Bill.query.options(orm.subqueryload(Bill.owers)).filter(
Bill.owers.contains(person)
)
for bill in bills.all():
if person != bill.payer:
share = bill.pay_each() * person.weight
should_pay[person] += share
should_receive[bill.payer] += share
for person in self.members:
balance = should_receive[person] - should_pay[person]
balances[person.id] = balance
return balances
@property
def members_stats(self):
"""Compute what each member has paid
:return: one stat dict per member
:rtype list:
"""
return [
{
"member": member,
"paid": sum(
[
bill.converted_amount
for bill in self.get_member_bills(member.id).all()
]
),
"spent": sum(
[
bill.pay_each() * member.weight
for bill in self.get_bills_unordered().all()
if member in bill.owers
]
),
"balance": self.balance[member.id],
}
for member in self.active_members
]
@property
def monthly_stats(self):
"""Compute expenses by month
:return: a dict of years mapping to a dict of months mapping to the amount
:rtype dict:
"""
monthly = defaultdict(lambda: defaultdict(float))
for bill in self.get_bills_unordered().all():
monthly[bill.date.year][bill.date.month] += bill.converted_amount
return monthly
@property
def uses_weights(self):
return len([i for i in self.members if i.weight != 1]) > 0
def get_transactions_to_settle_bill(self, pretty_output=False):
"""Return a list of transactions that could be made to settle the bill"""
def prettify(transactions, pretty_output):
"""Return pretty transactions"""
if not pretty_output:
return transactions
pretty_transactions = []
for transaction in transactions:
pretty_transactions.append(
{
"ower": transaction["ower"].name,
"receiver": transaction["receiver"].name,
"amount": round(transaction["amount"], 2),
"currency": transaction["currency"],
}
)
return pretty_transactions
# cache value for better performance
members = {person.id: person for person in self.members}
settle_plan = settle(self.balance.items()) or []
transactions = [
{
"ower": members[ower_id],
"receiver": members[receiver_id],
"amount": amount,
"currency": self.default_currency,
}
for ower_id, amount, receiver_id in settle_plan
]
return prettify(transactions, pretty_output)
def exactmatch(self, credit, debts):
"""Recursively try and find subsets of 'debts' whose sum is equal to credit"""
if not debts:
return None
if debts[0]["balance"] > credit:
return self.exactmatch(credit, debts[1:])
elif debts[0]["balance"] == credit:
return [debts[0]]
else:
match = self.exactmatch(credit - debts[0]["balance"], debts[1:])
if match:
match.append(debts[0])
else:
match = self.exactmatch(credit, debts[1:])
return match
def has_bills(self):
"""return if the project do have bills or not"""
return self.get_bills_unordered().count() > 0
def has_multiple_currencies(self):
"""Returns True if multiple currencies are used"""
# It would be more efficient to do the counting in the database,
# but this is called very rarely so we can tolerate if it's a bit
# slow. And doing this in Python is much more readable, see #784.
nb_currencies = len(
set(bill.original_currency for bill in self.get_bills_unordered())
)
return nb_currencies > 1
def get_bills_unordered(self):
"""Base query for bill list"""
return (
Bill.query.join(Person, Project)
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
.filter(Project.id == self.id)
)
def get_bills(self):
"""Return the list of bills related to this project"""
return (
self.get_bills_unordered()
.order_by(Bill.date.desc())
.order_by(Bill.creation_date.desc())
.order_by(Bill.id.desc())
)
def get_member_bills(self, member_id):
"""Return the list of bills related to a specific member"""
return (
self.get_bills_unordered()
.filter(Person.id == member_id)
.order_by(Bill.date.desc())
.order_by(Bill.id.desc())
)
def get_pretty_bills(self, export_format="json"):
"""Return a list of project's bills with pretty formatting"""
bills = self.get_bills()
pretty_bills = []
for bill in bills:
if export_format == "json":
owers = [ower.name for ower in bill.owers]
else:
owers = ", ".join([ower.name for ower in bill.owers])
pretty_bills.append(
{
"what": bill.what,
"amount": round(bill.amount, 2),
"currency": bill.original_currency,
"date": str(bill.date),
"payer_name": Person.query.get(bill.payer_id).name,
"payer_weight": Person.query.get(bill.payer_id).weight,
"owers": owers,
}
)
return pretty_bills
def switch_currency(self, new_currency):
if new_currency == self.default_currency:
return
# Update converted currency
if new_currency == CurrencyConverter.no_currency:
if self.has_multiple_currencies():
raise ValueError(f"Can't unset currency of project {self.id}")
for bill in self.get_bills_unordered():
# We are removing the currency, and we already checked that all bills
# had the same currency: it means that we can simply strip the currency
# without converting the amounts. We basically ignore the current default_currency
# Reset converted amount in case it was different from the original amount
bill.converted_amount = bill.amount
# Strip currency
bill.original_currency = CurrencyConverter.no_currency
db.session.add(bill)
else:
for bill in self.get_bills_unordered():
if bill.original_currency == CurrencyConverter.no_currency:
# Bills that were created without currency will be set to the new currency
bill.original_currency = new_currency
bill.converted_amount = bill.amount
else:
# Convert amount for others, without touching original_currency
bill.converted_amount = CurrencyConverter().exchange_currency(
bill.amount, bill.original_currency, new_currency
)
db.session.add(bill)
self.default_currency = new_currency
db.session.add(self)
db.session.commit()
def remove_member(self, member_id):
"""Remove a member from the project.
If the member is not bound to a bill, then he is deleted, otherwise
he is only deactivated.
This method returns the status DELETED or DEACTIVATED regarding the
changes made.
"""
person = Person.query.get(member_id, self)
if person is None:
return None
if not person.has_bills():
db.session.delete(person)
db.session.commit()
else:
person.activated = False
db.session.commit()
return person
def remove_project(self):
db.session.delete(self)
db.session.commit()
def generate_token(self, token_type="auth"):
"""Generate a timed and serialized JsonWebToken
:param token_type: Either "auth" for authentication (invalidated when project code changed),
or "reset" for password reset (invalidated after expiration)
"""
if token_type == "reset":
serializer = URLSafeTimedSerializer(
current_app.config["SECRET_KEY"], salt=token_type
)
token = serializer.dumps({"project_id": self.id})
else:
serializer = URLSafeSerializer(
current_app.config["SECRET_KEY"] + self.password, salt=token_type
)
token = serializer.dumps({"project_id": self.id})
return token
@staticmethod
def verify_token(token, token_type="auth", project_id=None, max_age=3600):
"""Return the project id associated to the provided token,
None if the provided token is expired or not valid.
:param token: Serialized TimedJsonWebToken
:param token_type: Either "auth" for authentication (invalidated when project code changed),
or "reset" for password reset (invalidated after expiration)
:param project_id: Project ID. Used for token_type "auth" to use the password as serializer secret key.
:param max_age: Token expiration time (in seconds). Only used with token_type "reset"
"""
loads_kwargs = {}
if token_type == "reset":
serializer = URLSafeTimedSerializer(
current_app.config["SECRET_KEY"], salt=token_type
)
loads_kwargs["max_age"] = max_age
else:
project = Project.query.get(project_id)
password = project.password if project is not None else ''
serializer = URLSafeSerializer(
current_app.config["SECRET_KEY"] + password, salt=token_type
)
try:
data = serializer.loads(token, **loads_kwargs)
except SignatureExpired:
return None
except BadSignature:
return None
data_project = data.get("project_id")
return data_project if project_id is None or data_project == project_id else None
def __str__(self):
return self.name
def __repr__(self):
return f"<Project {self.name}>"
@staticmethod
def create_demo_project():
project = Project(
id="demo",
name="demonstration",
password=generate_password_hash("demo"),
contact_email="demo@notmyidea.org",
default_currency="EUR",
)
db.session.add(project)
db.session.commit()
members = {}
for name in ("Amina", "Georg", "Alice"):
person = Person()
person.name = name
person.project = project
person.weight = 1
db.session.add(person)
members[name] = person
db.session.commit()
operations = (
("Georg", 200, ("Amina", "Georg", "Alice"), "Food shopping"),
("Alice", 20, ("Amina", "Alice"), "Beer !"),
("Amina", 50, ("Amina", "Alice", "Georg"), "AMAP"),
)
for (payer, amount, owers, subject) in operations:
bill = Bill()
bill.payer_id = members[payer].id
bill.what = subject
bill.owers = [members[name] for name in owers]
bill.amount = amount
bill.original_currency = "EUR"
bill.converted_amount = amount
db.session.add(bill)
db.session.commit()
return project
class Person(db.Model):
class PersonQuery(BaseQuery):
def get_by_name(self, name, project):
return (
Person.query.filter(Person.name == name)
.filter(Person.project_id == project.id)
.one_or_none()
)
def get(self, id, project=None):
if not project:
project = g.project
return (
Person.query.filter(Person.id == id)
.filter(Person.project_id == project.id)
.one_or_none()
)
query_class = PersonQuery
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
__table_args__ = {"sqlite_autoincrement": True}
id = db.Column(db.Integer, primary_key=True)
project_id = db.Column(db.String(64), db.ForeignKey("project.id"))
bills = db.relationship("Bill", backref="payer")
name = db.Column(db.UnicodeText)
weight = db.Column(db.Float, default=1)
activated = db.Column(db.Boolean, default=True)
@property
def _to_serialize(self):
return {
"id": self.id,
"name": self.name,
"weight": self.weight,
"activated": self.activated,
}
def has_bills(self):
"""return if the user do have bills or not"""
bills_as_ower_number = (
db.session.query(billowers)
.filter(billowers.columns.get("person_id") == self.id)
.count()
)
return bills_as_ower_number != 0 or len(self.bills) != 0
def __str__(self):
return self.name
def __repr__(self):
return f"<Person {self.name} for project {self.project.name}>"
# We need to manually define a join table for m2m relations
billowers = db.Table(
"billowers",
db.Column("bill_id", db.Integer, db.ForeignKey("bill.id"), primary_key=True),
db.Column("person_id", db.Integer, db.ForeignKey("person.id"), primary_key=True),
sqlite_autoincrement=True,
)
class Bill(db.Model):
class BillQuery(BaseQuery):
def get(self, project, id):
try:
return (
self.join(Person, Project)
.filter(Bill.payer_id == Person.id)
.filter(Person.project_id == Project.id)
.filter(Project.id == project.id)
.filter(Bill.id == id)
.one()
)
except orm.exc.NoResultFound:
return None
def delete(self, project, id):
bill = self.get(project, id)
if bill:
db.session.delete(bill)
return bill
query_class = BillQuery
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
__table_args__ = {"sqlite_autoincrement": True}
id = db.Column(db.Integer, primary_key=True)
payer_id = db.Column(db.Integer, db.ForeignKey("person.id"))
owers = db.relationship(Person, secondary=billowers)
amount = db.Column(db.Float)
date = db.Column(db.Date, default=datetime.now)
creation_date = db.Column(db.Date, default=datetime.now)
what = db.Column(db.UnicodeText)
external_link = db.Column(db.UnicodeText)
original_currency = db.Column(db.String(3))
converted_amount = db.Column(db.Float)
archive = db.Column(db.Integer, db.ForeignKey("archive.id"))
@property
def _to_serialize(self):
return {
"id": self.id,
"payer_id": self.payer_id,
"owers": self.owers,
"amount": self.amount,
"date": self.date,
"creation_date": self.creation_date,
"what": self.what,
"external_link": self.external_link,
"original_currency": self.original_currency,
"converted_amount": self.converted_amount,
}
def pay_each_default(self, amount):
"""Compute what each share has to pay"""
if self.owers:
weights = (
db.session.query(func.sum(Person.weight))
.join(billowers, Bill)
.filter(Bill.id == self.id)
).scalar()
return amount / weights
else:
return 0
def __str__(self):
return self.what
def pay_each(self):
return self.pay_each_default(self.converted_amount)
def __repr__(self):
return (
f"<Bill of {self.amount} from {self.payer} for "
f"{', '.join([o.name for o in self.owers])}>"
)
class Archive(db.Model):
id = db.Column(db.Integer, primary_key=True)
project_id = db.Column(db.String(64), db.ForeignKey("project.id"))
name = db.Column(db.UnicodeText)
@property
def start_date(self):
pass
@property
def end_date(self):
pass
def __repr__(self):
return "<Archive>"
sqlalchemy.orm.configure_mappers()
PersonVersion = version_class(Person)
ProjectVersion = version_class(Project)
BillVersion = version_class(Bill)