Integrate SQLAlchemy-Continuum to support version tracking

This commit is contained in:
Andrew Dickinson 2020-04-11 13:46:55 -04:00
parent edaab5501b
commit 5ebcc8bd51
3 changed files with 296 additions and 4 deletions

View file

@ -0,0 +1,212 @@
"""autologger
Revision ID: 2dcb0c0048dc
Revises: 6c6fb2b7f229
Create Date: 2020-04-10 18:12:41.285590
"""
# revision identifiers, used by Alembic.
revision = "2dcb0c0048dc"
down_revision = "6c6fb2b7f229"
from alembic import op
import sqlalchemy as sa
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"bill_version",
sa.Column("id", sa.Integer(), autoincrement=False, nullable=False),
sa.Column("payer_id", sa.Integer(), autoincrement=False, nullable=True),
sa.Column("amount", sa.Float(), autoincrement=False, nullable=True),
sa.Column("date", sa.Date(), autoincrement=False, nullable=True),
sa.Column("creation_date", sa.Date(), autoincrement=False, nullable=True),
sa.Column("what", sa.UnicodeText(), autoincrement=False, nullable=True),
sa.Column(
"external_link", sa.UnicodeText(), autoincrement=False, nullable=True
),
sa.Column("archive", sa.Integer(), autoincrement=False, nullable=True),
sa.Column(
"transaction_id", sa.BigInteger(), autoincrement=False, nullable=False
),
sa.Column("end_transaction_id", sa.BigInteger(), nullable=True),
sa.Column("operation_type", sa.SmallInteger(), nullable=False),
sa.PrimaryKeyConstraint("id", "transaction_id"),
)
op.create_index(
op.f("ix_bill_version_end_transaction_id"),
"bill_version",
["end_transaction_id"],
unique=False,
)
op.create_index(
op.f("ix_bill_version_operation_type"),
"bill_version",
["operation_type"],
unique=False,
)
op.create_index(
op.f("ix_bill_version_transaction_id"),
"bill_version",
["transaction_id"],
unique=False,
)
op.create_table(
"billowers_version",
sa.Column("bill_id", sa.Integer(), autoincrement=False, nullable=False),
sa.Column("person_id", sa.Integer(), autoincrement=False, nullable=False),
sa.Column(
"transaction_id", sa.BigInteger(), autoincrement=False, nullable=False
),
sa.Column("end_transaction_id", sa.BigInteger(), nullable=True),
sa.Column("operation_type", sa.SmallInteger(), nullable=False),
sa.PrimaryKeyConstraint("bill_id", "person_id", "transaction_id"),
)
op.create_index(
op.f("ix_billowers_version_end_transaction_id"),
"billowers_version",
["end_transaction_id"],
unique=False,
)
op.create_index(
op.f("ix_billowers_version_operation_type"),
"billowers_version",
["operation_type"],
unique=False,
)
op.create_index(
op.f("ix_billowers_version_transaction_id"),
"billowers_version",
["transaction_id"],
unique=False,
)
op.create_table(
"person_version",
sa.Column("id", sa.Integer(), autoincrement=False, nullable=False),
sa.Column(
"project_id", sa.String(length=64), autoincrement=False, nullable=True
),
sa.Column("name", sa.UnicodeText(), autoincrement=False, nullable=True),
sa.Column("weight", sa.Float(), autoincrement=False, nullable=True),
sa.Column("activated", sa.Boolean(), autoincrement=False, nullable=True),
sa.Column(
"transaction_id", sa.BigInteger(), autoincrement=False, nullable=False
),
sa.Column("end_transaction_id", sa.BigInteger(), nullable=True),
sa.Column("operation_type", sa.SmallInteger(), nullable=False),
sa.PrimaryKeyConstraint("id", "transaction_id"),
)
op.create_index(
op.f("ix_person_version_end_transaction_id"),
"person_version",
["end_transaction_id"],
unique=False,
)
op.create_index(
op.f("ix_person_version_operation_type"),
"person_version",
["operation_type"],
unique=False,
)
op.create_index(
op.f("ix_person_version_transaction_id"),
"person_version",
["transaction_id"],
unique=False,
)
op.create_table(
"project_version",
sa.Column("id", sa.String(length=64), autoincrement=False, nullable=False),
sa.Column("name", sa.UnicodeText(), autoincrement=False, nullable=True),
sa.Column(
"password", sa.String(length=128), autoincrement=False, nullable=True
),
sa.Column(
"contact_email", sa.String(length=128), autoincrement=False, nullable=True
),
sa.Column(
"logging_preference",
sa.Enum("DISABLED", "ENABLED", "RECORD_IP", name="loggingmode"),
autoincrement=False,
nullable=True,
),
sa.Column(
"transaction_id", sa.BigInteger(), autoincrement=False, nullable=False
),
sa.Column("end_transaction_id", sa.BigInteger(), nullable=True),
sa.Column("operation_type", sa.SmallInteger(), nullable=False),
sa.PrimaryKeyConstraint("id", "transaction_id"),
)
op.create_index(
op.f("ix_project_version_end_transaction_id"),
"project_version",
["end_transaction_id"],
unique=False,
)
op.create_index(
op.f("ix_project_version_operation_type"),
"project_version",
["operation_type"],
unique=False,
)
op.create_index(
op.f("ix_project_version_transaction_id"),
"project_version",
["transaction_id"],
unique=False,
)
op.create_table(
"transaction",
sa.Column("issued_at", sa.DateTime(), nullable=True),
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("remote_addr", sa.String(length=50), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.add_column(
"project",
sa.Column(
"logging_preference",
sa.Enum("DISABLED", "ENABLED", "RECORD_IP", name="loggingmode"),
nullable=True,
),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("project", "logging_preference")
op.drop_table("transaction")
op.drop_index(
op.f("ix_project_version_transaction_id"), table_name="project_version"
)
op.drop_index(
op.f("ix_project_version_operation_type"), table_name="project_version"
)
op.drop_index(
op.f("ix_project_version_end_transaction_id"), table_name="project_version"
)
op.drop_table("project_version")
op.drop_index(op.f("ix_person_version_transaction_id"), table_name="person_version")
op.drop_index(op.f("ix_person_version_operation_type"), table_name="person_version")
op.drop_index(
op.f("ix_person_version_end_transaction_id"), table_name="person_version"
)
op.drop_table("person_version")
op.drop_index(
op.f("ix_billowers_version_transaction_id"), table_name="billowers_version"
)
op.drop_index(
op.f("ix_billowers_version_operation_type"), table_name="billowers_version"
)
op.drop_index(
op.f("ix_billowers_version_end_transaction_id"), table_name="billowers_version"
)
op.drop_table("billowers_version")
op.drop_index(op.f("ix_bill_version_transaction_id"), table_name="bill_version")
op.drop_index(op.f("ix_bill_version_operation_type"), table_name="bill_version")
op.drop_index(op.f("ix_bill_version_end_transaction_id"), table_name="bill_version")
op.drop_table("bill_version")
# ### end Alembic commands ###

View file

@ -1,19 +1,72 @@
from collections import defaultdict
from datetime import datetime
import sqlalchemy
from flask_sqlalchemy import SQLAlchemy, BaseQuery
from flask import g, current_app
from debts import settle
from sqlalchemy import orm
from sqlalchemy.sql import func
from ihatemoney.utils import LoggingMode
from ihatemoney.utils import LoggingMode, get_ip_if_allowed
from itsdangerous import (
TimedJSONWebSignatureSerializer,
URLSafeSerializer,
BadSignature,
SignatureExpired,
)
from sqlalchemy_continuum import make_versioned
from sqlalchemy_continuum import VersioningManager as VersioningManager
from sqlalchemy_continuum.plugins import FlaskPlugin
def version_privacy_predicate():
"""Evaluate if the project of the current session has enabled logging."""
return g.project.logging_preference != LoggingMode.DISABLED
class ConditionalVersioningManager(VersioningManager):
"""Conditionally enable version tracking based on the given predicate."""
def __init__(self, tracking_predicate, *args, **kwargs):
"""Create version entry iff tracking_predicate() returns True."""
super().__init__(*args, **kwargs)
self.tracking_predicate = tracking_predicate
def before_flush(self, session, flush_context, instances):
if self.tracking_predicate():
return super().before_flush(session, flush_context, instances)
else:
# At least one call to unit_of_work() needs to be made against the
# session object to prevent a KeyError later. This doesn't create
# a version or transaction entry
self.unit_of_work(session)
def after_flush(self, session, flush_context):
if self.tracking_predicate():
return super().after_flush(session, flush_context)
else:
# At least one call to unit_of_work() needs to be made against the
# session object to prevent a KeyError later. This doesn't create
# a version or transaction entry
self.unit_of_work(session)
make_versioned(
user_cls=None,
manager=ConditionalVersioningManager(tracking_predicate=version_privacy_predicate),
plugins=[
FlaskPlugin(
# Redirect to our own function, which respects user preferences
# on IP address collection
remote_addr_factory=get_ip_if_allowed,
# Suppress the plugin's attempt to grab a user id,
# which imports the flask_login module (causing an error)
current_user_id_factory=lambda: None,
)
],
)
db = SQLAlchemy()
@ -23,6 +76,9 @@ class Project(db.Model):
def get_by_name(self, name):
return Project.query.filter(Project.name == name).one()
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
id = db.Column(db.String(64), primary_key=True)
name = db.Column(db.UnicodeText)
@ -304,6 +360,9 @@ class Person(db.Model):
query_class = PersonQuery
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
id = db.Column(db.Integer, primary_key=True)
project_id = db.Column(db.String(64), db.ForeignKey("project.id"))
bills = db.relationship("Bill", backref="payer")
@ -340,8 +399,8 @@ class Person(db.Model):
# We need to manually define a join table for m2m relations
billowers = db.Table(
"billowers",
db.Column("bill_id", db.Integer, db.ForeignKey("bill.id")),
db.Column("person_id", db.Integer, db.ForeignKey("person.id")),
db.Column("bill_id", db.Integer, db.ForeignKey("bill.id"), primary_key=True),
db.Column("person_id", db.Integer, db.ForeignKey("person.id"), primary_key=True),
)
@ -368,6 +427,9 @@ class Bill(db.Model):
query_class = BillQuery
# Direct SQLAlchemy-Continuum to track changes to this model
__versioned__ = {}
id = db.Column(db.Integer, primary_key=True)
payer_id = db.Column(db.Integer, db.ForeignKey("person.id"))
@ -429,3 +491,6 @@ class Archive(db.Model):
def __repr__(self):
return "<Archive>"
sqlalchemy.orm.configure_mappers()

View file

@ -10,9 +10,10 @@ import jinja2
from json import dumps, JSONEncoder
from flask import redirect, current_app
from babel import Locale
from sqlalchemy_continuum.plugins.flask import fetch_remote_addr
from werkzeug.routing import HTTPException, RoutingException
from datetime import datetime, timedelta
from flask import g
import csv
@ -284,6 +285,8 @@ class FormEnum(Enum):
class LoggingMode(FormEnum):
"""Represents a project's history preferences."""
DISABLED = 0
ENABLED = 1
RECORD_IP = 2
@ -291,3 +294,15 @@ class LoggingMode(FormEnum):
@classmethod
def default(cls):
return cls.ENABLED
def get_ip_if_allowed():
"""
Get the remote address (IP address) of the current Flask context, if the
project's privacy settings allow it. Behind the scenes, this calls back to
the FlaskPlugin from SQLAlchemy-Continuum in order to maintain forward
compatibility
"""
if g.project and g.project.logging_preference == LoggingMode.RECORD_IP:
return fetch_remote_addr()
return None