Have filter by date running properly

This commit is contained in:
Sylvieox 2024-05-07 10:34:58 -04:00
parent 417e144455
commit dce7bd7815
5 changed files with 67 additions and 24 deletions

Binary file not shown.

View file

@ -265,12 +265,13 @@ class Project(db.Model):
)
@staticmethod
def filter_by_date(query, start_date, end_date):
if start_date and end_date:
return query.filter(Bill.date >= start_date, Bill.date <= end_date)
def filter_by_date(query, start, end):
if start and end:
return query.filter(Bill.date.between(start, end))
else:
return query
def get_bill_weights(self):
"""
Return all bills for this project, along with the sum of weight for each bill.
@ -292,9 +293,9 @@ class Project(db.Model):
"""Ordered version of get_bill_weights"""
return self.order_bills(self.get_bill_weights())
def get_filtered_date_bill_weights_ordered(self, start_date, end_date):
def get_filtered_date_bill_weights_ordered(self, start, end):
bill_weights_ordered = self.get_bill_weights_ordered()
filtered_bill_weights = self.filter_by_date(bill_weights_ordered, start_date, end_date)
filtered_bill_weights = self.filter_by_date(bill_weights_ordered, start,end )
return filtered_bill_weights
def get_member_bills(self, member_id):

View file

@ -107,12 +107,10 @@
{% endif %}
<form action="{{ url_for(".list_bills") }}" method="post">
{{ csrf_form.csrf_token }}
<label for="start_date">Start Date:</label>
<input type="date" id="start_date" name="start_date" value="{{ start_date if start_date else '' }}">
<label for="end_date">End Date:</label>
<input type="date" id="end_date" name="end_date" value="{{ end_date if end_date else '' }}">
<label for="start">Start Date:</label>
<input type="date" id="start" name="start" value="{{ start if start else '' }}">
<label for="end">End Date:</label>
<input type="date" id="end" name="end" value="{{ end if end else '' }}">
<input type="submit" value="Enter">
</form>

View file

@ -0,0 +1,41 @@
import pytest
from unittest.mock import Mock
from ihatemoney.models import Project, Bill
@pytest.fixture
def test_filter_by_date(Project):
# Prepare mock data
mock_query = Mock()
start_date = '2024-01-01'
end_date = '2024-12-31'
# Mock the methods being called inside filter_by_date
Project.query.filter.return_value = Mock() # Assuming you're using SQLAlchemy's Query object
# Call the method to test
result = Project.filter_by_date(mock_query, start_date, end_date)
# Assertions
assert result == Project.query.filter.return_value # Check if the method returns the expected result
Project.query.filter.assert_called_once_with(Bill.date >= start_date,
Bill.date <= end_date) # Check if filter was called with the correct arguments
def test_get_filtered_date_bill_weights_ordered(Project):
# Prepare mock data
start_date = '2024-01-01'
end_date = '2024-12-31'
Project.get_bill_weights_ordered.return_value = Mock()
Project.filter_by_date.return_value = Mock()
# Call the method to test
result = Project.get_filtered_date_bill_weights_ordered(start_date, end_date)
# Assertions
assert result == Project.filter_by_date.return_value # Check if the method returns the expected result
Project.filter_by_date.assert_called_once_with(
Project.get_bill_weights_ordered.return_value, start_date,
end_date) # Check if filter_by_date was called with the correct arguments

View file

@ -130,8 +130,8 @@ def set_show_admin_dashboard_link(endpoint, values):
"""
g.show_admin_dashboard_link = (
current_app.config["ACTIVATE_ADMIN_DASHBOARD"]
and current_app.config["ADMIN_PASSWORD"]
current_app.config["ACTIVATE_ADMIN_DASHBOARD"]
and current_app.config["ADMIN_PASSWORD"]
)
g.logout_form = LogoutForm()
@ -199,7 +199,7 @@ def admin():
if request.method == "POST" and form.validate():
# Valid password
if check_password_hash(
current_app.config["ADMIN_PASSWORD"], form.admin_password.data
current_app.config["ADMIN_PASSWORD"], form.admin_password.data
):
session["is_admin"] = True
session.update()
@ -658,28 +658,31 @@ def list_bills():
if "last_selected_payer" in session:
bill_form.payer.data = session["last_selected_payer"]
if (
"last_selected_payed_for" in session
and g.project.id in session["last_selected_payed_for"]
"last_selected_payed_for" in session
and g.project.id in session["last_selected_payed_for"]
):
bill_form.payed_for.data = session["last_selected_payed_for"][g.project.id]
# Each item will be a (weight_sum, Bill) tuple.
# TODO: improve this awkward result using column_property:
# https://docs.sqlalchemy.org/en/14/orm/mapped_sql_expr.html.
if request.method == "GET":
# Retrieve ordered bill weights for the project
weighted_bills = g.project.get_bill_weights_ordered().paginate(
per_page=100, error_out=True
)
elif request.method == "POST":
# Retrieve start_date and end_date from form data
start_date = request.form.get('start_date')
end_date = request.form.get('end_date')
start = request.form.get('start')
end = request.form.get('end')
weighted_bills = g.project.get_filtered_date_bill_weights_ordered(start_date, end_date).paginate(
# Retrieve filtered bill weights by date
weighted_bills = g.project.get_filtered_date_bill_weights_ordered(start, end).paginate(
per_page=100, error_out=True
)
# Render the template with the appropriate data
return render_template(
"list_bills.html",
bills=weighted_bills,
@ -688,8 +691,8 @@ def list_bills():
csrf_form=csrf_form,
add_bill=request.values.get("add_bill", False),
current_view="list_bills",
start_date=start_date if request.method == "POST" else None,
end_date=end_date if request.method == "POST" else None,
start=start if request.method == "POST" else None,
end=end if request.method == "POST" else None,
)
@main.route("/<project_id>/members/add", methods=["GET", "POST"])
@ -985,8 +988,8 @@ def feed(token):
return "", 304
if (
request.if_modified_since
and request.if_modified_since.replace(tzinfo=None) >= last_modified
request.if_modified_since
and request.if_modified_since.replace(tzinfo=None) >= last_modified
):
return "", 304