From 5ebcc8bd512d0ffa457f910f71bbeda0e12739a7 Mon Sep 17 00:00:00 2001 From: Andrew Dickinson Date: Sat, 11 Apr 2020 13:46:55 -0400 Subject: [PATCH] Integrate SQLAlchemy-Continuum to support version tracking --- .../versions/2dcb0c0048dc_autologger.py | 212 ++++++++++++++++++ ihatemoney/models.py | 71 +++++- ihatemoney/utils.py | 17 +- 3 files changed, 296 insertions(+), 4 deletions(-) create mode 100644 ihatemoney/migrations/versions/2dcb0c0048dc_autologger.py diff --git a/ihatemoney/migrations/versions/2dcb0c0048dc_autologger.py b/ihatemoney/migrations/versions/2dcb0c0048dc_autologger.py new file mode 100644 index 00000000..f17025ec --- /dev/null +++ b/ihatemoney/migrations/versions/2dcb0c0048dc_autologger.py @@ -0,0 +1,212 @@ +"""autologger + +Revision ID: 2dcb0c0048dc +Revises: 6c6fb2b7f229 +Create Date: 2020-04-10 18:12:41.285590 + +""" + +# revision identifiers, used by Alembic. +revision = "2dcb0c0048dc" +down_revision = "6c6fb2b7f229" + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "bill_version", + sa.Column("id", sa.Integer(), autoincrement=False, nullable=False), + sa.Column("payer_id", sa.Integer(), autoincrement=False, nullable=True), + sa.Column("amount", sa.Float(), autoincrement=False, nullable=True), + sa.Column("date", sa.Date(), autoincrement=False, nullable=True), + sa.Column("creation_date", sa.Date(), autoincrement=False, nullable=True), + sa.Column("what", sa.UnicodeText(), autoincrement=False, nullable=True), + sa.Column( + "external_link", sa.UnicodeText(), autoincrement=False, nullable=True + ), + sa.Column("archive", sa.Integer(), autoincrement=False, nullable=True), + sa.Column( + "transaction_id", sa.BigInteger(), autoincrement=False, nullable=False + ), + sa.Column("end_transaction_id", sa.BigInteger(), nullable=True), + sa.Column("operation_type", sa.SmallInteger(), nullable=False), + sa.PrimaryKeyConstraint("id", "transaction_id"), + ) + op.create_index( + op.f("ix_bill_version_end_transaction_id"), + "bill_version", + ["end_transaction_id"], + unique=False, + ) + op.create_index( + op.f("ix_bill_version_operation_type"), + "bill_version", + ["operation_type"], + unique=False, + ) + op.create_index( + op.f("ix_bill_version_transaction_id"), + "bill_version", + ["transaction_id"], + unique=False, + ) + op.create_table( + "billowers_version", + sa.Column("bill_id", sa.Integer(), autoincrement=False, nullable=False), + sa.Column("person_id", sa.Integer(), autoincrement=False, nullable=False), + sa.Column( + "transaction_id", sa.BigInteger(), autoincrement=False, nullable=False + ), + sa.Column("end_transaction_id", sa.BigInteger(), nullable=True), + sa.Column("operation_type", sa.SmallInteger(), nullable=False), + sa.PrimaryKeyConstraint("bill_id", "person_id", "transaction_id"), + ) + op.create_index( + op.f("ix_billowers_version_end_transaction_id"), + "billowers_version", + ["end_transaction_id"], + unique=False, + ) + op.create_index( + op.f("ix_billowers_version_operation_type"), + "billowers_version", + ["operation_type"], + unique=False, + ) + op.create_index( + op.f("ix_billowers_version_transaction_id"), + "billowers_version", + ["transaction_id"], + unique=False, + ) + op.create_table( + "person_version", + sa.Column("id", sa.Integer(), autoincrement=False, nullable=False), + sa.Column( + "project_id", sa.String(length=64), autoincrement=False, nullable=True + ), + sa.Column("name", sa.UnicodeText(), autoincrement=False, nullable=True), + sa.Column("weight", sa.Float(), autoincrement=False, nullable=True), + sa.Column("activated", sa.Boolean(), autoincrement=False, nullable=True), + sa.Column( + "transaction_id", sa.BigInteger(), autoincrement=False, nullable=False + ), + sa.Column("end_transaction_id", sa.BigInteger(), nullable=True), + sa.Column("operation_type", sa.SmallInteger(), nullable=False), + sa.PrimaryKeyConstraint("id", "transaction_id"), + ) + op.create_index( + op.f("ix_person_version_end_transaction_id"), + "person_version", + ["end_transaction_id"], + unique=False, + ) + op.create_index( + op.f("ix_person_version_operation_type"), + "person_version", + ["operation_type"], + unique=False, + ) + op.create_index( + op.f("ix_person_version_transaction_id"), + "person_version", + ["transaction_id"], + unique=False, + ) + op.create_table( + "project_version", + sa.Column("id", sa.String(length=64), autoincrement=False, nullable=False), + sa.Column("name", sa.UnicodeText(), autoincrement=False, nullable=True), + sa.Column( + "password", sa.String(length=128), autoincrement=False, nullable=True + ), + sa.Column( + "contact_email", sa.String(length=128), autoincrement=False, nullable=True + ), + sa.Column( + "logging_preference", + sa.Enum("DISABLED", "ENABLED", "RECORD_IP", name="loggingmode"), + autoincrement=False, + nullable=True, + ), + sa.Column( + "transaction_id", sa.BigInteger(), autoincrement=False, nullable=False + ), + sa.Column("end_transaction_id", sa.BigInteger(), nullable=True), + sa.Column("operation_type", sa.SmallInteger(), nullable=False), + sa.PrimaryKeyConstraint("id", "transaction_id"), + ) + op.create_index( + op.f("ix_project_version_end_transaction_id"), + "project_version", + ["end_transaction_id"], + unique=False, + ) + op.create_index( + op.f("ix_project_version_operation_type"), + "project_version", + ["operation_type"], + unique=False, + ) + op.create_index( + op.f("ix_project_version_transaction_id"), + "project_version", + ["transaction_id"], + unique=False, + ) + op.create_table( + "transaction", + sa.Column("issued_at", sa.DateTime(), nullable=True), + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("remote_addr", sa.String(length=50), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.add_column( + "project", + sa.Column( + "logging_preference", + sa.Enum("DISABLED", "ENABLED", "RECORD_IP", name="loggingmode"), + nullable=True, + ), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("project", "logging_preference") + op.drop_table("transaction") + op.drop_index( + op.f("ix_project_version_transaction_id"), table_name="project_version" + ) + op.drop_index( + op.f("ix_project_version_operation_type"), table_name="project_version" + ) + op.drop_index( + op.f("ix_project_version_end_transaction_id"), table_name="project_version" + ) + op.drop_table("project_version") + op.drop_index(op.f("ix_person_version_transaction_id"), table_name="person_version") + op.drop_index(op.f("ix_person_version_operation_type"), table_name="person_version") + op.drop_index( + op.f("ix_person_version_end_transaction_id"), table_name="person_version" + ) + op.drop_table("person_version") + op.drop_index( + op.f("ix_billowers_version_transaction_id"), table_name="billowers_version" + ) + op.drop_index( + op.f("ix_billowers_version_operation_type"), table_name="billowers_version" + ) + op.drop_index( + op.f("ix_billowers_version_end_transaction_id"), table_name="billowers_version" + ) + op.drop_table("billowers_version") + op.drop_index(op.f("ix_bill_version_transaction_id"), table_name="bill_version") + op.drop_index(op.f("ix_bill_version_operation_type"), table_name="bill_version") + op.drop_index(op.f("ix_bill_version_end_transaction_id"), table_name="bill_version") + op.drop_table("bill_version") + # ### end Alembic commands ### diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 8a823f2e..00621091 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -1,19 +1,72 @@ from collections import defaultdict from datetime import datetime + +import sqlalchemy from flask_sqlalchemy import SQLAlchemy, BaseQuery from flask import g, current_app from debts import settle from sqlalchemy import orm from sqlalchemy.sql import func -from ihatemoney.utils import LoggingMode +from ihatemoney.utils import LoggingMode, get_ip_if_allowed from itsdangerous import ( TimedJSONWebSignatureSerializer, URLSafeSerializer, BadSignature, SignatureExpired, ) +from sqlalchemy_continuum import make_versioned +from sqlalchemy_continuum import VersioningManager as VersioningManager +from sqlalchemy_continuum.plugins import FlaskPlugin + + +def version_privacy_predicate(): + """Evaluate if the project of the current session has enabled logging.""" + return g.project.logging_preference != LoggingMode.DISABLED + + +class ConditionalVersioningManager(VersioningManager): + """Conditionally enable version tracking based on the given predicate.""" + + def __init__(self, tracking_predicate, *args, **kwargs): + """Create version entry iff tracking_predicate() returns True.""" + super().__init__(*args, **kwargs) + self.tracking_predicate = tracking_predicate + + def before_flush(self, session, flush_context, instances): + if self.tracking_predicate(): + return super().before_flush(session, flush_context, instances) + else: + # At least one call to unit_of_work() needs to be made against the + # session object to prevent a KeyError later. This doesn't create + # a version or transaction entry + self.unit_of_work(session) + + def after_flush(self, session, flush_context): + if self.tracking_predicate(): + return super().after_flush(session, flush_context) + else: + # At least one call to unit_of_work() needs to be made against the + # session object to prevent a KeyError later. This doesn't create + # a version or transaction entry + self.unit_of_work(session) + + +make_versioned( + user_cls=None, + manager=ConditionalVersioningManager(tracking_predicate=version_privacy_predicate), + plugins=[ + FlaskPlugin( + # Redirect to our own function, which respects user preferences + # on IP address collection + remote_addr_factory=get_ip_if_allowed, + # Suppress the plugin's attempt to grab a user id, + # which imports the flask_login module (causing an error) + current_user_id_factory=lambda: None, + ) + ], +) db = SQLAlchemy() @@ -23,6 +76,9 @@ class Project(db.Model): def get_by_name(self, name): return Project.query.filter(Project.name == name).one() + # Direct SQLAlchemy-Continuum to track changes to this model + __versioned__ = {} + id = db.Column(db.String(64), primary_key=True) name = db.Column(db.UnicodeText) @@ -304,6 +360,9 @@ class Person(db.Model): query_class = PersonQuery + # Direct SQLAlchemy-Continuum to track changes to this model + __versioned__ = {} + id = db.Column(db.Integer, primary_key=True) project_id = db.Column(db.String(64), db.ForeignKey("project.id")) bills = db.relationship("Bill", backref="payer") @@ -340,8 +399,8 @@ class Person(db.Model): # We need to manually define a join table for m2m relations billowers = db.Table( "billowers", - db.Column("bill_id", db.Integer, db.ForeignKey("bill.id")), - db.Column("person_id", db.Integer, db.ForeignKey("person.id")), + db.Column("bill_id", db.Integer, db.ForeignKey("bill.id"), primary_key=True), + db.Column("person_id", db.Integer, db.ForeignKey("person.id"), primary_key=True), ) @@ -368,6 +427,9 @@ class Bill(db.Model): query_class = BillQuery + # Direct SQLAlchemy-Continuum to track changes to this model + __versioned__ = {} + id = db.Column(db.Integer, primary_key=True) payer_id = db.Column(db.Integer, db.ForeignKey("person.id")) @@ -429,3 +491,6 @@ class Archive(db.Model): def __repr__(self): return "" + + +sqlalchemy.orm.configure_mappers() diff --git a/ihatemoney/utils.py b/ihatemoney/utils.py index ad2a40b8..9d6f7929 100644 --- a/ihatemoney/utils.py +++ b/ihatemoney/utils.py @@ -10,9 +10,10 @@ import jinja2 from json import dumps, JSONEncoder from flask import redirect, current_app from babel import Locale +from sqlalchemy_continuum.plugins.flask import fetch_remote_addr from werkzeug.routing import HTTPException, RoutingException from datetime import datetime, timedelta - +from flask import g import csv @@ -284,6 +285,8 @@ class FormEnum(Enum): class LoggingMode(FormEnum): + """Represents a project's history preferences.""" + DISABLED = 0 ENABLED = 1 RECORD_IP = 2 @@ -291,3 +294,15 @@ class LoggingMode(FormEnum): @classmethod def default(cls): return cls.ENABLED + + +def get_ip_if_allowed(): + """ + Get the remote address (IP address) of the current Flask context, if the + project's privacy settings allow it. Behind the scenes, this calls back to + the FlaskPlugin from SQLAlchemy-Continuum in order to maintain forward + compatibility + """ + if g.project and g.project.logging_preference == LoggingMode.RECORD_IP: + return fetch_remote_addr() + return None