From b0a67a5156e75daab413b6a06c17545b57a7410c Mon Sep 17 00:00:00 2001 From: Tom Roussel <21120212+TomRoussel@users.noreply.github.com> Date: Sat, 2 Mar 2024 12:52:54 +0100 Subject: [PATCH] Changed bill_type to an Enum --- ihatemoney/forms.py | 11 ++----- ...b38559992_new_bill_type_attribute_added.py | 3 +- ihatemoney/models.py | 29 ++++++++++--------- ihatemoney/tests/budget_test.py | 2 +- 4 files changed, 21 insertions(+), 24 deletions(-) diff --git a/ihatemoney/forms.py b/ihatemoney/forms.py index 6be91224..9f62688d 100644 --- a/ihatemoney/forms.py +++ b/ihatemoney/forms.py @@ -39,7 +39,7 @@ from wtforms.validators import ( ) from ihatemoney.currency_convertor import CurrencyConverter -from ihatemoney.models import Bill, LoggingMode, Person, Project +from ihatemoney.models import Bill, BillType, LoggingMode, Person, Project from ihatemoney.utils import ( em_surround, eval_arithmetic_expression, @@ -81,7 +81,6 @@ def get_billform_for(project, set_default=True, **kwargs): active_members = [(m.id, m.name) for m in project.active_members] - form.bill_type.choices = project.bill_types form.payed_for.choices = form.payer.choices = active_members form.payed_for.default = [m.id for m in project.active_members] @@ -365,7 +364,7 @@ class BillForm(FlaskForm): payed_for = SelectMultipleField( _("For whom?"), validators=[DataRequired()], coerce=int ) - bill_type = SelectField(_("Bill Type"), validators=[DataRequired()], coerce=str) + bill_type = SelectField(_("Bill Type"), validators=[DataRequired()], choices=BillType.choices()) submit = SubmitField(_("Submit")) submit2 = SubmitField(_("Submit and add a new one")) @@ -386,7 +385,7 @@ class BillForm(FlaskForm): bill.payer_id = self.payer.data bill.amount = self.amount.data bill.what = self.what.data - bill.bill_type = self.bill_type.data + bill.bill_type = BillType(self.bill_type.data) bill.external_link = self.external_link.data bill.date = self.date.data bill.owers = Person.query.get_by_ids(self.payed_for.data, project) @@ -433,10 +432,6 @@ class BillForm(FlaskForm): ) raise ValidationError(msg) - def validate_bill_type(self, field): - if (field.data, field.data) not in Project.bill_types: - raise ValidationError(_("Invalid Bill Type")) - class MemberForm(FlaskForm): name = StringField(_("Name"), validators=[DataRequired()], filters=[strip_filter]) diff --git a/ihatemoney/migrations/versions/7a9b38559992_new_bill_type_attribute_added.py b/ihatemoney/migrations/versions/7a9b38559992_new_bill_type_attribute_added.py index 14c73722..4638aba1 100644 --- a/ihatemoney/migrations/versions/7a9b38559992_new_bill_type_attribute_added.py +++ b/ihatemoney/migrations/versions/7a9b38559992_new_bill_type_attribute_added.py @@ -12,10 +12,11 @@ down_revision = "927ed575acbd" from alembic import op import sqlalchemy as sa +from ihatemoney.models import BillType def upgrade(): - op.add_column("bill", sa.Column("bill_type", sa.UnicodeText())) + op.add_column("bill", sa.Column("bill_type", sa.Enum(BillType))) op.add_column("bill_version", sa.Column("bill_type", sa.UnicodeText())) diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 97a67cb3..a868772d 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -21,7 +21,7 @@ from sqlalchemy_continuum.plugins import FlaskPlugin from ihatemoney.currency_convertor import CurrencyConverter from ihatemoney.monkeypath_continuum import PatchedTransactionFactory -from ihatemoney.utils import generate_password_hash, get_members, same_bill +from ihatemoney.utils import generate_password_hash, get_members, same_bill, FormEnum from ihatemoney.versioning import ( ConditionalVersioningManager, LoggingMode, @@ -50,6 +50,12 @@ make_versioned( ], ) +class BillType(FormEnum): + EXPENSE = "Expense" + REIMBURSEMENT = "Reimbursement" + TRANSFER = "Transfer" + + db = SQLAlchemy() @@ -76,11 +82,6 @@ class Project(db.Model): query_class = ProjectQuery default_currency = db.Column(db.String(3)) - bill_types = [ - ("Expense", "Expense"), - ("Reimbursement", "Reimbursement"), - ("Transfer", "Transfer"), - ] @property def _to_serialize(self): @@ -123,12 +124,12 @@ class Project(db.Model): for bill in self.get_bills_unordered().all(): total_weight = sum(ower.weight for ower in bill.owers) - if bill.bill_type == "Expense": + if bill.bill_type == BillType.EXPENSE: should_receive[bill.payer.id] += bill.converted_amount for ower in bill.owers: should_pay[ower.id] += (ower.weight * bill.converted_amount / total_weight) - if bill.bill_type == "Reimbursement": + if bill.bill_type == BillType.REIMBURSEMENT: should_receive[bill.payer.id] += bill.converted_amount for ower in bill.owers: should_receive[ower.id] -= bill.converted_amount @@ -174,7 +175,7 @@ class Project(db.Model): """ monthly = defaultdict(lambda: defaultdict(float)) for bill in self.get_bills_unordered().all(): - if bill.bill_type == "Expense": + if bill.bill_type == BillType.EXPENSE: monthly[bill.date.year][bill.date.month] += bill.converted_amount return monthly @@ -351,7 +352,7 @@ class Project(db.Model): pretty_bills.append( { "what": bill.what, - "bill_type": bill.bill_type, + "bill_type": bill.bill_type.value, "amount": round(bill.amount, 2), "currency": bill.original_currency, "date": str(bill.date), @@ -695,7 +696,7 @@ class Bill(db.Model): date = db.Column(db.Date, default=datetime.datetime.now) creation_date = db.Column(db.Date, default=datetime.datetime.now) what = db.Column(db.UnicodeText) - bill_type = db.Column(db.UnicodeText) + bill_type = db.Column(db.Enum(BillType)) external_link = db.Column(db.UnicodeText) original_currency = db.Column(db.String(3)) @@ -715,7 +716,7 @@ class Bill(db.Model): payer_id: int = None, project_default_currency: str = "", what: str = "", - bill_type: str = "", + bill_type: str = "Expense", ): super().__init__() self.amount = amount @@ -725,7 +726,7 @@ class Bill(db.Model): self.owers = owers self.payer_id = payer_id self.what = what - self.bill_type = bill_type + self.bill_type = BillType(bill_type) self.converted_amount = self.currency_helper.exchange_currency( self.amount, self.original_currency, project_default_currency ) @@ -740,7 +741,7 @@ class Bill(db.Model): "date": self.date, "creation_date": self.creation_date, "what": self.what, - "bill_type": self.bill_type, + "bill_type": self.bill_type.value, "external_link": self.external_link, "original_currency": self.original_currency, "converted_amount": self.converted_amount, diff --git a/ihatemoney/tests/budget_test.py b/ihatemoney/tests/budget_test.py index 3fa4d5ef..9ddeec8a 100644 --- a/ihatemoney/tests/budget_test.py +++ b/ihatemoney/tests/budget_test.py @@ -1493,7 +1493,7 @@ class TestBudget(IhatemoneyTestCase): #test if theres a new one with bill_type reimbursement bill = project.get_newest_bill() - assert bill.bill_type == "Reimbursement" + assert bill.bill_type == models.BillType.REIMBURSEMENT return def test_settle_zero(self):