diff --git a/lacommunaute/forum_conversation/tests/tests_views.py b/lacommunaute/forum_conversation/tests/tests_views.py index 21c9e0bcf..5964953c1 100644 --- a/lacommunaute/forum_conversation/tests/tests_views.py +++ b/lacommunaute/forum_conversation/tests/tests_views.py @@ -27,6 +27,7 @@ from lacommunaute.forum_conversation.views import PostDeleteView, TopicCreateView from lacommunaute.forum_moderation.factories import BlockedDomainNameFactory, BlockedEmailFactory from lacommunaute.forum_upvote.factories import UpVoteFactory +from lacommunaute.notification.factories import NotificationFactory from lacommunaute.users.factories import UserFactory from lacommunaute.utils.testing import parse_response_to_soup @@ -713,6 +714,18 @@ def test_delete_link_visibility(self): status_code=200, ) + def test_get_marks_notifications_read(self): + self.client.force_login(self.poster) + + notification = NotificationFactory(recipient=self.poster.email, post=self.topic.first_post) + self.assertIsNone(notification.sent_at) + + response = self.client.get(self.url) + self.assertEqual(response.status_code, 200) + + notification.refresh_from_db() + self.assertEqual(str(notification.created), str(notification.sent_at)) + def test_numqueries(self): PostFactory.create_batch(10, topic=self.topic, poster=self.poster) UpVoteFactory(content_object=self.topic.last_post, voter=UserFactory()) @@ -720,7 +733,7 @@ def test_numqueries(self): self.client.force_login(self.poster) # note vincentporte : to be optimized - with self.assertNumQueries(39): + with self.assertNumQueries(40): response = self.client.get(self.url) self.assertEqual(response.status_code, 200) diff --git a/lacommunaute/forum_conversation/tests/tests_views_htmx.py b/lacommunaute/forum_conversation/tests/tests_views_htmx.py index 7680c3816..9103f562b 100644 --- a/lacommunaute/forum_conversation/tests/tests_views_htmx.py +++ b/lacommunaute/forum_conversation/tests/tests_views_htmx.py @@ -14,6 +14,7 @@ from lacommunaute.forum_moderation.factories import BlockedDomainNameFactory, BlockedEmailFactory from lacommunaute.forum_moderation.models import BlockedPost from lacommunaute.forum_upvote.factories import UpVoteFactory +from lacommunaute.notification.factories import NotificationFactory from lacommunaute.users.factories import UserFactory @@ -210,6 +211,18 @@ def test_certified_post_highlight(self): response = self.client.get(self.url) self.assertContains(response, "Certifié par la Plateforme de l'Inclusion", status_code=200) + def test_get_marks_notifications_read(self): + self.client.force_login(self.user) + + notification = NotificationFactory(recipient=self.user.email, post=self.topic.first_post) + self.assertIsNone(notification.sent_at) + + response = self.client.get(self.url) + self.assertEqual(response.status_code, 200) + + notification.refresh_from_db() + self.assertEqual(str(notification.created), str(notification.sent_at)) + class PostFeedCreateViewTest(TestCase): @classmethod diff --git a/lacommunaute/forum_conversation/views.py b/lacommunaute/forum_conversation/views.py index 4c6b86265..bf3fa9689 100644 --- a/lacommunaute/forum_conversation/views.py +++ b/lacommunaute/forum_conversation/views.py @@ -15,6 +15,7 @@ from lacommunaute.forum_conversation.models import Topic from lacommunaute.forum_conversation.shortcuts import can_certify_post, get_posts_of_a_topic_except_first_one from lacommunaute.forum_conversation.view_mixins import FilteredTopicsListViewMixin +from lacommunaute.notification.models import Notification logger = logging.getLogger(__name__) @@ -113,6 +114,11 @@ def get_context_data(self, **kwargs): def get_queryset(self): return get_posts_of_a_topic_except_first_one(self.topic, self.request.user) + def get(self, request, *args, **kwargs): + if request.user.is_authenticated: + Notification.objects.mark_topic_posts_read(self.get_topic(), request.user) + return super().get(request, *args, **kwargs) + class TopicListView(FilteredTopicsListViewMixin, ListView): context_object_name = "topics" diff --git a/lacommunaute/forum_conversation/views_htmx.py b/lacommunaute/forum_conversation/views_htmx.py index 87dda6ee4..d1fc22ebf 100644 --- a/lacommunaute/forum_conversation/views_htmx.py +++ b/lacommunaute/forum_conversation/views_htmx.py @@ -7,6 +7,7 @@ from lacommunaute.forum_conversation.forms import PostForm from lacommunaute.forum_conversation.models import CertifiedPost, Post, Topic from lacommunaute.forum_conversation.shortcuts import can_certify_post, get_posts_of_a_topic_except_first_one +from lacommunaute.notification.models import Notification logger = logging.getLogger(__name__) @@ -67,6 +68,8 @@ def get(self, request, **kwargs): topic = self.get_topic() track_handler.mark_topic_read(topic, request.user) + if request.user.is_authenticated: + Notification.objects.mark_topic_posts_read(topic, request.user) return render( request, diff --git a/lacommunaute/notification/models.py b/lacommunaute/notification/models.py index 34f3b1638..5b65d42ef 100644 --- a/lacommunaute/notification/models.py +++ b/lacommunaute/notification/models.py @@ -2,6 +2,7 @@ from operator import attrgetter from django.db import models +from django.db.models import F from django.utils.translation import gettext_lazy as _ from machina.models.abstract_models import DatedModel @@ -31,6 +32,17 @@ def group_by_recipient(self): for recipient, group in groupby(self.order_by("recipient", "kind"), key=attrgetter("recipient")) } + def mark_topic_posts_read(self, topic, user): + """ + Called when a topic's posts are read - to update the read status of associated Notification + """ + if not topic or (not user or user.is_anonymous): + raise ValueError() + + self.filter( + sent_at__isnull=True, recipient=user.email, post__in=topic.posts.values_list("id", flat=True) + ).update(sent_at=F("created")) + class Notification(DatedModel): recipient = models.EmailField(verbose_name=_("recipient"), null=False, blank=False) diff --git a/lacommunaute/notification/tests/tests_models.py b/lacommunaute/notification/tests/tests_models.py index b0fdf194d..c12a93d3b 100644 --- a/lacommunaute/notification/tests/tests_models.py +++ b/lacommunaute/notification/tests/tests_models.py @@ -1,8 +1,12 @@ +from django.contrib.auth.models import AnonymousUser +from django.db.models import F from django.test import TestCase +from lacommunaute.forum_conversation.factories import TopicFactory from lacommunaute.notification.enums import EmailSentTrackKind from lacommunaute.notification.factories import NotificationFactory from lacommunaute.notification.models import EmailSentTrack, Notification +from lacommunaute.users.factories import UserFactory class EmailSentTrackModelTest(TestCase): @@ -31,3 +35,57 @@ def test_notification_group_by_recipient(self): ) self.assertEqual(result[recipient_b], [notification_b]) + + def test_mark_topic_posts_read(self): + user = UserFactory() + topic = TopicFactory(with_post=True) + + NotificationFactory.create_batch(2, recipient=user.email, post=topic.first_post) + + Notification.objects.mark_topic_posts_read(topic, user) + + self.assertEqual( + Notification.objects.filter( + sent_at__isnull=False, post=topic.first_post, recipient=user.email, sent_at=F("created") + ).count(), + 2, + ) + + def test_mark_topic_posts_read_doesnt_impact_old_notifications(self): + user = UserFactory() + topic = TopicFactory(with_post=True) + + old_notification = NotificationFactory(recipient=user.email, post=topic.first_post, is_sent=True) + self.assertNotEqual(str(old_notification.sent_at), str(old_notification.created)) + + Notification.objects.mark_topic_posts_read(topic, user) + + old_notification.refresh_from_db() + self.assertNotEqual(str(old_notification.sent_at), str(old_notification.created)) + + def test_mark_topic_posts_read_doesnt_impact_other_notifications(self): + user = UserFactory() + topic = TopicFactory(with_post=True) + + other_notification = NotificationFactory(recipient="test@example.com", post=topic.first_post) + + Notification.objects.mark_topic_posts_read(topic, user) + + other_notification.refresh_from_db() + self.assertIsNone(other_notification.sent_at) + + def test_mark_topic_posts_read_anonymous_user(self): + topic = TopicFactory(with_post=True) + + with self.assertRaises(ValueError): + Notification.objects.mark_topic_posts_read(topic, AnonymousUser()) + + def test_mark_topic_posts_read_invalid_arguments(self): + user = UserFactory() + topic = TopicFactory(with_post=True) + + with self.assertRaises(ValueError): + Notification.objects.mark_topic_posts_read(None, user) + + with self.assertRaises(ValueError): + Notification.objects.mark_topic_posts_read(topic, None)