Refactor versioning classes to avoid circular dependencies

This commit is contained in:
Andrew Dickinson 2020-04-12 21:03:40 -04:00
parent ade3e26b72
commit 9b3dfa506c
4 changed files with 107 additions and 61 deletions

View file

@ -22,8 +22,8 @@ from jinja2 import Markup
import email_validator import email_validator
from ihatemoney.models import Project, Person from ihatemoney.models import Project, Person, LoggingMode
from ihatemoney.utils import slugify, eval_arithmetic_expression, LoggingMode from ihatemoney.utils import slugify, eval_arithmetic_expression
def strip_filter(string): def strip_filter(string):

View file

@ -9,7 +9,6 @@ from flask import g, current_app
from debts import settle from debts import settle
from sqlalchemy import orm from sqlalchemy import orm
from sqlalchemy.sql import func from sqlalchemy.sql import func
from ihatemoney.utils import LoggingMode, get_ip_if_allowed
from itsdangerous import ( from itsdangerous import (
TimedJSONWebSignatureSerializer, TimedJSONWebSignatureSerializer,
URLSafeSerializer, URLSafeSerializer,
@ -17,40 +16,15 @@ from itsdangerous import (
SignatureExpired, SignatureExpired,
) )
from sqlalchemy_continuum import make_versioned from sqlalchemy_continuum import make_versioned
from sqlalchemy_continuum import VersioningManager as VersioningManager
from sqlalchemy_continuum.plugins import FlaskPlugin from sqlalchemy_continuum.plugins import FlaskPlugin
from sqlalchemy_continuum import version_class
from ihatemoney.versioning import (
def version_privacy_predicate(): LoggingMode,
"""Evaluate if the project of the current session has enabled logging.""" ConditionalVersioningManager,
return g.project.logging_preference != LoggingMode.DISABLED version_privacy_predicate,
get_ip_if_allowed,
)
class ConditionalVersioningManager(VersioningManager):
"""Conditionally enable version tracking based on the given predicate."""
def __init__(self, tracking_predicate, *args, **kwargs):
"""Create version entry iff tracking_predicate() returns True."""
super().__init__(*args, **kwargs)
self.tracking_predicate = tracking_predicate
def before_flush(self, session, flush_context, instances):
if self.tracking_predicate():
return super().before_flush(session, flush_context, instances)
else:
# At least one call to unit_of_work() needs to be made against the
# session object to prevent a KeyError later. This doesn't create
# a version or transaction entry
self.unit_of_work(session)
def after_flush(self, session, flush_context):
if self.tracking_predicate():
return super().after_flush(session, flush_context)
else:
# At least one call to unit_of_work() needs to be made against the
# session object to prevent a KeyError later. This doesn't create
# a version or transaction entry
self.unit_of_work(session)
make_versioned( make_versioned(
@ -500,3 +474,7 @@ class Archive(db.Model):
sqlalchemy.orm.configure_mappers() sqlalchemy.orm.configure_mappers()
PersonVersion = version_class(Person)
ProjectVersion = version_class(Project)
BillVersion = version_class(Bill)

View file

@ -10,10 +10,8 @@ import jinja2
from json import dumps, JSONEncoder from json import dumps, JSONEncoder
from flask import redirect, current_app from flask import redirect, current_app
from babel import Locale from babel import Locale
from sqlalchemy_continuum.plugins.flask import fetch_remote_addr
from werkzeug.routing import HTTPException, RoutingException from werkzeug.routing import HTTPException, RoutingException
from datetime import datetime, timedelta from datetime import datetime, timedelta
from flask import g
import csv import csv
@ -282,27 +280,3 @@ class FormEnum(Enum):
def __str__(self): def __str__(self):
return str(self.value) return str(self.value)
class LoggingMode(FormEnum):
"""Represents a project's history preferences."""
DISABLED = 0
ENABLED = 1
RECORD_IP = 2
@classmethod
def default(cls):
return cls.ENABLED
def get_ip_if_allowed():
"""
Get the remote address (IP address) of the current Flask context, if the
project's privacy settings allow it. Behind the scenes, this calls back to
the FlaskPlugin from SQLAlchemy-Continuum in order to maintain forward
compatibility
"""
if g.project and g.project.logging_preference == LoggingMode.RECORD_IP:
return fetch_remote_addr()
return None

94
ihatemoney/versioning.py Normal file
View file

@ -0,0 +1,94 @@
from flask import g
from sqlalchemy.orm.attributes import get_history
from sqlalchemy_continuum import VersioningManager
from sqlalchemy_continuum.plugins.flask import fetch_remote_addr
from ihatemoney.utils import FormEnum
class LoggingMode(FormEnum):
"""Represents a project's history preferences."""
DISABLED = 0
ENABLED = 1
RECORD_IP = 2
@classmethod
def default(cls):
return cls.ENABLED
class ConditionalVersioningManager(VersioningManager):
"""Conditionally enable version tracking based on the given predicate."""
def __init__(self, tracking_predicate, *args, **kwargs):
"""Create version entry iff tracking_predicate() returns True."""
super().__init__(*args, **kwargs)
self.tracking_predicate = tracking_predicate
def before_flush(self, session, flush_context, instances):
if self.tracking_predicate():
return super().before_flush(session, flush_context, instances)
else:
# At least one call to unit_of_work() needs to be made against the
# session object to prevent a KeyError later. This doesn't create
# a version or transaction entry
self.unit_of_work(session)
def after_flush(self, session, flush_context):
if self.tracking_predicate():
return super().after_flush(session, flush_context)
else:
# At least one call to unit_of_work() needs to be made against the
# session object to prevent a KeyError later. This doesn't create
# a version or transaction entry
self.unit_of_work(session)
def version_privacy_predicate():
"""Evaluate if the project of the current session has enabled logging."""
logging_enabled = False
try:
if g.project.logging_preference != LoggingMode.DISABLED:
logging_enabled = True
# If logging WAS enabled prior to this transaction,
# we log this one last transaction
old_logging_mode = get_history(g.project, "logging_preference")[2]
if old_logging_mode and old_logging_mode[0] != LoggingMode.DISABLED:
logging_enabled = True
except AttributeError:
# g.project doesn't exist, it's being created or this action is outside
# the scope of a project. Use the default logging mode to decide
if LoggingMode.default() != LoggingMode.DISABLED:
logging_enabled = True
return logging_enabled
def get_ip_if_allowed():
"""
Get the remote address (IP address) of the current Flask context, if the
project's privacy settings allow it. Behind the scenes, this calls back to
the FlaskPlugin from SQLAlchemy-Continuum in order to maintain forward
compatibility
"""
ip_logging_allowed = False
try:
if g.project.logging_preference == LoggingMode.RECORD_IP:
ip_logging_allowed = True
# If ip recording WAS enabled prior to this transaction,
# we record the IP for this one last transaction
old_logging_mode = get_history(g.project, "logging_preference")[2]
if old_logging_mode and old_logging_mode[0] == LoggingMode.RECORD_IP:
ip_logging_allowed = True
except AttributeError:
# g.project doesn't exist, it's being created or this action is outside
# the scope of a project. Use the default logging mode to decide
if LoggingMode.default() == LoggingMode.RECORD_IP:
ip_logging_allowed = True
if ip_logging_allowed:
return fetch_remote_addr()
else:
return None