Changed bill_type to an Enum

This commit is contained in:
Tom Roussel 2024-03-02 12:52:54 +01:00
parent 8cef9492d0
commit b0a67a5156
4 changed files with 21 additions and 24 deletions

View file

@ -39,7 +39,7 @@ from wtforms.validators import (
)
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.models import Bill, LoggingMode, Person, Project
from ihatemoney.models import Bill, BillType, LoggingMode, Person, Project
from ihatemoney.utils import (
em_surround,
eval_arithmetic_expression,
@ -81,7 +81,6 @@ def get_billform_for(project, set_default=True, **kwargs):
active_members = [(m.id, m.name) for m in project.active_members]
form.bill_type.choices = project.bill_types
form.payed_for.choices = form.payer.choices = active_members
form.payed_for.default = [m.id for m in project.active_members]
@ -365,7 +364,7 @@ class BillForm(FlaskForm):
payed_for = SelectMultipleField(
_("For whom?"), validators=[DataRequired()], coerce=int
)
bill_type = SelectField(_("Bill Type"), validators=[DataRequired()], coerce=str)
bill_type = SelectField(_("Bill Type"), validators=[DataRequired()], choices=BillType.choices())
submit = SubmitField(_("Submit"))
submit2 = SubmitField(_("Submit and add a new one"))
@ -386,7 +385,7 @@ class BillForm(FlaskForm):
bill.payer_id = self.payer.data
bill.amount = self.amount.data
bill.what = self.what.data
bill.bill_type = self.bill_type.data
bill.bill_type = BillType(self.bill_type.data)
bill.external_link = self.external_link.data
bill.date = self.date.data
bill.owers = Person.query.get_by_ids(self.payed_for.data, project)
@ -433,10 +432,6 @@ class BillForm(FlaskForm):
)
raise ValidationError(msg)
def validate_bill_type(self, field):
if (field.data, field.data) not in Project.bill_types:
raise ValidationError(_("Invalid Bill Type"))
class MemberForm(FlaskForm):
name = StringField(_("Name"), validators=[DataRequired()], filters=[strip_filter])

View file

@ -12,10 +12,11 @@ down_revision = "927ed575acbd"
from alembic import op
import sqlalchemy as sa
from ihatemoney.models import BillType
def upgrade():
op.add_column("bill", sa.Column("bill_type", sa.UnicodeText()))
op.add_column("bill", sa.Column("bill_type", sa.Enum(BillType)))
op.add_column("bill_version", sa.Column("bill_type", sa.UnicodeText()))

View file

@ -21,7 +21,7 @@ from sqlalchemy_continuum.plugins import FlaskPlugin
from ihatemoney.currency_convertor import CurrencyConverter
from ihatemoney.monkeypath_continuum import PatchedTransactionFactory
from ihatemoney.utils import generate_password_hash, get_members, same_bill
from ihatemoney.utils import generate_password_hash, get_members, same_bill, FormEnum
from ihatemoney.versioning import (
ConditionalVersioningManager,
LoggingMode,
@ -50,6 +50,12 @@ make_versioned(
],
)
class BillType(FormEnum):
EXPENSE = "Expense"
REIMBURSEMENT = "Reimbursement"
TRANSFER = "Transfer"
db = SQLAlchemy()
@ -76,11 +82,6 @@ class Project(db.Model):
query_class = ProjectQuery
default_currency = db.Column(db.String(3))
bill_types = [
("Expense", "Expense"),
("Reimbursement", "Reimbursement"),
("Transfer", "Transfer"),
]
@property
def _to_serialize(self):
@ -123,12 +124,12 @@ class Project(db.Model):
for bill in self.get_bills_unordered().all():
total_weight = sum(ower.weight for ower in bill.owers)
if bill.bill_type == "Expense":
if bill.bill_type == BillType.EXPENSE:
should_receive[bill.payer.id] += bill.converted_amount
for ower in bill.owers:
should_pay[ower.id] += (ower.weight * bill.converted_amount / total_weight)
if bill.bill_type == "Reimbursement":
if bill.bill_type == BillType.REIMBURSEMENT:
should_receive[bill.payer.id] += bill.converted_amount
for ower in bill.owers:
should_receive[ower.id] -= bill.converted_amount
@ -174,7 +175,7 @@ class Project(db.Model):
"""
monthly = defaultdict(lambda: defaultdict(float))
for bill in self.get_bills_unordered().all():
if bill.bill_type == "Expense":
if bill.bill_type == BillType.EXPENSE:
monthly[bill.date.year][bill.date.month] += bill.converted_amount
return monthly
@ -351,7 +352,7 @@ class Project(db.Model):
pretty_bills.append(
{
"what": bill.what,
"bill_type": bill.bill_type,
"bill_type": bill.bill_type.value,
"amount": round(bill.amount, 2),
"currency": bill.original_currency,
"date": str(bill.date),
@ -695,7 +696,7 @@ class Bill(db.Model):
date = db.Column(db.Date, default=datetime.datetime.now)
creation_date = db.Column(db.Date, default=datetime.datetime.now)
what = db.Column(db.UnicodeText)
bill_type = db.Column(db.UnicodeText)
bill_type = db.Column(db.Enum(BillType))
external_link = db.Column(db.UnicodeText)
original_currency = db.Column(db.String(3))
@ -715,7 +716,7 @@ class Bill(db.Model):
payer_id: int = None,
project_default_currency: str = "",
what: str = "",
bill_type: str = "",
bill_type: str = "Expense",
):
super().__init__()
self.amount = amount
@ -725,7 +726,7 @@ class Bill(db.Model):
self.owers = owers
self.payer_id = payer_id
self.what = what
self.bill_type = bill_type
self.bill_type = BillType(bill_type)
self.converted_amount = self.currency_helper.exchange_currency(
self.amount, self.original_currency, project_default_currency
)
@ -740,7 +741,7 @@ class Bill(db.Model):
"date": self.date,
"creation_date": self.creation_date,
"what": self.what,
"bill_type": self.bill_type,
"bill_type": self.bill_type.value,
"external_link": self.external_link,
"original_currency": self.original_currency,
"converted_amount": self.converted_amount,

View file

@ -1493,7 +1493,7 @@ class TestBudget(IhatemoneyTestCase):
#test if theres a new one with bill_type reimbursement
bill = project.get_newest_bill()
assert bill.bill_type == "Reimbursement"
assert bill.bill_type == models.BillType.REIMBURSEMENT
return
def test_settle_zero(self):