diff --git a/ihatemoney.db b/ihatemoney.db index b82834f2..00c32ee6 100644 Binary files a/ihatemoney.db and b/ihatemoney.db differ diff --git a/ihatemoney/models.py b/ihatemoney/models.py index fb0db673..4513b3e7 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -265,12 +265,13 @@ class Project(db.Model): ) @staticmethod - def filter_by_date(query, start_date, end_date): - if start_date and end_date: - return query.filter(Bill.date >= start_date, Bill.date <= end_date) + def filter_by_date(query, start, end): + if start and end: + return query.filter(Bill.date.between(start, end)) else: return query + def get_bill_weights(self): """ Return all bills for this project, along with the sum of weight for each bill. @@ -292,9 +293,9 @@ class Project(db.Model): """Ordered version of get_bill_weights""" return self.order_bills(self.get_bill_weights()) - def get_filtered_date_bill_weights_ordered(self, start_date, end_date): + def get_filtered_date_bill_weights_ordered(self, start, end): bill_weights_ordered = self.get_bill_weights_ordered() - filtered_bill_weights = self.filter_by_date(bill_weights_ordered, start_date, end_date) + filtered_bill_weights = self.filter_by_date(bill_weights_ordered, start,end ) return filtered_bill_weights def get_member_bills(self, member_id): diff --git a/ihatemoney/templates/list_bills.html b/ihatemoney/templates/list_bills.html index 0ae8b8af..545e1148 100644 --- a/ihatemoney/templates/list_bills.html +++ b/ihatemoney/templates/list_bills.html @@ -107,12 +107,10 @@ {% endif %}
{{ csrf_form.csrf_token }} - - - - - - + + + +
diff --git a/ihatemoney/tests/filterbydate_test.py b/ihatemoney/tests/filterbydate_test.py new file mode 100644 index 00000000..91732657 --- /dev/null +++ b/ihatemoney/tests/filterbydate_test.py @@ -0,0 +1,41 @@ +import pytest +from unittest.mock import Mock +from ihatemoney.models import Project, Bill + + +@pytest.fixture +def test_filter_by_date(Project): + # Prepare mock data + mock_query = Mock() + start_date = '2024-01-01' + end_date = '2024-12-31' + + # Mock the methods being called inside filter_by_date + Project.query.filter.return_value = Mock() # Assuming you're using SQLAlchemy's Query object + + # Call the method to test + result = Project.filter_by_date(mock_query, start_date, end_date) + + # Assertions + assert result == Project.query.filter.return_value # Check if the method returns the expected result + Project.query.filter.assert_called_once_with(Bill.date >= start_date, + Bill.date <= end_date) # Check if filter was called with the correct arguments + + +def test_get_filtered_date_bill_weights_ordered(Project): + # Prepare mock data + start_date = '2024-01-01' + end_date = '2024-12-31' + + + Project.get_bill_weights_ordered.return_value = Mock() + Project.filter_by_date.return_value = Mock() + + # Call the method to test + result = Project.get_filtered_date_bill_weights_ordered(start_date, end_date) + + # Assertions + assert result == Project.filter_by_date.return_value # Check if the method returns the expected result + Project.filter_by_date.assert_called_once_with( + Project.get_bill_weights_ordered.return_value, start_date, + end_date) # Check if filter_by_date was called with the correct arguments diff --git a/ihatemoney/web.py b/ihatemoney/web.py index 7263793b..4ee71385 100644 --- a/ihatemoney/web.py +++ b/ihatemoney/web.py @@ -130,8 +130,8 @@ def set_show_admin_dashboard_link(endpoint, values): """ g.show_admin_dashboard_link = ( - current_app.config["ACTIVATE_ADMIN_DASHBOARD"] - and current_app.config["ADMIN_PASSWORD"] + current_app.config["ACTIVATE_ADMIN_DASHBOARD"] + and current_app.config["ADMIN_PASSWORD"] ) g.logout_form = LogoutForm() @@ -199,7 +199,7 @@ def admin(): if request.method == "POST" and form.validate(): # Valid password if check_password_hash( - current_app.config["ADMIN_PASSWORD"], form.admin_password.data + current_app.config["ADMIN_PASSWORD"], form.admin_password.data ): session["is_admin"] = True session.update() @@ -658,28 +658,31 @@ def list_bills(): if "last_selected_payer" in session: bill_form.payer.data = session["last_selected_payer"] if ( - "last_selected_payed_for" in session - and g.project.id in session["last_selected_payed_for"] + "last_selected_payed_for" in session + and g.project.id in session["last_selected_payed_for"] ): bill_form.payed_for.data = session["last_selected_payed_for"][g.project.id] # Each item will be a (weight_sum, Bill) tuple. # TODO: improve this awkward result using column_property: # https://docs.sqlalchemy.org/en/14/orm/mapped_sql_expr.html. + if request.method == "GET": + # Retrieve ordered bill weights for the project weighted_bills = g.project.get_bill_weights_ordered().paginate( per_page=100, error_out=True ) elif request.method == "POST": # Retrieve start_date and end_date from form data - start_date = request.form.get('start_date') - end_date = request.form.get('end_date') + start = request.form.get('start') + end = request.form.get('end') - - weighted_bills = g.project.get_filtered_date_bill_weights_ordered(start_date, end_date).paginate( + # Retrieve filtered bill weights by date + weighted_bills = g.project.get_filtered_date_bill_weights_ordered(start, end).paginate( per_page=100, error_out=True ) + # Render the template with the appropriate data return render_template( "list_bills.html", bills=weighted_bills, @@ -688,8 +691,8 @@ def list_bills(): csrf_form=csrf_form, add_bill=request.values.get("add_bill", False), current_view="list_bills", - start_date=start_date if request.method == "POST" else None, - end_date=end_date if request.method == "POST" else None, + start=start if request.method == "POST" else None, + end=end if request.method == "POST" else None, ) @main.route("//members/add", methods=["GET", "POST"]) @@ -985,8 +988,8 @@ def feed(token): return "", 304 if ( - request.if_modified_since - and request.if_modified_since.replace(tzinfo=None) >= last_modified + request.if_modified_since + and request.if_modified_since.replace(tzinfo=None) >= last_modified ): return "", 304