mirror of
https://github.com/spiral-project/ihatemoney.git
synced 2025-04-29 01:42:37 +02:00
Have filter by date running properly
This commit is contained in:
parent
417e144455
commit
dce7bd7815
5 changed files with 67 additions and 24 deletions
BIN
ihatemoney.db
BIN
ihatemoney.db
Binary file not shown.
|
@ -265,12 +265,13 @@ class Project(db.Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def filter_by_date(query, start_date, end_date):
|
def filter_by_date(query, start, end):
|
||||||
if start_date and end_date:
|
if start and end:
|
||||||
return query.filter(Bill.date >= start_date, Bill.date <= end_date)
|
return query.filter(Bill.date.between(start, end))
|
||||||
else:
|
else:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
def get_bill_weights(self):
|
def get_bill_weights(self):
|
||||||
"""
|
"""
|
||||||
Return all bills for this project, along with the sum of weight for each bill.
|
Return all bills for this project, along with the sum of weight for each bill.
|
||||||
|
@ -292,9 +293,9 @@ class Project(db.Model):
|
||||||
"""Ordered version of get_bill_weights"""
|
"""Ordered version of get_bill_weights"""
|
||||||
return self.order_bills(self.get_bill_weights())
|
return self.order_bills(self.get_bill_weights())
|
||||||
|
|
||||||
def get_filtered_date_bill_weights_ordered(self, start_date, end_date):
|
def get_filtered_date_bill_weights_ordered(self, start, end):
|
||||||
bill_weights_ordered = self.get_bill_weights_ordered()
|
bill_weights_ordered = self.get_bill_weights_ordered()
|
||||||
filtered_bill_weights = self.filter_by_date(bill_weights_ordered, start_date, end_date)
|
filtered_bill_weights = self.filter_by_date(bill_weights_ordered, start,end )
|
||||||
return filtered_bill_weights
|
return filtered_bill_weights
|
||||||
|
|
||||||
def get_member_bills(self, member_id):
|
def get_member_bills(self, member_id):
|
||||||
|
|
|
@ -107,12 +107,10 @@
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<form action="{{ url_for(".list_bills") }}" method="post">
|
<form action="{{ url_for(".list_bills") }}" method="post">
|
||||||
{{ csrf_form.csrf_token }}
|
{{ csrf_form.csrf_token }}
|
||||||
<label for="start_date">Start Date:</label>
|
<label for="start">Start Date:</label>
|
||||||
<input type="date" id="start_date" name="start_date" value="{{ start_date if start_date else '' }}">
|
<input type="date" id="start" name="start" value="{{ start if start else '' }}">
|
||||||
|
<label for="end">End Date:</label>
|
||||||
<label for="end_date">End Date:</label>
|
<input type="date" id="end" name="end" value="{{ end if end else '' }}">
|
||||||
<input type="date" id="end_date" name="end_date" value="{{ end_date if end_date else '' }}">
|
|
||||||
|
|
||||||
<input type="submit" value="Enter">
|
<input type="submit" value="Enter">
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
|
|
41
ihatemoney/tests/filterbydate_test.py
Normal file
41
ihatemoney/tests/filterbydate_test.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
from ihatemoney.models import Project, Bill
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_filter_by_date(Project):
|
||||||
|
# Prepare mock data
|
||||||
|
mock_query = Mock()
|
||||||
|
start_date = '2024-01-01'
|
||||||
|
end_date = '2024-12-31'
|
||||||
|
|
||||||
|
# Mock the methods being called inside filter_by_date
|
||||||
|
Project.query.filter.return_value = Mock() # Assuming you're using SQLAlchemy's Query object
|
||||||
|
|
||||||
|
# Call the method to test
|
||||||
|
result = Project.filter_by_date(mock_query, start_date, end_date)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert result == Project.query.filter.return_value # Check if the method returns the expected result
|
||||||
|
Project.query.filter.assert_called_once_with(Bill.date >= start_date,
|
||||||
|
Bill.date <= end_date) # Check if filter was called with the correct arguments
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_filtered_date_bill_weights_ordered(Project):
|
||||||
|
# Prepare mock data
|
||||||
|
start_date = '2024-01-01'
|
||||||
|
end_date = '2024-12-31'
|
||||||
|
|
||||||
|
|
||||||
|
Project.get_bill_weights_ordered.return_value = Mock()
|
||||||
|
Project.filter_by_date.return_value = Mock()
|
||||||
|
|
||||||
|
# Call the method to test
|
||||||
|
result = Project.get_filtered_date_bill_weights_ordered(start_date, end_date)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert result == Project.filter_by_date.return_value # Check if the method returns the expected result
|
||||||
|
Project.filter_by_date.assert_called_once_with(
|
||||||
|
Project.get_bill_weights_ordered.return_value, start_date,
|
||||||
|
end_date) # Check if filter_by_date was called with the correct arguments
|
|
@ -666,20 +666,23 @@ def list_bills():
|
||||||
# Each item will be a (weight_sum, Bill) tuple.
|
# Each item will be a (weight_sum, Bill) tuple.
|
||||||
# TODO: improve this awkward result using column_property:
|
# TODO: improve this awkward result using column_property:
|
||||||
# https://docs.sqlalchemy.org/en/14/orm/mapped_sql_expr.html.
|
# https://docs.sqlalchemy.org/en/14/orm/mapped_sql_expr.html.
|
||||||
|
|
||||||
if request.method == "GET":
|
if request.method == "GET":
|
||||||
|
# Retrieve ordered bill weights for the project
|
||||||
weighted_bills = g.project.get_bill_weights_ordered().paginate(
|
weighted_bills = g.project.get_bill_weights_ordered().paginate(
|
||||||
per_page=100, error_out=True
|
per_page=100, error_out=True
|
||||||
)
|
)
|
||||||
elif request.method == "POST":
|
elif request.method == "POST":
|
||||||
# Retrieve start_date and end_date from form data
|
# Retrieve start_date and end_date from form data
|
||||||
start_date = request.form.get('start_date')
|
start = request.form.get('start')
|
||||||
end_date = request.form.get('end_date')
|
end = request.form.get('end')
|
||||||
|
|
||||||
|
# Retrieve filtered bill weights by date
|
||||||
weighted_bills = g.project.get_filtered_date_bill_weights_ordered(start_date, end_date).paginate(
|
weighted_bills = g.project.get_filtered_date_bill_weights_ordered(start, end).paginate(
|
||||||
per_page=100, error_out=True
|
per_page=100, error_out=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Render the template with the appropriate data
|
||||||
return render_template(
|
return render_template(
|
||||||
"list_bills.html",
|
"list_bills.html",
|
||||||
bills=weighted_bills,
|
bills=weighted_bills,
|
||||||
|
@ -688,8 +691,8 @@ def list_bills():
|
||||||
csrf_form=csrf_form,
|
csrf_form=csrf_form,
|
||||||
add_bill=request.values.get("add_bill", False),
|
add_bill=request.values.get("add_bill", False),
|
||||||
current_view="list_bills",
|
current_view="list_bills",
|
||||||
start_date=start_date if request.method == "POST" else None,
|
start=start if request.method == "POST" else None,
|
||||||
end_date=end_date if request.method == "POST" else None,
|
end=end if request.method == "POST" else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@main.route("/<project_id>/members/add", methods=["GET", "POST"])
|
@main.route("/<project_id>/members/add", methods=["GET", "POST"])
|
||||||
|
|
Loading…
Reference in a new issue