diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ac2e9099e..920a7205c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,6 @@ exclude: | (?x)^( \.git| __pycache__| - .*snap_test_.*\.py| .+\/.+\/migrations\/.*| legacy| \.venv diff --git a/api/drf_views.py b/api/drf_views.py index f75c0e274..bc288087c 100644 --- a/api/drf_views.py +++ b/api/drf_views.py @@ -59,6 +59,7 @@ from deployments.models import Personnel from main.enums import GlobalEnumSerializer, get_enum_values from main.filters import NullsLastOrderingFilter +from main.permissions import DenyGuestUserMutationPermission from main.utils import is_tableau from per.models import Overview from per.serializers import CountryLatestOverviewSerializer @@ -870,7 +871,7 @@ def get_serializer_class(self): class ProfileViewset(viewsets.ModelViewSet): serializer_class = ProfileSerializer authentication_classes = (TokenAuthentication,) - permission_classes = (IsAuthenticated,) + permission_classes = (IsAuthenticated, DenyGuestUserMutationPermission) def get_queryset(self): return Profile.objects.filter(user=self.request.user) @@ -879,16 +880,12 @@ def get_queryset(self): class UserViewset(viewsets.ModelViewSet): serializer_class = UserSerializer authentication_classes = (TokenAuthentication,) - permission_classes = (IsAuthenticated,) + permission_classes = [IsAuthenticated, DenyGuestUserMutationPermission] def get_queryset(self): return User.objects.filter(pk=self.request.user.pk) - @action( - detail=False, - url_path="me", - serializer_class=UserMeSerializer, - ) + @action(detail=False, url_path="me", serializer_class=UserMeSerializer, permission_classes=(IsAuthenticated,)) def get_authenticated_user_info(self, request, *args, **kwargs): return Response(self.get_serializer_class()(request.user).data) @@ -915,7 +912,7 @@ class FieldReportViewset(ReadOnlyVisibilityViewsetMixin, viewsets.ModelViewSet): ) # for /docs ordering_fields = ("summary", "event", "dtype", "created_at", "updated_at") filterset_class = FieldReportFilter - authentication_class = [IsAuthenticated] + permission_classes = [IsAuthenticated, DenyGuestUserMutationPermission] queryset = FieldReport.objects.select_related("dtype", "event").prefetch_related( "actions_taken", "actions_taken__actions", "countries", "districts", "regions" ) @@ -1308,7 +1305,7 @@ class UsersViewset(viewsets.ReadOnlyModelViewSet): """ serializer_class = UserSerializer - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticated, DenyGuestUserMutationPermission] filterset_class = UserFilterSet def get_queryset(self): @@ -1346,7 +1343,7 @@ def get(self, _): class ExportViewSet(viewsets.ModelViewSet): serializer_class = ExportSerializer - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticated, DenyGuestUserMutationPermission] def get_queryset(self): user = self.request.user diff --git a/api/migrations/0212_profile_limit_access_to_guest.py b/api/migrations/0212_profile_limit_access_to_guest.py new file mode 100644 index 000000000..03a5b11f7 --- /dev/null +++ b/api/migrations/0212_profile_limit_access_to_guest.py @@ -0,0 +1,22 @@ +# Generated by Django 4.2.13 on 2024-07-30 07:53 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("api", "0211_alter_countrydirectory_unique_together_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="profile", + name="limit_access_to_guest", + field=models.BooleanField( + default=True, + help_text="If this value is set to true, the user is treated as a guest user regardless of any other permissions they may have, thereby depriving them of all non-guest user permissions.", + verbose_name="limit access to guest user permissions", + ), + ), + ] diff --git a/api/models.py b/api/models.py index 576266cc8..60b802a59 100644 --- a/api/models.py +++ b/api/models.py @@ -1850,6 +1850,14 @@ class OrgTypes(models.TextChoices): phone_number = models.CharField(verbose_name=_("phone number"), blank=True, null=True, max_length=100) last_frontend_login = models.DateTimeField(verbose_name=_("last frontend login"), null=True, blank=True) accepted_montandon_license_terms = models.BooleanField(verbose_name=_("has accepted montandon license terms?"), default=False) + limit_access_to_guest = models.BooleanField( + help_text=( + "If this value is set to true, the user is treated as a guest user regardless of any other permissions" + " they may have, thereby depriving them of all non-guest user permissions." + ), + verbose_name=_("limit access to guest user permissions"), + default=True, + ) class Meta: verbose_name = _("user profile") diff --git a/api/serializers.py b/api/serializers.py index e0f9a1b05..58fa6e67e 100644 --- a/api/serializers.py +++ b/api/serializers.py @@ -1703,6 +1703,7 @@ class UserMeSerializer(UserSerializer): is_per_admin_for_regions = serializers.SerializerMethodField() is_per_admin_for_countries = serializers.SerializerMethodField() user_countries_regions = serializers.SerializerMethodField() + limit_access_to_guest = serializers.BooleanField(read_only=True, source="profile.limit_access_to_guest") class Meta: model = User @@ -1714,6 +1715,7 @@ class Meta: "is_per_admin_for_regions", "is_per_admin_for_countries", "user_countries_regions", + "limit_access_to_guest", ) @staticmethod diff --git a/api/snapshots/snap_test_views.py b/api/snapshots/snap_test_views.py index 4f1e4fc52..bc263648a 100644 --- a/api/snapshots/snap_test_views.py +++ b/api/snapshots/snap_test_views.py @@ -14,7 +14,7 @@ "countries": [], "countries_for_preview": [], "created_at": "2008-01-01T00:00:00.123456Z", - "disaster_start_date": "2015-04-21T17:45:23.476445Z", + "disaster_start_date": "2021-09-20T13:28:12.297843Z", "districts": [], "dtype": 1, "emergency_response_contact_email": None, @@ -56,12 +56,12 @@ }, ], "field_reports": [], - "glide": "xJKxDZJiNfetzTUEHA", + "glide": "bxJKxDZJiNfetzTUEH", "hide_attached_field_reports": True, "hide_field_report_map": True, "id": 2, - "ifrc_severity_level": 0, - "ifrc_severity_level_display": "Yellow", + "ifrc_severity_level": 1, + "ifrc_severity_level_display": "Orange", "is_featured": False, "is_featured_region": True, "key_figures": [], @@ -71,7 +71,7 @@ "parent_event": 1, "response_activity_count": 0, "slug": "ygwwmqzcudihyfjsonxkmtecqoxsfogyrdoxkxwnqrsrpemoki", - "summary": "NMGyDLJYVcCZKPmuMEGjdCgZvTfGPlcpTCCHHNkxxsyAXvRMdYOPvevgJRysqUQMjvfLQjwtPSQziMTftJyPYviQSVRHfPQBGxbxtlnvXFmoijesYgGXIVHcQvXNiMyjklSXNZkUCcAxRUpCNsWVYCoIptZYEmxRKCDXsXyGHAkmZMiqdPExJgTHhsfWkrCGjBfoCwbAdzGxpyfxobugTPvYjicsESiWTECNafbqnjJUMHBhXspthdpAOYNDehFMIbOGKpTjsBaNwpKAlQQfHxeHIGYGJbyEcOyxqVbwYewpUQOgXLVWvicwIvPlXRDSEOlZieTXDcsmcYmcutGzIEqcWPmswXdPvrhZxBzVCyvlFSFxZHrZfUBfBMlIsugfuQstCMTBkSCwCcUwNBrOYdeQOzxGZVRkbjMRYCciepXPxxyKcMjRCxxCWeKiHxzuPrphbVlFHyJhqXqTCnNsSFmhieClTCfZRuQwTeJIstkTTSOlYxGo", + "summary": "fNMGyDLJYVcCZKPmuMEGjdCgZvTfGPlcpTCCHHNkxxsyAXvRMdYOPvevgJRysqUQMjvfLQjwtPSQziMTftJyPYviQSVRHfPQBGxbxtlnvXFmoijesYgGXIVHcQvXNiMyjklSXNZkUCcAxRUpCNsWVYCoIptZYEmxRKCDXsXyGHAkmZMiqdPExJgTHhsfWkrCGjBfoCwbAdzGxpyfxobugTPvYjicsESiWTECNafbqnjJUMHBhXspthdpAOYNDehFMIbOGKpTjsBaNwpKAlQQfHxeHIGYGJbyEcOyxqVbwYewpUQOgXLVWvicwIvPlXRDSEOlZieTXDcsmcYmcutGzIEqcWPmswXdPvrhZxBzVCyvlFSFxZHrZfUBfBMlIsugfuQstCMTBkSCwCcUwNBrOYdeQOzxGZVRkbjMRYCciepXPxxyKcMjRCxxCWeKiHxzuPrphbVlFHyJhqXqTCnNsSFmhieClTCfZRuQwTeJIstkTTSOlYxG", "tab_one_title": "cPXKqPnXKANObFOIsPtEpZZRztDeSdkCAEDnvMjuTuUwziWxGJ", "tab_three_title": "gBiqUxWzxczdKJmxJseyGCWJrNRNhigzxYvJxWjmMGzGccciTv", "tab_two_title": "gupDhrCpjgdsyNApkuKUumWkFGDFtFbfzGDpnLwddsFMPREsIa", @@ -88,18 +88,18 @@ "countries": [], "countries_for_preview": [], "created_at": "2008-01-01T00:00:00.123456Z", - "disaster_start_date": "2015-04-21T17:45:23.476445Z", + "disaster_start_date": "2021-09-20T13:28:12.297843Z", "districts": [], "dtype": 1, "emergency_response_contact_email": None, "featured_documents": [], "field_reports": [], - "glide": "xJKxDZJiNfetzTUEHA", + "glide": "bxJKxDZJiNfetzTUEH", "hide_attached_field_reports": True, "hide_field_report_map": True, "id": 2, - "ifrc_severity_level": 0, - "ifrc_severity_level_display": "Yellow", + "ifrc_severity_level": 1, + "ifrc_severity_level_display": "Orange", "is_featured": False, "is_featured_region": True, "key_figures": [], @@ -145,7 +145,7 @@ "parent_event": 1, "response_activity_count": 0, "slug": "ygwwmqzcudihyfjsonxkmtecqoxsfogyrdoxkxwnqrsrpemoki", - "summary": "NMGyDLJYVcCZKPmuMEGjdCgZvTfGPlcpTCCHHNkxxsyAXvRMdYOPvevgJRysqUQMjvfLQjwtPSQziMTftJyPYviQSVRHfPQBGxbxtlnvXFmoijesYgGXIVHcQvXNiMyjklSXNZkUCcAxRUpCNsWVYCoIptZYEmxRKCDXsXyGHAkmZMiqdPExJgTHhsfWkrCGjBfoCwbAdzGxpyfxobugTPvYjicsESiWTECNafbqnjJUMHBhXspthdpAOYNDehFMIbOGKpTjsBaNwpKAlQQfHxeHIGYGJbyEcOyxqVbwYewpUQOgXLVWvicwIvPlXRDSEOlZieTXDcsmcYmcutGzIEqcWPmswXdPvrhZxBzVCyvlFSFxZHrZfUBfBMlIsugfuQstCMTBkSCwCcUwNBrOYdeQOzxGZVRkbjMRYCciepXPxxyKcMjRCxxCWeKiHxzuPrphbVlFHyJhqXqTCnNsSFmhieClTCfZRuQwTeJIstkTTSOlYxGo", + "summary": "fNMGyDLJYVcCZKPmuMEGjdCgZvTfGPlcpTCCHHNkxxsyAXvRMdYOPvevgJRysqUQMjvfLQjwtPSQziMTftJyPYviQSVRHfPQBGxbxtlnvXFmoijesYgGXIVHcQvXNiMyjklSXNZkUCcAxRUpCNsWVYCoIptZYEmxRKCDXsXyGHAkmZMiqdPExJgTHhsfWkrCGjBfoCwbAdzGxpyfxobugTPvYjicsESiWTECNafbqnjJUMHBhXspthdpAOYNDehFMIbOGKpTjsBaNwpKAlQQfHxeHIGYGJbyEcOyxqVbwYewpUQOgXLVWvicwIvPlXRDSEOlZieTXDcsmcYmcutGzIEqcWPmswXdPvrhZxBzVCyvlFSFxZHrZfUBfBMlIsugfuQstCMTBkSCwCcUwNBrOYdeQOzxGZVRkbjMRYCciepXPxxyKcMjRCxxCWeKiHxzuPrphbVlFHyJhqXqTCnNsSFmhieClTCfZRuQwTeJIstkTTSOlYxG", "tab_one_title": "cPXKqPnXKANObFOIsPtEpZZRztDeSdkCAEDnvMjuTuUwziWxGJ", "tab_three_title": "gBiqUxWzxczdKJmxJseyGCWJrNRNhigzxYvJxWjmMGzGccciTv", "tab_two_title": "gupDhrCpjgdsyNApkuKUumWkFGDFtFbfzGDpnLwddsFMPREsIa", diff --git a/api/test_views.py b/api/test_views.py index a4bf987ee..0c74bd07b 100644 --- a/api/test_views.py +++ b/api/test_views.py @@ -10,9 +10,162 @@ EventFeaturedDocumentFactory, EventLinkFactory, ) +from api.models import Profile +from deployments.factories.user import UserFactory from main.test_case import APITestCase, SnapshotTestCase +class GuestUserPermissionTest(APITestCase): + def setUp(self): + # Create guest user + self.guest_user = User.objects.create(username="guest") + guest_profile = Profile.objects.get(user=self.guest_user) + guest_profile.limit_access_to_guest = True + guest_profile.save() + + # Create go user + self.go_user = User.objects.create(username="go-user") + go_user_profile = Profile.objects.get(user=self.go_user) + go_user_profile.limit_access_to_guest = False + go_user_profile.save() + + def test_guest_user_permission(self): + body = {} + guest_apis = [ + "/api/v2/add_subscription/", + "/api/v2/del_subscription/", + "/api/v2/external-token/", + "/api/v2/user/me/", + ] + id = 1 # NOTE: id is used just to test api that requires id, it doesnot indicate real id. It can be any number. + go_apis = [ + "/api/v2/dref/", + "/api/v2/dref-final-report/", + f"/api/v2/dref-final-report/{id}/publish/", + "/api/v2/dref-op-update/", + f"/api/v2/dref-op-update/{id}/publish/", + "/api/v2/dref-share/", + f"/api/v2/dref/{id}/publish/", + "/api/v2/flash-update/", + "/api/v2/flash-update-file/multiple/", + "/api/v2/local-units/", + f"/api/v2/local-units/{id}/validate/", + "/api/v2/pdf-export/", + "/api/v2/per-assessment/", + "/api/v2/per-document-upload/", + "/api/v2/per-file/multiple/", + "/api/v2/per-prioritization/", + "/api/v2/per-work-plan/", + "/api/v2/project/", + "/api/v2/dref-files/", + "/api/v2/dref-files/multiple/", + "/api/v2/field-report/", + "/api/v2/flash-update-file/", + "/api/v2/per-file/", + "/api/v2/share-flash-update/", + "/api/v2/add_cronjob_log/", + "/api/v2/profile/", + "/api/v2/subscription/", + "/api/v2/user/", + ] + + get_apis = [ + "/api/v2/dref/", + "/api/v2/dref-files/", + "/api/v2/dref-final-report/", + f"/api/v2/dref-final-report/{id}/", + "/api/v2/dref-op-update/", + f"/api/v2/dref/{id}/", + "/api/v2/field-report/", + f"/api/v2/field-report/{id}/", + "/api/v2/flash-update/", + "/api/v2/flash-update-file/", + f"/api/v2/flash-update/{id}/", + "/api/v2/language/", + f"/api/v2/language/{id}/", + "/api/v2/local-units/", + f"/api/v2/local-units/{id}/", + "/api/v2/ops-learning/", + f"/api/v2/ops-learning/{id}/", + f"/api/v2/pdf-export/{id}/", + "/api/v2/per-assessment/", + f"/api/v2/per-assessment/{id}/", + "/api/v2/per-document-upload/", + f"/api/v2/per-document-upload/{id}/", + "/api/v2/per-file/", + "/api/v2/per-overview/", + f"/api/v2/per-overview/{id}/", + "/api/v2/per-prioritization/", + f"/api/v2/per-prioritization/{id}/", + "/api/v2/per-work-plan/", + f"/api/v2/per-work-plan/{id}/", + "/api/v2/profile/", + f"/api/v2/profile/{id}/", + f"/api/v2/share-flash-update/{id}/", + "/api/v2/subscription/", + f"/api/v2/subscription/{id}/", + "/api/v2/users/", + f"/api/v2/users/{id}/", + # Exports + f"/api/v2/export-flash-update/{1}/", + ] + + # NOTE: With custom Content Negotiation: Look for main.utils.SpreadSheetContentNegotiation + get_custom_negotiation_apis = [ + f"/api/v2/export-per/{1}/", + ] + + go_apis_req_additional_perm = [ + "/api/v2/ops-learning/", + "/api/v2/per-overview/", + f"/api/v2/user/{id}/accepted_license_terms/", + f"/api/v2/language/{id}/bulk-action/", + ] + + self.authenticate(user=self.guest_user) + + def _success_check(response): # NOTE: Only handles json responses + self.assertNotIn(response.status_code, [401, 403], response.content) + self.assertNotIn(response.json().get("error_code"), [401, 403], response.content) + + def _failure_check(response, is_json=True): + self.assertIn(response.status_code, [401, 403], response.content) + if is_json: + self.assertIn(response.json()["error_code"], [401, 403], response.content) + + for api_url in get_custom_negotiation_apis: + headers = { + "Accept": "text/html", + } + response = self.client.get(api_url, headers=headers, stream=True) + _failure_check(response, is_json=False) + + # Guest user should not be able to access get apis that requires IsAuthenticated permission + for api_url in get_apis: + response = self.client.get(api_url) + _failure_check(response) + + # Guest user should not be able to hit post apis. + for api_url in go_apis + go_apis_req_additional_perm: + response = self.client.post(api_url, json=body) + _failure_check(response) + + # Guest user should be able to access guest apis + for api_url in guest_apis: + response = self.client.post(api_url, json=body) + _success_check(response) + + # Go user should be able to access go_apis + self.authenticate(user=self.go_user) + for api_url in go_apis: + response = self.client.post(api_url, json=body) + _success_check(response) + + for api_url in get_apis: + response = self.client.get(api_url) + _success_check(response) + + class AuthTokenTest(APITestCase): def setUp(self): user = User.objects.create(username="jo") @@ -78,7 +231,7 @@ class FieldReportTest(APITestCase): fixtures = ["DisasterTypes", "Actions"] def test_create_and_update(self): - user = User.objects.create(username="jo") + user = UserFactory(username="jo") region = models.Region.objects.create(name=1) country1 = models.Country.objects.create(name="abc", region=region) country2 = models.Country.objects.create(name="xyz") @@ -204,21 +357,24 @@ def test_country_snippet_visibility(self): self.assertEqual(response["count"], 0) # perform the request with an authenticated user - user = User.objects.create(username="foo") + user = UserFactory(username="foo") self.client.force_authenticate(user=user) response = self.client.get("/api/v2/country_snippet/").json() # one snippets available to anonymous user self.assertEqual(response["count"], 1) # perform the request with an ifrc user - user2 = User.objects.create(username="bar") + user2 = UserFactory(username="bar") user2.user_permissions.add(self.ifrc_permission) self.client.force_authenticate(user=user2) response = self.client.get("/api/v2/country_snippet/").json() self.assertEqual(response["count"], 2) # perform the request with a superuser - super_user = User.objects.create_superuser(username="baz", email="foo@baz.com", password="12345678") + super_user = UserFactory(username="baz", email="foo@baz.com", password="12345678") + super_user.is_superuser = True + super_user.save() + self.client.force_authenticate(user=super_user) response = self.client.get("/api/v2/country_snippet/").json() self.assertEqual(response["count"], 2) diff --git a/api/views.py b/api/views.py index 78d7e6db0..5fb1d214d 100644 --- a/api/views.py +++ b/api/views.py @@ -43,6 +43,7 @@ Statuses, ) from flash_update.models import FlashUpdate +from main.permissions import DenyGuestUserMutationPermission from notifications.models import Subscription, SurgeAlert from notifications.notification import send_notification from registrations.models import Pending, Recovery @@ -976,7 +977,7 @@ def post(self, request): class AddCronJobLog(APIView): authentication_classes = (authentication.TokenAuthentication,) - permissions_classes = (permissions.IsAuthenticated,) + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] def post(self, request): errors, created = CronJob.sync_cron(request.data) diff --git a/api/visibility_class.py b/api/visibility_class.py index 839c1df8a..31125cb89 100644 --- a/api/visibility_class.py +++ b/api/visibility_class.py @@ -17,7 +17,7 @@ def get_visibility_queryset(self, queryset): if queryset.model == Project: choices = VisibilityCharChoices - if self.request.user.is_authenticated: + if self.request.user.is_authenticated and not self.request.user.profile.limit_access_to_guest: if is_user_ifrc(self.request.user): return queryset else: @@ -37,7 +37,7 @@ class ReadOnlyVisibilityViewset(viewsets.ReadOnlyModelViewSet): def get_queryset(self): # FIXME: utils.py:43 # filter_visibility_by_auth(user=self.request.user, visibility_model_class=self.visibility_model_class) - if self.request.user.is_authenticated: + if self.request.user.is_authenticated and not self.request.user.profile.limit_access_to_guest: if is_user_ifrc(self.request.user): return self.visibility_model_class.objects.all() else: diff --git a/deployments/drf_views.py b/deployments/drf_views.py index f04268e49..c4abd3563 100644 --- a/deployments/drf_views.py +++ b/deployments/drf_views.py @@ -23,6 +23,7 @@ from api.models import Country, Region from api.view_filters import ListFilter from api.visibility_class import ReadOnlyVisibilityViewsetMixin +from main.permissions import DenyGuestUserMutationPermission from main.serializers import CsvListMixin from main.utils import is_tableau @@ -455,7 +456,7 @@ def get_permissions(self): if self.action in ["list", "retrieve"]: permission_classes = [] else: - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticated, DenyGuestUserMutationPermission] return [permission() for permission in permission_classes] diff --git a/deployments/factories/user.py b/deployments/factories/user.py index 8ccecf722..9da30434c 100644 --- a/deployments/factories/user.py +++ b/deployments/factories/user.py @@ -1,6 +1,8 @@ import factory from django.contrib.auth import get_user_model +from api.models import Profile + class UserFactory(factory.django.DjangoModelFactory): class Meta: @@ -8,3 +10,12 @@ class Meta: username = factory.Sequence(lambda n: "user_%d" % n) email = factory.Sequence(lambda n: "user_%d@ifrc.org" % n) + + @factory.post_generation + def create_profile(obj, create, extracted, **kwargs): + if create: + profile = Profile.objects.get(user=obj) + profile.limit_access_to_guest = False + profile.save(update_fields=["limit_access_to_guest"]) + # Set new profile to the user object + obj.profile = profile diff --git a/dref/views.py b/dref/views.py index 56bd5888c..cf04e7834 100644 --- a/dref/views.py +++ b/dref/views.py @@ -35,6 +35,7 @@ DrefShareUserSerializer, MiniDrefSerializer, ) +from main.permissions import DenyGuestUserMutationPermission def filter_dref_queryset_by_user_access(user, queryset): @@ -58,7 +59,7 @@ def filter_dref_queryset_by_user_access(user, queryset): class DrefViewSet(RevisionMixin, viewsets.ModelViewSet): serializer_class = DrefSerializer - permission_classes = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] filterset_class = DrefFilter def get_queryset(self): @@ -75,7 +76,7 @@ def get_queryset(self): url_path="publish", methods=["post"], serializer_class=DrefSerializer, - permission_classes=[permissions.IsAuthenticated, PublishDrefPermission], + permission_classes=[permissions.IsAuthenticated, PublishDrefPermission, DenyGuestUserMutationPermission], ) def get_published(self, request, pk=None, version=None): dref = self.get_object() @@ -88,7 +89,7 @@ def get_published(self, request, pk=None, version=None): class DrefOperationalUpdateViewSet(RevisionMixin, viewsets.ModelViewSet): serializer_class = DrefOperationalUpdateSerializer - permission_classes = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] filterset_class = DrefOperationalUpdateFilter def get_queryset(self): @@ -122,7 +123,7 @@ def get_queryset(self): url_path="publish", methods=["post"], serializer_class=DrefOperationalUpdateSerializer, - permission_classes=[permissions.IsAuthenticated, PublishDrefPermission], + permission_classes=[permissions.IsAuthenticated, PublishDrefPermission, DenyGuestUserMutationPermission], ) def get_published(self, request, pk=None, version=None): operational_update = self.get_object() @@ -135,7 +136,7 @@ def get_published(self, request, pk=None, version=None): class DrefFinalReportViewSet(RevisionMixin, viewsets.ModelViewSet): serializer_class = DrefFinalReportSerializer - permission_classes = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] def get_queryset(self): user = self.request.user @@ -154,7 +155,7 @@ def get_queryset(self): url_path="publish", methods=["post"], serializer_class=DrefFinalReportSerializer, - permission_classes=[permissions.IsAuthenticated, PublishDrefPermission], + permission_classes=[permissions.IsAuthenticated, PublishDrefPermission, DenyGuestUserMutationPermission], ) def get_published(self, request, pk=None, version=None): field_report = self.get_object() @@ -171,7 +172,7 @@ def get_published(self, request, pk=None, version=None): class DrefFileViewSet(mixins.ListModelMixin, mixins.CreateModelMixin, viewsets.GenericViewSet): - permission_class = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] serializer_class = DrefFileSerializer def get_queryset(self): @@ -184,7 +185,7 @@ def get_queryset(self): detail=False, url_path="multiple", methods=["POST"], - permission_classes=[permissions.IsAuthenticated], + permission_classes=[permissions.IsAuthenticated, DenyGuestUserMutationPermission], ) def multiple_file(self, request, pk=None, version=None): # converts querydict to original dict @@ -199,7 +200,9 @@ def multiple_file(self, request, pk=None, version=None): class CompletedDrefOperationsViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = CompletedDrefOperationsSerializer - permission_classes = [permissions.IsAuthenticated] + permission_classes = [ + permissions.IsAuthenticated, + ] filterset_class = CompletedDrefOperationsFilterSet queryset = DrefFinalReport.objects.filter(is_published=True).order_by("-created_at").distinct() @@ -210,7 +213,9 @@ def get_queryset(self): class ActiveDrefOperationsViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = MiniDrefSerializer - permission_classes = [permissions.IsAuthenticated] + permission_classes = [ + permissions.IsAuthenticated, + ] filterset_class = ActiveDrefFilterSet queryset = ( Dref.objects.prefetch_related("planned_interventions", "needs_identified", "national_society_actions", "users") @@ -225,7 +230,7 @@ def get_queryset(self): class DrefShareView(views.APIView): - permission_classes = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] @extend_schema(request=AddDrefUserSerializer, responses=None) def post(self, request): @@ -238,7 +243,9 @@ def post(self, request): class DrefShareUserViewSet(viewsets.ReadOnlyModelViewSet): - permissions_classes = [permissions.IsAuthenticated] + permission_classes = [ + permissions.IsAuthenticated, + ] serializer_class = DrefShareUserSerializer filterset_class = DrefShareUserFilterSet diff --git a/flash_update/test_views.py b/flash_update/test_views.py index 15e914fd4..2e641aecc 100644 --- a/flash_update/test_views.py +++ b/flash_update/test_views.py @@ -2,9 +2,9 @@ from unittest import mock from django.conf import settings -from django.contrib.auth.models import User import api.models as models +from deployments.factories.user import UserFactory from flash_update.factories import ( DonorFactory, DonorGroupFactory, @@ -21,7 +21,7 @@ class FlashUpdateTest(APITestCase): def setUp(self): - self.user = User.objects.create(username="jo") + self.user = UserFactory.create(username="jo") self.country1 = models.Country.objects.create(name="abc") self.country2 = models.Country.objects.create(name="xyz") self.district1 = models.District.objects.create(name="test district1", country=self.country1) @@ -127,7 +127,7 @@ def test_create_and_update(self, send_flash_update_email): self.assertEqual(updated.actions_taken_flash.count(), 2) def test_patch(self): - user = User.objects.create(username="test_abc") + user = UserFactory(username="test_abc") self.client.force_authenticate(user=user) with self.capture_on_commit_callbacks(execute=True): response1 = self.client.post("/api/v2/flash-update/", self.body, format="json").json() @@ -141,7 +141,7 @@ def test_patch(self): self.assertEqual(flash_id.share_with, FlashUpdate.FlashShareWith.IFRC_SECRETARIAT) def test_get_flash_update(self): - user1 = User.objects.create(username="abc") + user1 = UserFactory.create(username="abc") flash_update1, flash_update2, flash_update3 = FlashUpdateFactory.create_batch(3, created_by=user1) self.client.force_authenticate(user=user1) response1 = self.client.get("/api/v2/flash-update/").json() @@ -157,7 +157,7 @@ def test_get_flash_update(self): self.assertEqual(response["id"], flash_update1.id) # try with another user - user2 = User.objects.create(username="xyz") + user2 = UserFactory.create(username="xyz") self.client.force_authenticate(user=user2) flash_update4, flash_update5 = FlashUpdateFactory.create_batch(2, created_by=user2) response2 = self.client.get("/api/v2/flash-update/").json() @@ -167,13 +167,13 @@ def test_get_flash_update(self): self.assertNotIn([data["id"] for data in response2["results"]], [data["id"] for data in response1["results"]]) # try with users who has no any flash update created - user3 = User.objects.create(username="ram") + user3 = UserFactory.create(username="ram") self.client.force_authenticate(user=user3) response3 = self.client.get("/api/v2/flash-update/").json() self.assertEqual(response3["count"], 5) def test_filter(self): - user = User.objects.create(username="xyz") + user = UserFactory.create(username="xyz") self.client.force_authenticate(user=user) hazard_type1 = models.DisasterType.objects.create(name="disaster_type1") hazard_type2 = models.DisasterType.objects.create(name="disaster_type2") @@ -203,7 +203,7 @@ def test_validate_country_district(self): self.assert_400(response) def test_upload_file(self): - user = User.objects.create(username="flash_user") + user = UserFactory(username="flash_user") url = "/api/v2/flash-update-file/" data = {"file": open(self.file, "rb"), "caption": "test file"} self.client.force_authenticate(user=user) diff --git a/flash_update/views.py b/flash_update/views.py index cfad401fa..fb89a5cd6 100644 --- a/flash_update/views.py +++ b/flash_update/views.py @@ -14,6 +14,7 @@ from rest_framework.response import Response from api.serializers import ActionSerializer +from main.permissions import DenyGuestUserMutationPermission from .filter_set import FlashUpdateFilter from .models import ( @@ -38,7 +39,7 @@ class FlashUpdateViewSet(viewsets.ModelViewSet): serializer_class = FlashUpdateSerializer - permission_classes = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] filterset_class = FlashUpdateFilter def get_queryset(self): @@ -68,7 +69,7 @@ def get_queryset(self): class FlashUpdateFileViewSet(mixins.ListModelMixin, mixins.CreateModelMixin, viewsets.GenericViewSet): - permission_class = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] serializer_class = FlashGraphicMapSerializer def get_queryset(self): @@ -79,7 +80,7 @@ def get_queryset(self): detail=False, url_path="multiple", methods=["POST"], - permission_classes=[permissions.IsAuthenticated], + permission_classes=[permissions.IsAuthenticated, DenyGuestUserMutationPermission], ) def multiple_file(self, request, pk=None, version=None): files = [files[0] for files in dict((request.data).lists()).values()] @@ -112,11 +113,14 @@ class DonorsViewSet(viewsets.ReadOnlyModelViewSet): class ShareFlashUpdateViewSet(mixins.CreateModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet): queryset = FlashUpdateShare.objects.all() serializer_class = ShareFlashUpdateSerializer - permission_class = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] class ExportFlashUpdateView(views.APIView): - permission_classes = [permissions.IsAuthenticated] + permission_classes = [ + permissions.IsAuthenticated, + DenyGuestUserMutationPermission, + ] @extend_schema(request=None, responses=ExportFlashUpdateViewSerializer) def get(self, request, pk, format=None): diff --git a/lang/tests.py b/lang/tests.py index d62b8dff0..03561f14f 100644 --- a/lang/tests.py +++ b/lang/tests.py @@ -2,10 +2,11 @@ from unittest import mock from django.conf import settings -from django.contrib.auth.models import Permission, User +from django.contrib.auth.models import Permission from django.core import management from django.test import override_settings +from deployments.factories.user import UserFactory from lang.translation import IfrcTranslator from main.test_case import APITestCase @@ -130,7 +131,7 @@ def test_bulk_action(self): self.assertEqual(first_string_key, string_1["key"]) def test_user_me(self): - user = User.objects.create_user( + user = UserFactory.create( username="user@test.com", first_name="User", last_name="Toot", @@ -185,7 +186,7 @@ def test_user_me(self): ) def test_lang_api_permissions(self): - user = User.objects.create_user( + user = UserFactory( username="user@test.com", first_name="User", last_name="Toot", diff --git a/lang/views.py b/lang/views.py index 03e26f8fd..e34e7c801 100644 --- a/lang/views.py +++ b/lang/views.py @@ -9,6 +9,8 @@ from rest_framework.authentication import TokenAuthentication from rest_framework.decorators import action as djaction +from main.permissions import DenyGuestUserMutationPermission + from .models import String from .permissions import LangStringPermission from .serializers import ( @@ -24,7 +26,7 @@ class LanguageViewSet(viewsets.ViewSet): # TODO: Cache retrive response to file authentication_classes = (TokenAuthentication,) - permission_classes = (LangStringPermission,) + permission_classes = (LangStringPermission, DenyGuestUserMutationPermission) lookup_url_kwarg = "pk" @extend_schema(request=None, responses=LanguageListSerializer) diff --git a/local_units/views.py b/local_units/views.py index 7affa78a4..3a6dc0805 100644 --- a/local_units/views.py +++ b/local_units/views.py @@ -33,6 +33,7 @@ PrivateLocalUnitDetailSerializer, PrivateLocalUnitSerializer, ) +from main.permissions import DenyGuestUserMutationPermission class PrivateLocalUnitViewSet(viewsets.ModelViewSet): @@ -47,7 +48,7 @@ class PrivateLocalUnitViewSet(viewsets.ModelViewSet): "local_branch_name", "english_branch_name", ) - permission_classes = [permissions.IsAuthenticated, IsAuthenticatedForLocalUnit] + permission_classes = [permissions.IsAuthenticated, IsAuthenticatedForLocalUnit, DenyGuestUserMutationPermission] def get_serializer_class(self): if self.action == "list": @@ -63,7 +64,7 @@ def destroy(self, request, *args, **kwargs): url_path="validate", methods=["post"], serializer_class=PrivateLocalUnitSerializer, - permission_classes=[permissions.IsAuthenticated, ValidateLocalUnitPermission], + permission_classes=[permissions.IsAuthenticated, ValidateLocalUnitPermission, DenyGuestUserMutationPermission], ) def get_validate(self, request, pk=None, version=None): local_unit = self.get_object() diff --git a/main/permissions.py b/main/permissions.py index f0ae7c110..d58662e96 100644 --- a/main/permissions.py +++ b/main/permissions.py @@ -10,3 +10,26 @@ def has_permission(self, request, view): def has_object_permission(self, request, view, obj): return self.has_permission(request, view) + + +class DenyGuestUserMutationPermission(permissions.BasePermission): + """ + Custom permission to deny mutation and query actions for logged-in guest users. + + This permission class restricts all (read, write, update, delete) operations if the user is a guest. + """ + + def _has_permission(self, request, view): + # For mutation methods (POST, PUT, DELETE, etc.): + # Check if the user is authenticated. + if not bool(request.user and request.user.is_authenticated): + # Deny access if the user is not authenticated. + return False + + return not request.user.profile.limit_access_to_guest + + def has_permission(self, request, view): + return self._has_permission(request, view) + + def has_object_permission(self, request, view, obj): + return self._has_permission(request, view) diff --git a/main/utils.py b/main/utils.py index 155f86819..c3f9c3856 100644 --- a/main/utils.py +++ b/main/utils.py @@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile, _TemporaryFileWrapper import requests +from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.db import models, router from django.utils.dateparse import parse_date, parse_datetime @@ -168,4 +169,6 @@ def select_renderer(self, request, renderers, format_suffix): accepts = self.get_accept_list(request) if not set(self.MEDIA_TYPES).intersection(set(accepts)): raise exceptions.NotAcceptable(available_renderers=renderers) + if settings.TESTING: # NOTE: Quick hack to test permission of the views + return super().select_renderer(request, renderers, format_suffix) return (None, self.MEDIA_TYPES[0]) diff --git a/notifications/drf_views.py b/notifications/drf_views.py index 2447338cb..662468068 100644 --- a/notifications/drf_views.py +++ b/notifications/drf_views.py @@ -8,6 +8,7 @@ from deployments.models import MolnixTag from main.filters import CharInFilter +from main.permissions import DenyGuestUserMutationPermission from .models import Subscription, SurgeAlert from .serializers import ( # UnauthenticatedSurgeAlertSerializer, @@ -87,7 +88,7 @@ def get_serializer_class(self): class SubscriptionViewset(viewsets.ModelViewSet): serializer_class = SubscriptionSerializer authentication_classes = (TokenAuthentication,) - permission_classes = (IsAuthenticated,) + permission_classes = (IsAuthenticated, DenyGuestUserMutationPermission) search_fields = ("user__username", "rtype") # for /docs def get_queryset(self): diff --git a/per/drf_views.py b/per/drf_views.py index 94199914f..1caa75897 100644 --- a/per/drf_views.py +++ b/per/drf_views.py @@ -19,6 +19,7 @@ from api.models import Country from deployments.models import SectorTag +from main.permissions import DenyGuestUserMutationPermission from main.utils import SpreadSheetContentNegotiation from per.filter_set import ( PerDocumentFilter, @@ -234,7 +235,7 @@ def get_queryset(self): class PerOverviewViewSet(viewsets.ModelViewSet): serializer_class = PerOverviewSerializer - permission_classes = [IsAuthenticated, PerPermission] + permission_classes = [IsAuthenticated, PerPermission, DenyGuestUserMutationPermission] filterset_class = PerOverviewFilter ordering_fields = "__all__" get_request_user_regions = RegionRestrictedAdmin.get_request_user_regions @@ -246,7 +247,7 @@ def get_queryset(self): class ExportPerView(views.APIView): - permission_classes = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] content_negotiation_class = SpreadSheetContentNegotiation @@ -506,7 +507,7 @@ def get(self, request, pk, format=None): class NewPerWorkPlanViewSet(viewsets.ModelViewSet): - permission_classes = (IsAuthenticated, PerGeneralPermission) + permission_classes = (IsAuthenticated, PerGeneralPermission, DenyGuestUserMutationPermission) queryset = PerWorkPlan.objects.all() serializer_class = PerWorkPlanSerializer filterset_class = PerWorkPlanFilter @@ -523,7 +524,7 @@ class FormPrioritizationViewSet(viewsets.ModelViewSet): serializer_class = FormPrioritizationSerializer queryset = FormPrioritization.objects.all() filterset_class = PerPrioritizationFilter - permission_classes = (IsAuthenticated, PerGeneralPermission) + permission_classes = (IsAuthenticated, PerGeneralPermission, DenyGuestUserMutationPermission) ordering_fields = "__all__" @@ -574,7 +575,7 @@ def get_queryset(self): class FormAssessmentViewSet(viewsets.ModelViewSet): serializer_class = PerAssessmentSerializer - permission_classes = [permissions.IsAuthenticated, PerGeneralPermission] + permission_classes = [permissions.IsAuthenticated, PerGeneralPermission, DenyGuestUserMutationPermission] ordering_fields = "__all__" def get_queryset(self): @@ -590,7 +591,7 @@ def get_queryset(self): class PerFileViewSet(mixins.ListModelMixin, mixins.CreateModelMixin, viewsets.GenericViewSet): - permission_class = [permissions.IsAuthenticated] + permission_classes = [permissions.IsAuthenticated, DenyGuestUserMutationPermission] serializer_class = PerFileSerializer def get_queryset(self): @@ -603,7 +604,7 @@ def get_queryset(self): detail=False, url_path="multiple", methods=["POST"], - permission_classes=[permissions.IsAuthenticated], + permission_classes=[permissions.IsAuthenticated, DenyGuestUserMutationPermission], ) def multiple_file(self, request, pk=None, version=None): # converts querydict to original dict @@ -707,7 +708,7 @@ class OpsLearningViewset(viewsets.ModelViewSet): """ queryset = OpsLearning.objects.all() - permission_classes = [OpsLearningPermission] + permission_classes = [DenyGuestUserMutationPermission, OpsLearningPermission] filterset_class = OpsLearningFilter search_fields = ( "learning", @@ -809,7 +810,7 @@ class PerDocumentUploadViewSet(viewsets.ModelViewSet): queryset = PerDocumentUpload.objects.all() serializer_class = PerDocumentUploadSerializer filterset_class = PerDocumentFilter - permission_classes = [permissions.IsAuthenticated, PerDocumentUploadPermission] + permission_classes = [permissions.IsAuthenticated, PerDocumentUploadPermission, DenyGuestUserMutationPermission] def get_queryset(self): queryset = super().get_queryset() diff --git a/pyproject.toml b/pyproject.toml index 57c489da8..0bacb6164 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,7 @@ exclude = ''' )/ ''' # NOTE: Update in .pre-commit-config.yaml as well -extend-exclude = "^.*\\b(migrations)\\b.*$ (__pycache__|.*snap_test_.*\\.py|.+/+.+/+migrations/+.*)" +extend-exclude = "^.*\\b(migrations)\\b.*$ (__pycache__|.+/+.+/+migrations/+.*)" [tool.isort] profile = "black" @@ -138,7 +138,6 @@ multi_line_output = 3 # NOTE: Update in .pre-commit-config.yaml as well skip = [ "**/__pycache__", - "**/snap_test_*.py", ".venv/", "legacy/", "**/migrations/*.py", diff --git a/registrations/views.py b/registrations/views.py index 751aff0e0..0a7586145 100644 --- a/registrations/views.py +++ b/registrations/views.py @@ -147,7 +147,9 @@ def get(self, request): class UserExternalTokenViewset(viewsets.ModelViewSet): serializer_class = UserExternalTokenSerializer - permission_classes = [permissions.IsAuthenticated] + permission_classes = [ + permissions.IsAuthenticated, + ] def get_queryset(self): return UserExternalToken.objects.filter(user=self.request.user)