mirror of
https://github.com/spiral-project/ihatemoney.git
synced 2025-04-28 17:32:38 +02:00
Changed bill_type to an Enum
This commit is contained in:
parent
8cef9492d0
commit
b0a67a5156
4 changed files with 21 additions and 24 deletions
|
@ -39,7 +39,7 @@ from wtforms.validators import (
|
|||
)
|
||||
|
||||
from ihatemoney.currency_convertor import CurrencyConverter
|
||||
from ihatemoney.models import Bill, LoggingMode, Person, Project
|
||||
from ihatemoney.models import Bill, BillType, LoggingMode, Person, Project
|
||||
from ihatemoney.utils import (
|
||||
em_surround,
|
||||
eval_arithmetic_expression,
|
||||
|
@ -81,7 +81,6 @@ def get_billform_for(project, set_default=True, **kwargs):
|
|||
|
||||
active_members = [(m.id, m.name) for m in project.active_members]
|
||||
|
||||
form.bill_type.choices = project.bill_types
|
||||
form.payed_for.choices = form.payer.choices = active_members
|
||||
form.payed_for.default = [m.id for m in project.active_members]
|
||||
|
||||
|
@ -365,7 +364,7 @@ class BillForm(FlaskForm):
|
|||
payed_for = SelectMultipleField(
|
||||
_("For whom?"), validators=[DataRequired()], coerce=int
|
||||
)
|
||||
bill_type = SelectField(_("Bill Type"), validators=[DataRequired()], coerce=str)
|
||||
bill_type = SelectField(_("Bill Type"), validators=[DataRequired()], choices=BillType.choices())
|
||||
submit = SubmitField(_("Submit"))
|
||||
submit2 = SubmitField(_("Submit and add a new one"))
|
||||
|
||||
|
@ -386,7 +385,7 @@ class BillForm(FlaskForm):
|
|||
bill.payer_id = self.payer.data
|
||||
bill.amount = self.amount.data
|
||||
bill.what = self.what.data
|
||||
bill.bill_type = self.bill_type.data
|
||||
bill.bill_type = BillType(self.bill_type.data)
|
||||
bill.external_link = self.external_link.data
|
||||
bill.date = self.date.data
|
||||
bill.owers = Person.query.get_by_ids(self.payed_for.data, project)
|
||||
|
@ -433,10 +432,6 @@ class BillForm(FlaskForm):
|
|||
)
|
||||
raise ValidationError(msg)
|
||||
|
||||
def validate_bill_type(self, field):
|
||||
if (field.data, field.data) not in Project.bill_types:
|
||||
raise ValidationError(_("Invalid Bill Type"))
|
||||
|
||||
|
||||
class MemberForm(FlaskForm):
|
||||
name = StringField(_("Name"), validators=[DataRequired()], filters=[strip_filter])
|
||||
|
|
|
@ -12,10 +12,11 @@ down_revision = "927ed575acbd"
|
|||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from ihatemoney.models import BillType
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column("bill", sa.Column("bill_type", sa.UnicodeText()))
|
||||
op.add_column("bill", sa.Column("bill_type", sa.Enum(BillType)))
|
||||
op.add_column("bill_version", sa.Column("bill_type", sa.UnicodeText()))
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from sqlalchemy_continuum.plugins import FlaskPlugin
|
|||
|
||||
from ihatemoney.currency_convertor import CurrencyConverter
|
||||
from ihatemoney.monkeypath_continuum import PatchedTransactionFactory
|
||||
from ihatemoney.utils import generate_password_hash, get_members, same_bill
|
||||
from ihatemoney.utils import generate_password_hash, get_members, same_bill, FormEnum
|
||||
from ihatemoney.versioning import (
|
||||
ConditionalVersioningManager,
|
||||
LoggingMode,
|
||||
|
@ -50,6 +50,12 @@ make_versioned(
|
|||
],
|
||||
)
|
||||
|
||||
class BillType(FormEnum):
|
||||
EXPENSE = "Expense"
|
||||
REIMBURSEMENT = "Reimbursement"
|
||||
TRANSFER = "Transfer"
|
||||
|
||||
|
||||
db = SQLAlchemy()
|
||||
|
||||
|
||||
|
@ -76,11 +82,6 @@ class Project(db.Model):
|
|||
|
||||
query_class = ProjectQuery
|
||||
default_currency = db.Column(db.String(3))
|
||||
bill_types = [
|
||||
("Expense", "Expense"),
|
||||
("Reimbursement", "Reimbursement"),
|
||||
("Transfer", "Transfer"),
|
||||
]
|
||||
|
||||
@property
|
||||
def _to_serialize(self):
|
||||
|
@ -123,12 +124,12 @@ class Project(db.Model):
|
|||
for bill in self.get_bills_unordered().all():
|
||||
total_weight = sum(ower.weight for ower in bill.owers)
|
||||
|
||||
if bill.bill_type == "Expense":
|
||||
if bill.bill_type == BillType.EXPENSE:
|
||||
should_receive[bill.payer.id] += bill.converted_amount
|
||||
for ower in bill.owers:
|
||||
should_pay[ower.id] += (ower.weight * bill.converted_amount / total_weight)
|
||||
|
||||
if bill.bill_type == "Reimbursement":
|
||||
if bill.bill_type == BillType.REIMBURSEMENT:
|
||||
should_receive[bill.payer.id] += bill.converted_amount
|
||||
for ower in bill.owers:
|
||||
should_receive[ower.id] -= bill.converted_amount
|
||||
|
@ -174,7 +175,7 @@ class Project(db.Model):
|
|||
"""
|
||||
monthly = defaultdict(lambda: defaultdict(float))
|
||||
for bill in self.get_bills_unordered().all():
|
||||
if bill.bill_type == "Expense":
|
||||
if bill.bill_type == BillType.EXPENSE:
|
||||
monthly[bill.date.year][bill.date.month] += bill.converted_amount
|
||||
return monthly
|
||||
|
||||
|
@ -351,7 +352,7 @@ class Project(db.Model):
|
|||
pretty_bills.append(
|
||||
{
|
||||
"what": bill.what,
|
||||
"bill_type": bill.bill_type,
|
||||
"bill_type": bill.bill_type.value,
|
||||
"amount": round(bill.amount, 2),
|
||||
"currency": bill.original_currency,
|
||||
"date": str(bill.date),
|
||||
|
@ -695,7 +696,7 @@ class Bill(db.Model):
|
|||
date = db.Column(db.Date, default=datetime.datetime.now)
|
||||
creation_date = db.Column(db.Date, default=datetime.datetime.now)
|
||||
what = db.Column(db.UnicodeText)
|
||||
bill_type = db.Column(db.UnicodeText)
|
||||
bill_type = db.Column(db.Enum(BillType))
|
||||
external_link = db.Column(db.UnicodeText)
|
||||
|
||||
original_currency = db.Column(db.String(3))
|
||||
|
@ -715,7 +716,7 @@ class Bill(db.Model):
|
|||
payer_id: int = None,
|
||||
project_default_currency: str = "",
|
||||
what: str = "",
|
||||
bill_type: str = "",
|
||||
bill_type: str = "Expense",
|
||||
):
|
||||
super().__init__()
|
||||
self.amount = amount
|
||||
|
@ -725,7 +726,7 @@ class Bill(db.Model):
|
|||
self.owers = owers
|
||||
self.payer_id = payer_id
|
||||
self.what = what
|
||||
self.bill_type = bill_type
|
||||
self.bill_type = BillType(bill_type)
|
||||
self.converted_amount = self.currency_helper.exchange_currency(
|
||||
self.amount, self.original_currency, project_default_currency
|
||||
)
|
||||
|
@ -740,7 +741,7 @@ class Bill(db.Model):
|
|||
"date": self.date,
|
||||
"creation_date": self.creation_date,
|
||||
"what": self.what,
|
||||
"bill_type": self.bill_type,
|
||||
"bill_type": self.bill_type.value,
|
||||
"external_link": self.external_link,
|
||||
"original_currency": self.original_currency,
|
||||
"converted_amount": self.converted_amount,
|
||||
|
|
|
@ -1493,7 +1493,7 @@ class TestBudget(IhatemoneyTestCase):
|
|||
|
||||
#test if theres a new one with bill_type reimbursement
|
||||
bill = project.get_newest_bill()
|
||||
assert bill.bill_type == "Reimbursement"
|
||||
assert bill.bill_type == models.BillType.REIMBURSEMENT
|
||||
return
|
||||
|
||||
def test_settle_zero(self):
|
||||
|
|
Loading…
Reference in a new issue