mirror of
https://github.com/spiral-project/ihatemoney.git
synced 2025-04-28 17:32:38 +02:00
Have filter by date running properly
This commit is contained in:
parent
417e144455
commit
dce7bd7815
5 changed files with 67 additions and 24 deletions
BIN
ihatemoney.db
BIN
ihatemoney.db
Binary file not shown.
|
@ -265,12 +265,13 @@ class Project(db.Model):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_by_date(query, start_date, end_date):
|
||||
if start_date and end_date:
|
||||
return query.filter(Bill.date >= start_date, Bill.date <= end_date)
|
||||
def filter_by_date(query, start, end):
|
||||
if start and end:
|
||||
return query.filter(Bill.date.between(start, end))
|
||||
else:
|
||||
return query
|
||||
|
||||
|
||||
def get_bill_weights(self):
|
||||
"""
|
||||
Return all bills for this project, along with the sum of weight for each bill.
|
||||
|
@ -292,9 +293,9 @@ class Project(db.Model):
|
|||
"""Ordered version of get_bill_weights"""
|
||||
return self.order_bills(self.get_bill_weights())
|
||||
|
||||
def get_filtered_date_bill_weights_ordered(self, start_date, end_date):
|
||||
def get_filtered_date_bill_weights_ordered(self, start, end):
|
||||
bill_weights_ordered = self.get_bill_weights_ordered()
|
||||
filtered_bill_weights = self.filter_by_date(bill_weights_ordered, start_date, end_date)
|
||||
filtered_bill_weights = self.filter_by_date(bill_weights_ordered, start,end )
|
||||
return filtered_bill_weights
|
||||
|
||||
def get_member_bills(self, member_id):
|
||||
|
|
|
@ -107,12 +107,10 @@
|
|||
{% endif %}
|
||||
<form action="{{ url_for(".list_bills") }}" method="post">
|
||||
{{ csrf_form.csrf_token }}
|
||||
<label for="start_date">Start Date:</label>
|
||||
<input type="date" id="start_date" name="start_date" value="{{ start_date if start_date else '' }}">
|
||||
|
||||
<label for="end_date">End Date:</label>
|
||||
<input type="date" id="end_date" name="end_date" value="{{ end_date if end_date else '' }}">
|
||||
|
||||
<label for="start">Start Date:</label>
|
||||
<input type="date" id="start" name="start" value="{{ start if start else '' }}">
|
||||
<label for="end">End Date:</label>
|
||||
<input type="date" id="end" name="end" value="{{ end if end else '' }}">
|
||||
<input type="submit" value="Enter">
|
||||
</form>
|
||||
|
||||
|
|
41
ihatemoney/tests/filterbydate_test.py
Normal file
41
ihatemoney/tests/filterbydate_test.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from ihatemoney.models import Project, Bill
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_filter_by_date(Project):
|
||||
# Prepare mock data
|
||||
mock_query = Mock()
|
||||
start_date = '2024-01-01'
|
||||
end_date = '2024-12-31'
|
||||
|
||||
# Mock the methods being called inside filter_by_date
|
||||
Project.query.filter.return_value = Mock() # Assuming you're using SQLAlchemy's Query object
|
||||
|
||||
# Call the method to test
|
||||
result = Project.filter_by_date(mock_query, start_date, end_date)
|
||||
|
||||
# Assertions
|
||||
assert result == Project.query.filter.return_value # Check if the method returns the expected result
|
||||
Project.query.filter.assert_called_once_with(Bill.date >= start_date,
|
||||
Bill.date <= end_date) # Check if filter was called with the correct arguments
|
||||
|
||||
|
||||
def test_get_filtered_date_bill_weights_ordered(Project):
|
||||
# Prepare mock data
|
||||
start_date = '2024-01-01'
|
||||
end_date = '2024-12-31'
|
||||
|
||||
|
||||
Project.get_bill_weights_ordered.return_value = Mock()
|
||||
Project.filter_by_date.return_value = Mock()
|
||||
|
||||
# Call the method to test
|
||||
result = Project.get_filtered_date_bill_weights_ordered(start_date, end_date)
|
||||
|
||||
# Assertions
|
||||
assert result == Project.filter_by_date.return_value # Check if the method returns the expected result
|
||||
Project.filter_by_date.assert_called_once_with(
|
||||
Project.get_bill_weights_ordered.return_value, start_date,
|
||||
end_date) # Check if filter_by_date was called with the correct arguments
|
|
@ -130,8 +130,8 @@ def set_show_admin_dashboard_link(endpoint, values):
|
|||
"""
|
||||
|
||||
g.show_admin_dashboard_link = (
|
||||
current_app.config["ACTIVATE_ADMIN_DASHBOARD"]
|
||||
and current_app.config["ADMIN_PASSWORD"]
|
||||
current_app.config["ACTIVATE_ADMIN_DASHBOARD"]
|
||||
and current_app.config["ADMIN_PASSWORD"]
|
||||
)
|
||||
g.logout_form = LogoutForm()
|
||||
|
||||
|
@ -199,7 +199,7 @@ def admin():
|
|||
if request.method == "POST" and form.validate():
|
||||
# Valid password
|
||||
if check_password_hash(
|
||||
current_app.config["ADMIN_PASSWORD"], form.admin_password.data
|
||||
current_app.config["ADMIN_PASSWORD"], form.admin_password.data
|
||||
):
|
||||
session["is_admin"] = True
|
||||
session.update()
|
||||
|
@ -658,28 +658,31 @@ def list_bills():
|
|||
if "last_selected_payer" in session:
|
||||
bill_form.payer.data = session["last_selected_payer"]
|
||||
if (
|
||||
"last_selected_payed_for" in session
|
||||
and g.project.id in session["last_selected_payed_for"]
|
||||
"last_selected_payed_for" in session
|
||||
and g.project.id in session["last_selected_payed_for"]
|
||||
):
|
||||
bill_form.payed_for.data = session["last_selected_payed_for"][g.project.id]
|
||||
|
||||
# Each item will be a (weight_sum, Bill) tuple.
|
||||
# TODO: improve this awkward result using column_property:
|
||||
# https://docs.sqlalchemy.org/en/14/orm/mapped_sql_expr.html.
|
||||
|
||||
if request.method == "GET":
|
||||
# Retrieve ordered bill weights for the project
|
||||
weighted_bills = g.project.get_bill_weights_ordered().paginate(
|
||||
per_page=100, error_out=True
|
||||
)
|
||||
elif request.method == "POST":
|
||||
# Retrieve start_date and end_date from form data
|
||||
start_date = request.form.get('start_date')
|
||||
end_date = request.form.get('end_date')
|
||||
start = request.form.get('start')
|
||||
end = request.form.get('end')
|
||||
|
||||
|
||||
weighted_bills = g.project.get_filtered_date_bill_weights_ordered(start_date, end_date).paginate(
|
||||
# Retrieve filtered bill weights by date
|
||||
weighted_bills = g.project.get_filtered_date_bill_weights_ordered(start, end).paginate(
|
||||
per_page=100, error_out=True
|
||||
)
|
||||
|
||||
# Render the template with the appropriate data
|
||||
return render_template(
|
||||
"list_bills.html",
|
||||
bills=weighted_bills,
|
||||
|
@ -688,8 +691,8 @@ def list_bills():
|
|||
csrf_form=csrf_form,
|
||||
add_bill=request.values.get("add_bill", False),
|
||||
current_view="list_bills",
|
||||
start_date=start_date if request.method == "POST" else None,
|
||||
end_date=end_date if request.method == "POST" else None,
|
||||
start=start if request.method == "POST" else None,
|
||||
end=end if request.method == "POST" else None,
|
||||
)
|
||||
|
||||
@main.route("/<project_id>/members/add", methods=["GET", "POST"])
|
||||
|
@ -985,8 +988,8 @@ def feed(token):
|
|||
return "", 304
|
||||
|
||||
if (
|
||||
request.if_modified_since
|
||||
and request.if_modified_since.replace(tzinfo=None) >= last_modified
|
||||
request.if_modified_since
|
||||
and request.if_modified_since.replace(tzinfo=None) >= last_modified
|
||||
):
|
||||
return "", 304
|
||||
|
||||
|
|
Loading…
Reference in a new issue