Integrate SQLAlchemy-Continuum to support version tracking

This commit is contained in:
Andrew Dickinson 2020-04-11 13:46:55 -04:00
parent edaab5501b
commit 5ebcc8bd51
3 changed files with 296 additions and 4 deletions

View file

@ -0,0 +1,212 @@
"""autologger
Revision ID: 2dcb0c0048dc
Revises: 6c6fb2b7f229
Create Date: 2020-04-10 18:12:41.285590
"""
# revision identifiers, used by Alembic.
revision = "2dcb0c0048dc"
down_revision = "6c6fb2b7f229"
from alembic import op
import sqlalchemy as sa
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"bill_version",
sa.Column("id", sa.Integer(), autoincrement=False, nullable=False),
sa.Column("payer_id", sa.Integer(), autoincrement=False, nullable=True),
sa.Column("amount", sa.Float(), autoincrement=False, nullable=True),
sa.Column("date", sa.Date(), autoincrement=False, nullable=True),
sa.Column("creation_date", sa.Date(), autoincrement=False, nullable=True),
sa.Column("what", sa.UnicodeText(), autoincrement=False, nullable=True),
sa.Column(
"external_link", sa.UnicodeText(), autoincrement=False, nullable=True
),
sa.Column("archive", sa.Integer(), autoincrement=False, nullable=True),
sa.Column(
"transaction_id", sa.BigInteger(), autoincrement=False, nullable=False
),
sa.Column("end_transaction_id", sa.BigInteger(), nullable=True),
sa.Column("operation_type", sa.SmallInteger(), nullable=False),
sa.PrimaryKeyConstraint("id", "transaction_id"),
)
op.create_index(
op.f("ix_bill_version_end_transaction_id"),
"bill_version",
["end_transaction_id"],
unique=False,
)
op.create_index(
op.f("ix_bill_version_operation_type"),
"bill_version",
["operation_type"],
unique=False,
)
op.create_index(
op.f("ix_bill_version_transaction_id"),
"bill_version",
["transaction_id"],
unique=False,
)
op.create_table(
"billowers_version",
sa.Column("bill_id", sa.Integer(), autoincrement=False, nullable=False),
sa.Column("person_id", sa.Integer(), autoincrement=False, nullable=False),
sa.Column(
"transaction_id", sa.BigInteger(), autoincrement=False, nullable=False
),
sa.Column("end_transaction_id", sa.BigInteger(), nullable=True),
sa.Column("operation_type", sa.SmallInteger(), nullable=False),
sa.PrimaryKeyConstraint("bill_id", "person_id", "transaction_id"),
)
op.create_index(
op.f("ix_billowers_version_end_transaction_id"),
"billowers_version",
["end_transaction_id"],
unique=False,
)
op.create_index(
op.f("ix_billowers_version_operation_type"),
"billowers_version",
["operation_type"],
unique=False,
)
op.create_index(
op.f("ix_billowers_version_transaction_id"),
"billowers_version",
["transaction_id"],
unique=False,
)
op.create_table(
"person_version",
sa.Column("id", sa.Integer(), autoincrement=False, nullable=False),
sa.Column(
"project_id", sa.String(length=64), autoincrement=False, nullable=True
),
sa.Column("name", sa.UnicodeText(), autoincrement=False, nullable=True),
sa.Column("weight", sa.Float(), autoincrement=False, nullable=True),
sa.Column("activated", sa.Boolean(), autoincrement=False, nullable=True),
sa.Column(
"transaction_id", sa.BigInteger(), autoincrement=False, nullable=False
),
sa.Column("end_transaction_id", sa.BigInteger(), nullable=True),
sa.Column("operation_type", sa.SmallInteger(), nullable=False),
sa.PrimaryKeyConstraint("id", "transaction_id"),
)
op.create_index(
op.f("ix_person_version_end_transaction_id"),
"person_version",
["end_transaction_id"],
unique=False,
)
op.create_index(
op.f("ix_person_version_operation_type"),
"person_version",
["operation_type"],
unique=False,
)
op.create_index(
op.f("ix_person_version_transaction_id"),
"person_version",
["transaction_id"],
unique=False,
)
op.create_table(
"project_version",
sa.Column("id", sa.String(length=64), autoincrement=False, nullable=False),
sa.Column("name", sa.UnicodeText(), autoincrement=False, nullable=True),
sa.Column(
"password", sa.String(length=128), autoincrement=False, nullable=True
),
sa.Column(
"contact_email", sa.String(length=128), autoincrement=False, nullable=True
),
sa.Column(
"logging_preference",
sa.Enum("DISABLED", "ENABLED", "RECORD_IP", name="loggingmode"),
autoincrement=False,
nullable=True,
),
sa.Column(
"transaction_id", sa.BigInteger(), autoincrement=False, nullable=False
),
sa.Column("end_transaction_id", sa.BigInteger(), nullable=True),
sa.Column("operation_type", sa.SmallInteger(), nullable=False),
sa.PrimaryKeyConstraint("id", "transaction_id"),
)
op.create_index(
op.f("ix_project_version_end_transaction_id"),
"project_version",
["end_transaction_id"],
unique=False,
)
op.create_index(
op.f("ix_project_version_operation_type"),
"project_version",
["operation_type"],
unique=False,
)
op.create_index(
op.f("ix_project_version_transaction_id"),
"project_version",
["transaction_id"],
unique=False,
)
op.create_table(
"transaction",
sa.Column("issued_at", sa.DateTime(), nullable=True),
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("remote_addr", sa.String(length=50), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.add_column(
"project",
sa.Column(
"logging_preference",
sa.Enum("DISABLED", "ENABLED", "RECORD_IP", name="loggingmode"),
nullable=True,
),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("project", "logging_preference")
op.drop_table("transaction")
op.drop_index(
op.f("ix_project_version_transaction_id"), table_name="project_version"
)
op.drop_index(
op.f("ix_project_version_operation_type"), table_name="project_version"
)
op.drop_index(
op.f("ix_project_version_end_transaction_id"), table_name="project_version"
)
op.drop_table("project_version")
op.drop_index(op.f("ix_person_version_transaction_id"), table_name="person_version")
op.drop_index(op.f("ix_person_version_operation_type"), table_name="person_version")
op.drop_index(
op.f("ix_person_version_end_transaction_id"), table_name="person_version"
)
op.drop_table("person_version")
op.drop_index(
op.f("ix_billowers_version_transaction_id"), table_name="billowers_version"
)
op.drop_index(
op.f("ix_billowers_version_operation_type"), table_name="billowers_version"
)
op.drop_index(
op.f("ix_billowers_version_end_transaction_id"), table_name="billowers_version"
)
op.drop_table("billowers_version")
op.drop_index(op.f("ix_bill_version_transaction_id"), table_name="bill_version")
op.drop_index(op.f("ix_bill_version_operation_type"), table_name="bill_version")
op.drop_index(op.f("ix_bill_version_end_transaction_id"), table_name="bill_version")
op.drop_table("bill_version")
# ### end Alembic commands ###

View file

@ -1,19 +1,72 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
import sqlalchemy
from flask_sqlalchemy import SQLAlchemy, BaseQuery from flask_sqlalchemy import SQLAlchemy, BaseQuery
from flask import g, current_app from flask import g, current_app
from debts import settle from debts import settle
from sqlalchemy import orm from sqlalchemy import orm
from sqlalchemy.sql import func from sqlalchemy.sql import func
from ihatemoney.utils import LoggingMode from ihatemoney.utils import LoggingMode, get_ip_if_allowed
from itsdangerous import ( from itsdangerous import (
TimedJSONWebSignatureSerializer, TimedJSONWebSignatureSerializer,
URLSafeSerializer, URLSafeSerializer,
BadSignature, BadSignature,
SignatureExpired, SignatureExpired,
) )
from sqlalchemy_continuum import make_versioned
from sqlalchemy_continuum import VersioningManager as VersioningManager
from sqlalchemy_continuum.plugins import FlaskPlugin
def version_privacy_predicate():
"""Evaluate if the project of the current session has enabled logging."""
return g.project.logging_preference != LoggingMode.DISABLED
class ConditionalVersioningManager(VersioningManager):
"""Conditionally enable version tracking based on the given predicate."""
def __init__(self, tracking_predicate, *args, **kwargs):
"""Create version entry iff tracking_predicate() returns True."""
super().__init__(*args, **kwargs)
self.tracking_predicate = tracking_predicate
def before_flush(self, session, flush_context, instances):
if self.tracking_predicate():
return super().before_flush(session, flush_context, instances)
else:
# At least one call to unit_of_work() needs to be made against the
# session object to prevent a KeyError later. This doesn't create
# a version or transaction entry
self.unit_of_work(session)
def after_flush(self, session, flush_context):
if self.tracking_predicate():
return super().after_flush(session, flush_context)
else:
# At least one call to unit_of_work() needs to be made against the
# session object to prevent a KeyError later. This doesn't create
# a version or transaction entry
self.unit_of_work(session)
make_versioned(
user_cls=None,
manager=ConditionalVersioningManager(tracking_predicate=version_privacy_predicate),
plugins=[
FlaskPlugin(
# Redirect to our own function, which respects user preferences
# on IP address collection
remote_addr_factory=get_ip_if_allowed,
# Suppress the plugin's attempt to grab a user id,
# which imports the flask_login module (causing an error)
current_user_id_factory=lambda: None,
)
],
)
db = SQLAlchemy() db = SQLAlchemy()
@ -23,6 +76,9 @@ class Project(db.Model):
def get_by_name(self, name): def get_by_name(self, name):
return Project.query.filter(Project.name == name).one() return Project.query.filter(Project.name == name).one()
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
id = db.Column(db.String(64), primary_key=True) id = db.Column(db.String(64), primary_key=True)
name = db.Column(db.UnicodeText) name = db.Column(db.UnicodeText)
@ -304,6 +360,9 @@ class Person(db.Model):
query_class = PersonQuery query_class = PersonQuery
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
project_id = db.Column(db.String(64), db.ForeignKey("project.id")) project_id = db.Column(db.String(64), db.ForeignKey("project.id"))
bills = db.relationship("Bill", backref="payer") bills = db.relationship("Bill", backref="payer")
@ -340,8 +399,8 @@ class Person(db.Model):
# We need to manually define a join table for m2m relations # We need to manually define a join table for m2m relations
billowers = db.Table( billowers = db.Table(
"billowers", "billowers",
db.Column("bill_id", db.Integer, db.ForeignKey("bill.id")), db.Column("bill_id", db.Integer, db.ForeignKey("bill.id"), primary_key=True),
db.Column("person_id", db.Integer, db.ForeignKey("person.id")), db.Column("person_id", db.Integer, db.ForeignKey("person.id"), primary_key=True),
) )
@ -368,6 +427,9 @@ class Bill(db.Model):
query_class = BillQuery query_class = BillQuery
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
payer_id = db.Column(db.Integer, db.ForeignKey("person.id")) payer_id = db.Column(db.Integer, db.ForeignKey("person.id"))
@ -429,3 +491,6 @@ class Archive(db.Model):
def __repr__(self): def __repr__(self):
return "<Archive>" return "<Archive>"
sqlalchemy.orm.configure_mappers()

View file

@ -10,9 +10,10 @@ import jinja2
from json import dumps, JSONEncoder from json import dumps, JSONEncoder
from flask import redirect, current_app from flask import redirect, current_app
from babel import Locale from babel import Locale
from sqlalchemy_continuum.plugins.flask import fetch_remote_addr
from werkzeug.routing import HTTPException, RoutingException from werkzeug.routing import HTTPException, RoutingException
from datetime import datetime, timedelta from datetime import datetime, timedelta
from flask import g
import csv import csv
@ -284,6 +285,8 @@ class FormEnum(Enum):
class LoggingMode(FormEnum): class LoggingMode(FormEnum):
"""Represents a project's history preferences."""
DISABLED = 0 DISABLED = 0
ENABLED = 1 ENABLED = 1
RECORD_IP = 2 RECORD_IP = 2
@ -291,3 +294,15 @@ class LoggingMode(FormEnum):
@classmethod @classmethod
def default(cls): def default(cls):
return cls.ENABLED return cls.ENABLED
def get_ip_if_allowed():
"""
Get the remote address (IP address) of the current Flask context, if the
project's privacy settings allow it. Behind the scenes, this calls back to
the FlaskPlugin from SQLAlchemy-Continuum in order to maintain forward
compatibility
"""
if g.project and g.project.logging_preference == LoggingMode.RECORD_IP:
return fetch_remote_addr()
return None