diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b42533e0..60ea7baf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,9 +25,9 @@ repos: - --use-current-year - --no-extra-eol - --detect-license-in-X-top-lines=5 - - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.10.1' - hooks: - - id: mypy - exclude: "examples|tests|venv|ci|docs|conftest.py" - additional_dependencies: [types-pyyaml>=6.0] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.10.1' + hooks: + - id: mypy + exclude: "examples|tests|venv|ci|docs|conftest.py" + additional_dependencies: [types-pyyaml>=6.0] diff --git a/kr8s/__init__.py b/kr8s/__init__.py index 64de4ee6..2504033d 100644 --- a/kr8s/__init__.py +++ b/kr8s/__init__.py @@ -7,7 +7,7 @@ Both APIs are functionally identical with the same objects, method signatures and return values. """ from functools import partial, update_wrapper -from typing import Dict, List, Optional, Union +from typing import Dict, Optional, Type, Union from . import asyncio, objects, portforward from ._api import ALL @@ -55,11 +55,11 @@ class Api(_AsyncApi): def get( kind: str, - *names: List[str], + *names: str, namespace: Optional[str] = None, label_selector: Optional[Union[str, Dict]] = None, field_selector: Optional[Union[str, Dict]] = None, - as_object: Optional[object] = None, + as_object: Optional[Type] = None, allow_unknown_type: bool = True, api=None, **kwargs, diff --git a/kr8s/_api.py b/kr8s/_api.py index 42fdb567..f9b12131 100644 --- a/kr8s/_api.py +++ b/kr8s/_api.py @@ -192,7 +192,9 @@ async def call_api( ) from e elif e.response.status_code >= 500: raise ServerError( - str(e), status=e.response.status_code, response=e.response + str(e), + status=str(e.response.status_code), + response=e.response, ) from e raise except ssl.SSLCertVerificationError: @@ -369,6 +371,7 @@ async def async_get_kind( if isinstance(kind, type): obj_cls = kind else: + namespaced: Optional[bool] = None try: kind, namespaced = await self.async_lookup_kind(kind) except ServerError as e: @@ -378,9 +381,12 @@ async def async_get_kind( obj_cls = get_class(kind, _asyncio=self._asyncio) except KeyError as e: if allow_unknown_type: - obj_cls = new_class( - kind, namespaced=namespaced, asyncio=self._asyncio - ) + if namespaced is not None: + obj_cls = new_class( + kind, namespaced=namespaced, asyncio=self._asyncio + ) + else: + obj_cls = new_class(kind, asyncio=self._asyncio) else: raise e params = params or None @@ -499,7 +505,7 @@ async def async_watch( field_selector: Optional[Union[str, Dict]] = None, since: Optional[str] = None, allow_unknown_type: bool = True, - ) -> AsyncGenerator[Tuple[str, object], None]: + ) -> AsyncGenerator[Tuple[str, APIObject], None]: """Watch a Kubernetes resource.""" async with self.async_get_kind( kind, @@ -515,7 +521,7 @@ async def async_watch( event = json.loads(line) yield event["type"], obj_cls(event["object"], api=self) - async def api_resources(self) -> dict: + async def api_resources(self) -> List[Dict]: """Get the Kubernetes API resources.""" return await self.async_api_resources() diff --git a/kr8s/_async_utils.py b/kr8s/_async_utils.py index 64e1a0dd..3967bb8e 100644 --- a/kr8s/_async_utils.py +++ b/kr8s/_async_utils.py @@ -44,6 +44,8 @@ from typing_extensions import ParamSpec import anyio +import anyio.from_thread +import anyio.to_thread T = TypeVar("T") C = TypeVar("C") @@ -65,11 +67,12 @@ class Portal: """ - _instance = None - _portal = None + _instance: Portal + _portal: anyio.from_thread.BlockingPortal + thread: Thread def __new__(cls): - if cls._instance is None: + if not hasattr(cls, "_instance"): cls._instance = super(Portal, cls).__new__(cls) cls._instance.thread = Thread( target=anyio.run, args=[cls._instance._run], name="Kr8sSyncRunnerThread" @@ -86,7 +89,7 @@ async def _run(self): def call(self, func: Callable[P, Awaitable[T]], *args, **kwargs) -> T: """Call a coroutine in the runner loop and return the result.""" # On first call the thread has to start the loop, so we need to wait for it - while not self._portal: + while not hasattr(self, "_portal"): pass return self._portal.call(func, *args, **kwargs) diff --git a/kr8s/_auth.py b/kr8s/_auth.py index 6747f974..f073d8ea 100644 --- a/kr8s/_auth.py +++ b/kr8s/_auth.py @@ -90,7 +90,7 @@ async def ssl_context(self): # If no cert information is provided, fall back to default verification return True sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - if self.client_key_file: + if self.client_key_file and self.client_cert_file: sslcontext.load_cert_chain( certfile=self.client_cert_file, keyfile=self.client_key_file, diff --git a/kr8s/_exec.py b/kr8s/_exec.py index 3ed7ff80..c0e97ff4 100644 --- a/kr8s/_exec.py +++ b/kr8s/_exec.py @@ -53,6 +53,7 @@ def __init__( async def run( self, ) -> AsyncGenerator[Exec, CompletedExec]: + assert self._resource.api async with self._resource.api.open_websocket( version=self._resource.version, url=f"{self._resource.endpoint}/{self._resource.name}/exec", @@ -70,7 +71,7 @@ async def run( if ws.subprotocol != "v5.channel.k8s.io": raise ExecError( "Stdin is not supported with protocol " - f"{ws.protocol}, only with v5.channel.k8s.io" + f"{ws.subprotocol}, only with v5.channel.k8s.io" ) if isinstance(self._stdin, str): await ws.send_bytes(STDIN_CHANNEL.to_bytes() + self._stdin.encode()) # type: ignore @@ -132,7 +133,7 @@ class CompletedExec: Similar to subprocess.CompletedProcess. """ - args: Union[str | List[str]] + args: Union[str, List[str]] stdout: bytes stderr: bytes returncode: int diff --git a/kr8s/_objects.py b/kr8s/_objects.py index bb018451..4014b532 100644 --- a/kr8s/_objects.py +++ b/kr8s/_objects.py @@ -18,6 +18,7 @@ Tuple, Type, Union, + cast, ) import anyio @@ -76,7 +77,7 @@ def __init__( "resource must be a dict, string, have an obj attribute or a to_dict method" ) if namespace is not None: - self.raw["metadata"]["namespace"] = namespace + self.raw["metadata"]["namespace"] = namespace # type: ignore self._api = api if self._api is None and not self._asyncio: self._api = kr8s.api() @@ -152,6 +153,7 @@ def name(self, value: str) -> None: def namespace(self) -> Optional[str]: """Namespace of the Kubernetes resource.""" if self.namespaced: + assert self.api return self.raw.get("metadata", {}).get("namespace", self.api.namespace) return None @@ -293,6 +295,7 @@ async def exists(self, ensure=False) -> bool: async def async_exists(self, ensure=False) -> bool: """Check if this object exists in Kubernetes.""" try: + assert self.api async with self.api.call_api( "GET", version=self.version, @@ -311,6 +314,7 @@ async def async_exists(self, ensure=False) -> bool: async def create(self) -> None: """Create this object in Kubernetes.""" + assert self.api async with self.api.call_api( "POST", version=self.version, @@ -326,6 +330,7 @@ async def delete(self, propagation_policy: Optional[str] = None) -> None: if propagation_policy: data["propagationPolicy"] = propagation_policy try: + assert self.api async with self.api.call_api( "DELETE", version=self.version, @@ -346,6 +351,7 @@ async def refresh(self) -> None: async def async_refresh(self) -> None: """Refresh this object from Kubernetes.""" try: + assert self.api async with self.api.call_api( "GET", version=self.version, @@ -372,6 +378,7 @@ async def async_patch(self, patch: Dict, *, subresource=None, type=None) -> None if subresource: url = f"{url}/{subresource}" try: + assert self.api async with self.api.call_api( "PATCH", version=self.version, @@ -401,6 +408,7 @@ async def scale(self, replicas: Optional[int] = None) -> None: async def async_watch(self): """Watch this object in Kubernetes.""" since = self.metadata.get("resourceVersion") + assert self.api async for event, obj in self.api.async_watch( self.endpoint, namespace=self.namespace, @@ -932,6 +940,7 @@ async def logs( params["limitBytes"] = int(limit_bytes) with contextlib.suppress(httpx.ReadTimeout): + assert self.api async with self.api.call_api( "GET", version=self.version, @@ -990,7 +999,7 @@ async def async_exec( command: List[str], *, container: Optional[str] = None, - stdin: Optional[Union[str | BinaryIO]] = None, + stdin: Optional[Union[str, BinaryIO]] = None, stdout: Optional[BinaryIO] = None, stderr: Optional[BinaryIO] = None, check: bool = True, @@ -1018,7 +1027,7 @@ async def exec( command: List[str], *, container: Optional[str] = None, - stdin: Optional[Union[str | BinaryIO]] = None, + stdin: Optional[Union[str, BinaryIO]] = None, stdout: Optional[BinaryIO] = None, stderr: Optional[BinaryIO] = None, check: bool = True, @@ -1266,6 +1275,7 @@ async def async_proxy_http_request( ) -> httpx.Response: if port is None: port = self.raw["spec"]["ports"][0]["port"] + assert self.api async with self.api.call_api( method, version=self.version, @@ -1301,11 +1311,20 @@ async def ready_pods(self) -> List[Pod]: async def async_ready_pods(self) -> List[Pod]: """Return a list of ready Pods for this Service.""" + assert self.api pods = await self.api.async_get( "pods", label_selector=dict_to_selector(self.spec["selector"]), namespace=self.namespace, ) + if isinstance(pods, Pod): + pods = [pods] + elif isinstance(pods, List) and all(isinstance(pod, Pod) for pod in pods): + # The all(isinstance(...) for ...) check doesn't seem to narrow the type + # correctly in pyright so we need to explicitly use cast + pods = cast(List[Pod], pods) + else: + raise TypeError(f"Unexpected type {type(pods)} returned from API") return [pod for pod in pods if await pod.async_ready()] async def ready(self) -> bool: @@ -1397,12 +1416,19 @@ class Deployment(APIObject): async def pods(self) -> List[Pod]: """Return a list of Pods for this Deployment.""" + assert self.api pods = await self.api.async_get( "pods", label_selector=dict_to_selector(self.spec["selector"]["matchLabels"]), namespace=self.namespace, ) - return pods + if isinstance(pods, Pod): + return [pods] + if isinstance(pods, List) and all(isinstance(pod, Pod) for pod in pods): + # The all(isinstance(...) for ...) check doesn't seem to narrow the type + # correctly in pyright so we need to explicitly use cast + return cast(List[Pod], pods) + raise TypeError(f"Unexpected type {type(pods)} returned from API") async def ready(self): """Check if the deployment is ready.""" diff --git a/kr8s/_portforward.py b/kr8s/_portforward.py index ac509b47..368f53e1 100644 --- a/kr8s/_portforward.py +++ b/kr8s/_portforward.py @@ -15,6 +15,7 @@ import sniffio from ._exceptions import ConnectionClosedError +from ._types import APIObjectWithPods if TYPE_CHECKING: from ._objects import APIObject @@ -112,6 +113,7 @@ async def __aenter__(self, *args, **kwargs): return await self._run_task.__aenter__(*args, **kwargs) async def __aexit__(self, *args, **kwargs): + assert self._run_task return await self._run_task.__aexit__(*args, **kwargs) async def start(self) -> int: @@ -181,7 +183,7 @@ async def _select_pod(self) -> APIObject: if isinstance(self._resource, Pod): return self._resource - if hasattr(self._resource, "async_ready_pods"): + if isinstance(self._resource, APIObjectWithPods): try: return random.choice(await self._resource.async_ready_pods()) except IndexError: @@ -196,6 +198,7 @@ async def _connect_websocket(self): if not self.pod: self.pod = await self._select_pod() try: + assert self.pod.api async with self.pod.api.open_websocket( version=self.pod.version, url=f"{self.pod.endpoint}/{self.pod.name}/portforward", diff --git a/kr8s/_types.py b/kr8s/_types.py index 75428984..019d35f7 100644 --- a/kr8s/_types.py +++ b/kr8s/_types.py @@ -1,6 +1,16 @@ # SPDX-FileCopyrightText: Copyright (c) 2024, Kr8s Developers (See LICENSE for list) # SPDX-License-Identifier: BSD 3-Clause License from os import PathLike -from typing import Union +from typing import TYPE_CHECKING, List, Protocol, Union, runtime_checkable PathType = Union[str, PathLike[str]] + +if TYPE_CHECKING: + from ._objects import Pod + + +@runtime_checkable +class APIObjectWithPods(Protocol): + """An APIObject subclass that contains other Pod objects.""" + + async def async_ready_pods(self) -> List["Pod"]: ... diff --git a/kr8s/asyncio/_helpers.py b/kr8s/asyncio/_helpers.py index 28f416ad..46566541 100644 --- a/kr8s/asyncio/_helpers.py +++ b/kr8s/asyncio/_helpers.py @@ -1,19 +1,20 @@ # SPDX-FileCopyrightText: Copyright (c) 2023-2024, Kr8s Developers (See LICENSE for list) # SPDX-License-Identifier: BSD 3-Clause License -from typing import Dict, List, Optional, Union +from typing import Dict, Optional, Type, Union from kr8s._api import Api +from kr8s._objects import APIObject from ._api import api as _api async def get( kind: str, - *names: List[str], + *names: str, namespace: Optional[str] = None, label_selector: Optional[Union[str, Dict]] = None, field_selector: Optional[Union[str, Dict]] = None, - as_object: Optional[object] = None, + as_object: Optional[Type[APIObject]] = None, allow_unknown_type: bool = True, api=None, _asyncio=True, diff --git a/pyproject.toml b/pyproject.toml index e55bf8b3..f037d4cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,3 +144,13 @@ convention = "google" [tool.mypy] exclude = ["examples", "tests", "venv", "ci", "docs", "conftest.py"] + +[tool.pyright] +exclude = ["examples", "**/tests", "venv", "ci", "docs", "conftest.py"] + +# We often override corotuines with sync methods so this is not useful +reportIncompatibleMethodOverride = "none" + +# When run with pre-commit, we don't want to report missing imports +reportMissingImports = "none" +reportMissingModuleSource = "none"