From 9b3dfa506ccf775a9a0927aed0e3e5f466dc4352 Mon Sep 17 00:00:00 2001 From: Andrew Dickinson Date: Sun, 12 Apr 2020 21:03:40 -0400 Subject: [PATCH] Refactor versioning classes to avoid circular dependencies --- ihatemoney/forms.py | 4 +- ihatemoney/models.py | 44 +++++-------------- ihatemoney/utils.py | 26 ----------- ihatemoney/versioning.py | 94 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 61 deletions(-) create mode 100644 ihatemoney/versioning.py diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py index f3aa6162..9d4b188c 100644 --- a/ihatemoney/forms.py +++ b/ihatemoney/forms.py @@ -22,8 +22,8 @@ from jinja2 import Markup import email_validator -from ihatemoney.models import Project, Person -from ihatemoney.utils import slugify, eval_arithmetic_expression, LoggingMode +from ihatemoney.models import Project, Person, LoggingMode +from ihatemoney.utils import slugify, eval_arithmetic_expression def strip_filter(string): diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 25317d80..a0475bdb 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -9,7 +9,6 @@ from flask import g, current_app from debts import settle from sqlalchemy import orm from sqlalchemy.sql import func -from ihatemoney.utils import LoggingMode, get_ip_if_allowed from itsdangerous import ( TimedJSONWebSignatureSerializer, URLSafeSerializer, @@ -17,40 +16,15 @@ from itsdangerous import ( SignatureExpired, ) from sqlalchemy_continuum import make_versioned -from sqlalchemy_continuum import VersioningManager as VersioningManager from sqlalchemy_continuum.plugins import FlaskPlugin +from sqlalchemy_continuum import version_class - -def version_privacy_predicate(): - """Evaluate if the project of the current session has enabled logging.""" - return g.project.logging_preference != LoggingMode.DISABLED - - -class ConditionalVersioningManager(VersioningManager): - """Conditionally enable version tracking based on the given predicate.""" - - def __init__(self, tracking_predicate, *args, **kwargs): - """Create version entry iff tracking_predicate() returns True.""" - super().__init__(*args, **kwargs) - self.tracking_predicate = tracking_predicate - - def before_flush(self, session, flush_context, instances): - if self.tracking_predicate(): - return super().before_flush(session, flush_context, instances) - else: - # At least one call to unit_of_work() needs to be made against the - # session object to prevent a KeyError later. This doesn't create - # a version or transaction entry - self.unit_of_work(session) - - def after_flush(self, session, flush_context): - if self.tracking_predicate(): - return super().after_flush(session, flush_context) - else: - # At least one call to unit_of_work() needs to be made against the - # session object to prevent a KeyError later. This doesn't create - # a version or transaction entry - self.unit_of_work(session) +from ihatemoney.versioning import ( + LoggingMode, + ConditionalVersioningManager, + version_privacy_predicate, + get_ip_if_allowed, +) make_versioned( @@ -500,3 +474,7 @@ class Archive(db.Model): sqlalchemy.orm.configure_mappers() + +PersonVersion = version_class(Person) +ProjectVersion = version_class(Project) +BillVersion = version_class(Bill) diff --git a/ihatemoney/utils.py b/ihatemoney/utils.py index 9d6f7929..0641d1c4 100644 --- a/ihatemoney/utils.py +++ b/ihatemoney/utils.py @@ -10,10 +10,8 @@ import jinja2 from json import dumps, JSONEncoder from flask import redirect, current_app from babel import Locale -from sqlalchemy_continuum.plugins.flask import fetch_remote_addr from werkzeug.routing import HTTPException, RoutingException from datetime import datetime, timedelta -from flask import g import csv @@ -282,27 +280,3 @@ class FormEnum(Enum): def __str__(self): return str(self.value) - - -class LoggingMode(FormEnum): - """Represents a project's history preferences.""" - - DISABLED = 0 - ENABLED = 1 - RECORD_IP = 2 - - @classmethod - def default(cls): - return cls.ENABLED - - -def get_ip_if_allowed(): - """ - Get the remote address (IP address) of the current Flask context, if the - project's privacy settings allow it. Behind the scenes, this calls back to - the FlaskPlugin from SQLAlchemy-Continuum in order to maintain forward - compatibility - """ - if g.project and g.project.logging_preference == LoggingMode.RECORD_IP: - return fetch_remote_addr() - return None diff --git a/ihatemoney/versioning.py b/ihatemoney/versioning.py new file mode 100644 index 00000000..50ad6ec8 --- /dev/null +++ b/ihatemoney/versioning.py @@ -0,0 +1,94 @@ +from flask import g +from sqlalchemy.orm.attributes import get_history +from sqlalchemy_continuum import VersioningManager +from sqlalchemy_continuum.plugins.flask import fetch_remote_addr + +from ihatemoney.utils import FormEnum + + +class LoggingMode(FormEnum): + """Represents a project's history preferences.""" + + DISABLED = 0 + ENABLED = 1 + RECORD_IP = 2 + + @classmethod + def default(cls): + return cls.ENABLED + + +class ConditionalVersioningManager(VersioningManager): + """Conditionally enable version tracking based on the given predicate.""" + + def __init__(self, tracking_predicate, *args, **kwargs): + """Create version entry iff tracking_predicate() returns True.""" + super().__init__(*args, **kwargs) + self.tracking_predicate = tracking_predicate + + def before_flush(self, session, flush_context, instances): + if self.tracking_predicate(): + return super().before_flush(session, flush_context, instances) + else: + # At least one call to unit_of_work() needs to be made against the + # session object to prevent a KeyError later. This doesn't create + # a version or transaction entry + self.unit_of_work(session) + + def after_flush(self, session, flush_context): + if self.tracking_predicate(): + return super().after_flush(session, flush_context) + else: + # At least one call to unit_of_work() needs to be made against the + # session object to prevent a KeyError later. This doesn't create + # a version or transaction entry + self.unit_of_work(session) + + +def version_privacy_predicate(): + """Evaluate if the project of the current session has enabled logging.""" + logging_enabled = False + try: + if g.project.logging_preference != LoggingMode.DISABLED: + logging_enabled = True + + # If logging WAS enabled prior to this transaction, + # we log this one last transaction + old_logging_mode = get_history(g.project, "logging_preference")[2] + if old_logging_mode and old_logging_mode[0] != LoggingMode.DISABLED: + logging_enabled = True + except AttributeError: + # g.project doesn't exist, it's being created or this action is outside + # the scope of a project. Use the default logging mode to decide + if LoggingMode.default() != LoggingMode.DISABLED: + logging_enabled = True + return logging_enabled + + +def get_ip_if_allowed(): + """ + Get the remote address (IP address) of the current Flask context, if the + project's privacy settings allow it. Behind the scenes, this calls back to + the FlaskPlugin from SQLAlchemy-Continuum in order to maintain forward + compatibility + """ + ip_logging_allowed = False + try: + if g.project.logging_preference == LoggingMode.RECORD_IP: + ip_logging_allowed = True + + # If ip recording WAS enabled prior to this transaction, + # we record the IP for this one last transaction + old_logging_mode = get_history(g.project, "logging_preference")[2] + if old_logging_mode and old_logging_mode[0] == LoggingMode.RECORD_IP: + ip_logging_allowed = True + except AttributeError: + # g.project doesn't exist, it's being created or this action is outside + # the scope of a project. Use the default logging mode to decide + if LoggingMode.default() == LoggingMode.RECORD_IP: + ip_logging_allowed = True + + if ip_logging_allowed: + return fetch_remote_addr() + else: + return None