Skip to content

Commit

Permalink
Merge pull request #986 from ImageMarkup/add-type-checking
Browse files Browse the repository at this point in the history
  • Loading branch information
danlamanna authored Oct 15, 2024
2 parents 6c903c0 + 49470d5 commit bd44e89
Show file tree
Hide file tree
Showing 60 changed files with 330 additions and 292 deletions.
4 changes: 3 additions & 1 deletion isic/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Callable

from ninja.security import HttpBearer, django_auth
from oauth2_provider.oauth2_backends import get_oauthlib_core

Expand Down Expand Up @@ -37,6 +39,6 @@ def authenticate(self, request, token):


# The lambda _: True is to handle the case where a user doesn't pass any authentication.
allow_any = [django_auth, OAuth2AuthBearer("any"), lambda _: True]
allow_any: list[Callable] = [django_auth, OAuth2AuthBearer("any"), lambda _: True]
is_authenticated = [django_auth, OAuth2AuthBearer("is_authenticated")]
is_staff = [SessionAuthStaffUser(), OAuth2AuthBearer("is_staff")]
6 changes: 3 additions & 3 deletions isic/core/api/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class CollectionLicenseBreakdown(Schema):
summary="Retrieve a breakdown of the licenses of the specified collection.",
include_in_schema=False,
)
def collection_license_breakdown(request, id: int) -> CollectionLicenseBreakdown:
def collection_license_breakdown(request, id: int) -> dict[str, int]:
qs = get_visible_objects(request.user, "core.view_collection")
collection = get_object_or_404(qs, id=id)
images = get_visible_objects(request.user, "core.view_image", collection.images.distinct())
Expand Down Expand Up @@ -179,7 +179,7 @@ def collection_populate_from_search(request, id: int, payload: SearchQueryIn):
if collection.locked:
return 409, {"error": "Collection is locked"}

if collection.public and payload.to_queryset(request.user).private().exists():
if collection.public and payload.to_queryset(request.user).private().exists(): # type: ignore[attr-defined]
return 409, {"error": "Collection is public and cannot contain private images."}

# Pass data instead of validated_data because the celery task is going to revalidate.
Expand All @@ -195,7 +195,7 @@ def collection_populate_from_search(request, id: int, payload: SearchQueryIn):


class IsicIdList(Schema):
isic_ids: conlist(constr(pattern=ISIC_ID_REGEX), max_length=500)
isic_ids: conlist(constr(pattern=ISIC_ID_REGEX), max_length=500) # type: ignore[valid-type]

model_config = {"extra": "forbid"}

Expand Down
4 changes: 2 additions & 2 deletions isic/core/api/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Meta:
metadata: dict

@staticmethod
def resolve_files(image: Image) -> dict:
def resolve_files(image: Image) -> ImageFilesOut:
if settings.ISIC_PLACEHOLDER_IMAGES:
full_url = f"https://picsum.photos/seed/{image.id}/1000"
thumbnail_url = f"https://picsum.photos/seed/{image.id}/256"
Expand All @@ -71,7 +71,7 @@ def resolve_metadata(image: Image) -> dict:

for key, value in image.metadata.items():
try:
metadata[FIELD_REGISTRY[key].type][key] = value
metadata[FIELD_REGISTRY[key].type][key] = value # type: ignore[index]
except KeyError:
# it's probably a computed field
for computed_field in image.accession.computed_fields:
Expand Down
6 changes: 3 additions & 3 deletions isic/core/api/user.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from datetime import datetime

from django.contrib.auth.models import User
from django.http.request import HttpRequest
from django.utils import timezone
from ninja import Field, ModelSchema, Router

from isic.auth import is_authenticated
from isic.types import AuthenticatedHttpRequest

router = Router()

Expand Down Expand Up @@ -37,12 +37,12 @@ def resolve_full_name(obj: User):
include_in_schema=True,
auth=is_authenticated,
)
def user_me(request: HttpRequest):
def user_me(request: AuthenticatedHttpRequest):
return request.user


@router.put("/accept-terms/", include_in_schema=False, auth=is_authenticated)
def accept_terms_of_use(request: HttpRequest):
def accept_terms_of_use(request: AuthenticatedHttpRequest):
if not request.user.profile.accepted_terms:
request.user.profile.accepted_terms = timezone.now()
request.user.profile.save(update_fields=["accepted_terms"])
Expand Down
4 changes: 3 additions & 1 deletion isic/core/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def to_es(self, key: SearchTermKey) -> dict:
elif key.field_lookup == "image_type" and self.value == "overview":
self.value = "clinical: overview"

term: dict

if self.value == "*":
term = {"exists": {"field": key.field_lookup}}
elif self.value.startswith("*"):
Expand Down Expand Up @@ -296,7 +298,7 @@ def make_parser( # noqa: C901
) -> ParserElement:
def make_term_keyword(name):
term = Optional("-") + Keyword(name)
if term_converter:
if term_converter is not None:
term.add_parse_action(term_converter)
return term

Expand Down
29 changes: 0 additions & 29 deletions isic/core/management/commands/merge-collections.py

This file was deleted.

2 changes: 1 addition & 1 deletion isic/core/management/commands/set_isic_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def add_staff_group():
User,
ZipUpload,
]:
content_type = ContentType.objects.get_for_model(model)
content_type = ContentType.objects.get_for_model(model) # type: ignore[arg-type]
for permission in ["view", "change"]:
group.permissions.add(
Permission.objects.get(
Expand Down
8 changes: 4 additions & 4 deletions isic/core/migrations/0002_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class Migration(migrations.Migration):
migrations.AddConstraint(
model_name="imageshare",
constraint=models.CheckConstraint(
check=models.Q(("grantor", models.F("grantee")), _negated=True),
condition=models.Q(("grantor", models.F("grantee")), _negated=True),
name="imageshare_grantor_grantee_diff_check",
),
),
Expand All @@ -176,7 +176,7 @@ class Migration(migrations.Migration):
migrations.AddConstraint(
model_name="girderimage",
constraint=models.CheckConstraint(
check=models.Q(
condition=models.Q(
("status", "unknown"),
("status", "non_image"),
("accession__isnull", False),
Expand All @@ -188,7 +188,7 @@ class Migration(migrations.Migration):
migrations.AddConstraint(
model_name="girderimage",
constraint=models.CheckConstraint(
check=models.Q(
condition=models.Q(
("status", "non_image"),
models.Q(("stripped_blob_dm", ""), _negated=True),
_connector="OR",
Expand All @@ -199,7 +199,7 @@ class Migration(migrations.Migration):
migrations.AddConstraint(
model_name="collectionshare",
constraint=models.CheckConstraint(
check=models.Q(("grantor", models.F("grantee")), _negated=True),
condition=models.Q(("grantor", models.F("grantee")), _negated=True),
name="collectionshare_grantor_grantee_diff_check",
),
),
Expand Down
6 changes: 3 additions & 3 deletions isic/core/models/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Meta(TimeStampedModel.Meta):
shares = models.ManyToManyField(
User,
through="CollectionShare",
through_fields=["collection", "grantee"],
through_fields=("collection", "grantee"),
related_name="collection_shares",
)

Expand Down Expand Up @@ -128,7 +128,7 @@ def shared_with(self):
]

def full_clean(self, exclude=None, validate_unique=True): # noqa: FBT002
if self.pk and self.public and self.images.private().exists():
if self.pk and self.public and self.images.private().exists(): # type: ignore[attr-defined]
raise ValidationError("Can't make collection public, it contains private images.")

return super().full_clean(exclude=exclude, validate_unique=validate_unique)
Expand All @@ -139,7 +139,7 @@ class Meta(TimeStampedModel.Meta):
constraints = [
CheckConstraint(
name="collectionshare_grantor_grantee_diff_check",
check=~Q(grantor=F("grantee")),
condition=~Q(grantor=F("grantee")),
),
UniqueConstraint(
name="collectionshare_grantor_collection_grantee_unique",
Expand Down
4 changes: 2 additions & 2 deletions isic/core/models/girder_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ class Meta:
constraints = [
models.CheckConstraint(
name="non_unknown_have_accession",
check=Q(status=GirderImageStatus.UNKNOWN)
condition=Q(status=GirderImageStatus.UNKNOWN)
| Q(status=GirderImageStatus.NON_IMAGE)
| Q(accession__isnull=False),
),
models.CheckConstraint(
name="non_non_image_have_stripped_blob_dm",
check=Q(status=GirderImageStatus.NON_IMAGE) | ~Q(stripped_blob_dm=""),
condition=Q(status=GirderImageStatus.NON_IMAGE) | ~Q(stripped_blob_dm=""),
),
]

Expand Down
37 changes: 26 additions & 11 deletions isic/core/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,23 @@
from .isic_id import IsicId


class ImageQuerySet(models.QuerySet):
class ImageQuerySet(models.QuerySet["Image"]):
def public(self):
return self.filter(public=True)

def private(self):
return self.filter(public=False)

def from_search_query(self, query: str):
if query == "":
return self
return self.filter(parse_query(django_parser, query) or Q())


class ImageManager(models.Manager["Image"]):
def get_queryset(self) -> ImageQuerySet:
return ImageQuerySet(self.model, using=self._db)

def with_elasticsearch_properties(self):
return self.select_related("accession__cohort").annotate(
coll_pks=ArrayAgg("collections", distinct=True, default=[]),
Expand Down Expand Up @@ -67,9 +78,9 @@ class Meta(CreationSortedTimeStampedModel.Meta):
# index is used because public is filtered in every permissions check
public = models.BooleanField(default=False, db_index=True)

shares = models.ManyToManyField(User, through="ImageShare", through_fields=["image", "grantee"])
shares = models.ManyToManyField(User, through="ImageShare", through_fields=("image", "grantee"))

objects = ImageQuerySet.as_manager()
objects = ImageManager()

def __str__(self):
return self.isic_id
Expand Down Expand Up @@ -99,18 +110,22 @@ def metadata(self) -> dict:
"""
image_metadata = deepcopy(self.accession.metadata)

for field in Accession.computed_fields:
if field.input_field_name in image_metadata:
computed_output_fields = field.transformer(image_metadata[field.input_field_name])
for computed_field in Accession.computed_fields:
if computed_field.input_field_name in image_metadata:
computed_output_fields = computed_field.transformer(
image_metadata[computed_field.input_field_name]
)

if computed_output_fields:
image_metadata.update(computed_output_fields)

del image_metadata[field.input_field_name]
del image_metadata[computed_field.input_field_name]

for field in Accession.remapped_internal_fields:
if getattr(self.accession, field.csv_field_name) is not None:
image_metadata[field.csv_field_name] = getattr(self.accession, field.csv_field_name)
for remapped_field in Accession.remapped_internal_fields:
if getattr(self.accession, remapped_field.csv_field_name) is not None:
image_metadata[remapped_field.csv_field_name] = getattr(
self.accession, remapped_field.csv_field_name
)

if "legacy_dx" in image_metadata:
image_metadata["diagnosis"] = image_metadata["legacy_dx"]
Expand Down Expand Up @@ -176,7 +191,7 @@ class Meta(TimeStampedModel.Meta):
constraints = [
CheckConstraint(
name="imageshare_grantor_grantee_diff_check",
check=~Q(grantor=F("grantee")),
condition=~Q(grantor=F("grantee")),
),
models.UniqueConstraint(
name="imageshare_grantor_image_grantee_unique",
Expand Down
6 changes: 3 additions & 3 deletions isic/core/permissions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import wraps
from urllib.parse import urlparse

import django.apps
Expand All @@ -12,7 +13,6 @@
from django.db.models.query import QuerySet
from django.http.request import HttpRequest
from django.shortcuts import get_object_or_404, resolve_url
from django.utils.functional import wraps
from ninja.security.session import SessionAuth


Expand All @@ -34,7 +34,7 @@ def view_staff(user_obj, _=None):
return user_obj.is_staff


User.perms_class = UserPermissions
User.perms_class = UserPermissions # type: ignore[attr-defined]


ISIC_PERMS_MAP = {}
Expand Down Expand Up @@ -99,7 +99,7 @@ def _wrapped_view(request, *args, **kwargs):
"If model should be looked up from "
"string it needs format: 'app_label.ModelClass'"
)
model = apps.get_model(*splitted)
model = apps.get_model(*splitted) # type: ignore[arg-type]
elif issubclass(model.__class__, Model | ModelBase | QuerySet):
pass
else:
Expand Down
17 changes: 10 additions & 7 deletions isic/core/search.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from copy import deepcopy
from functools import lru_cache, partial
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict

from cachalot.api import cachalot_disabled
from django.conf import settings
from django.contrib.auth.models import User
from django.contrib.auth.models import AnonymousUser, User
from django.db.models.query import QuerySet
from isic_metadata import FIELD_REGISTRY
from isic_metadata.fields import FitzpatrickSkinType, ImageTypeEnum
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_elasticsearch_client() -> "OpenSearch":

# TODO: investigate using retryable requests with transport_class
RetryOnTimeoutTransport = partial(Transport, retry_on_timeout=True) # noqa: N806
return OpenSearch(settings.ISIC_ELASTICSEARCH_URI, transport_class=RetryOnTimeoutTransport)
return OpenSearch(settings.ISIC_ELASTICSEARCH_URI, transport_class=RetryOnTimeoutTransport) # type: ignore[arg-type]


def maybe_create_index() -> None:
Expand Down Expand Up @@ -157,7 +157,7 @@ def _prettify_facets(facets: dict[str, Any]) -> dict[str, Any]:
# sort the values of image_type buckets by the element in the key field
facets["image_type"]["buckets"] = sorted(
facets["image_type"]["buckets"],
key=lambda x: ImageTypeEnum(x["key"])._sort_order_,
key=lambda x: ImageTypeEnum(x["key"])._sort_order_, # type: ignore[attr-defined]
)

return facets
Expand All @@ -183,7 +183,10 @@ def facets(query: dict | None = None, collections: list[int] | None = None) -> d
body=counts_body,
)["aggregations"]

facets_body = {
FacetsBody = TypedDict( # noqa: UP013
"FacetsBody", {"size": int, "aggs": dict, "query": NotRequired[dict | None]}
)
facets_body: FacetsBody = {
"size": 0,
"aggs": deepcopy(DEFAULT_SEARCH_AGGREGATES),
}
Expand Down Expand Up @@ -215,7 +218,7 @@ def facets(query: dict | None = None, collections: list[int] | None = None) -> d


def build_elasticsearch_query(
query: dict, user: User, collection_pks: list[int] | None = None
query: dict, user: User | AnonymousUser, collection_pks: list[int] | None = None
) -> dict:
"""
Build an elasticsearch query from an elasticsearch query body, a user, and collection ids.
Expand All @@ -237,7 +240,7 @@ def build_elasticsearch_query(
else:
visible_collection_pks = None

query_dict = {"bool": {"filter": [query]}} if query else {"bool": {}}
query_dict: dict = {"bool": {"filter": [query]}} if query else {"bool": {}}

if visible_collection_pks is not None:
query_dict["bool"].setdefault("filter", [])
Expand Down
Loading

0 comments on commit bd44e89

Please sign in to comment.