Skip to content

Commit

Permalink
Some some type errors found with pyright (#448)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson authored Jul 11, 2024
1 parent 36b9d02 commit 333d833
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 31 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ repos:
- --use-current-year
- --no-extra-eol
- --detect-license-in-X-top-lines=5
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.10.1'
hooks:
- id: mypy
exclude: "examples|tests|venv|ci|docs|conftest.py"
additional_dependencies: [types-pyyaml>=6.0]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.10.1'
hooks:
- id: mypy
exclude: "examples|tests|venv|ci|docs|conftest.py"
additional_dependencies: [types-pyyaml>=6.0]
6 changes: 3 additions & 3 deletions kr8s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Both APIs are functionally identical with the same objects, method signatures and return values.
"""
from functools import partial, update_wrapper
from typing import Dict, List, Optional, Union
from typing import Dict, Optional, Type, Union

from . import asyncio, objects, portforward
from ._api import ALL
Expand Down Expand Up @@ -55,11 +55,11 @@ class Api(_AsyncApi):

def get(
kind: str,
*names: List[str],
*names: str,
namespace: Optional[str] = None,
label_selector: Optional[Union[str, Dict]] = None,
field_selector: Optional[Union[str, Dict]] = None,
as_object: Optional[object] = None,
as_object: Optional[Type] = None,
allow_unknown_type: bool = True,
api=None,
**kwargs,
Expand Down
18 changes: 12 additions & 6 deletions kr8s/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ async def call_api(
) from e
elif e.response.status_code >= 500:
raise ServerError(
str(e), status=e.response.status_code, response=e.response
str(e),
status=str(e.response.status_code),
response=e.response,
) from e
raise
except ssl.SSLCertVerificationError:
Expand Down Expand Up @@ -369,6 +371,7 @@ async def async_get_kind(
if isinstance(kind, type):
obj_cls = kind
else:
namespaced: Optional[bool] = None
try:
kind, namespaced = await self.async_lookup_kind(kind)
except ServerError as e:
Expand All @@ -378,9 +381,12 @@ async def async_get_kind(
obj_cls = get_class(kind, _asyncio=self._asyncio)
except KeyError as e:
if allow_unknown_type:
obj_cls = new_class(
kind, namespaced=namespaced, asyncio=self._asyncio
)
if namespaced is not None:
obj_cls = new_class(
kind, namespaced=namespaced, asyncio=self._asyncio
)
else:
obj_cls = new_class(kind, asyncio=self._asyncio)
else:
raise e
params = params or None
Expand Down Expand Up @@ -499,7 +505,7 @@ async def async_watch(
field_selector: Optional[Union[str, Dict]] = None,
since: Optional[str] = None,
allow_unknown_type: bool = True,
) -> AsyncGenerator[Tuple[str, object], None]:
) -> AsyncGenerator[Tuple[str, APIObject], None]:
"""Watch a Kubernetes resource."""
async with self.async_get_kind(
kind,
Expand All @@ -515,7 +521,7 @@ async def async_watch(
event = json.loads(line)
yield event["type"], obj_cls(event["object"], api=self)

async def api_resources(self) -> dict:
async def api_resources(self) -> List[Dict]:
"""Get the Kubernetes API resources."""
return await self.async_api_resources()

Expand Down
11 changes: 7 additions & 4 deletions kr8s/_async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from typing_extensions import ParamSpec

import anyio
import anyio.from_thread
import anyio.to_thread

T = TypeVar("T")
C = TypeVar("C")
Expand All @@ -65,11 +67,12 @@ class Portal:
"""

_instance = None
_portal = None
_instance: Portal
_portal: anyio.from_thread.BlockingPortal
thread: Thread

def __new__(cls):
if cls._instance is None:
if not hasattr(cls, "_instance"):
cls._instance = super(Portal, cls).__new__(cls)
cls._instance.thread = Thread(
target=anyio.run, args=[cls._instance._run], name="Kr8sSyncRunnerThread"
Expand All @@ -86,7 +89,7 @@ async def _run(self):
def call(self, func: Callable[P, Awaitable[T]], *args, **kwargs) -> T:
"""Call a coroutine in the runner loop and return the result."""
# On first call the thread has to start the loop, so we need to wait for it
while not self._portal:
while not hasattr(self, "_portal"):
pass
return self._portal.call(func, *args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion kr8s/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def ssl_context(self):
# If no cert information is provided, fall back to default verification
return True
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
if self.client_key_file:
if self.client_key_file and self.client_cert_file:
sslcontext.load_cert_chain(
certfile=self.client_cert_file,
keyfile=self.client_key_file,
Expand Down
5 changes: 3 additions & 2 deletions kr8s/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
async def run(
self,
) -> AsyncGenerator[Exec, CompletedExec]:
assert self._resource.api
async with self._resource.api.open_websocket(
version=self._resource.version,
url=f"{self._resource.endpoint}/{self._resource.name}/exec",
Expand All @@ -70,7 +71,7 @@ async def run(
if ws.subprotocol != "v5.channel.k8s.io":
raise ExecError(
"Stdin is not supported with protocol "
f"{ws.protocol}, only with v5.channel.k8s.io"
f"{ws.subprotocol}, only with v5.channel.k8s.io"
)
if isinstance(self._stdin, str):
await ws.send_bytes(STDIN_CHANNEL.to_bytes() + self._stdin.encode()) # type: ignore
Expand Down Expand Up @@ -132,7 +133,7 @@ class CompletedExec:
Similar to subprocess.CompletedProcess.
"""

args: Union[str | List[str]]
args: Union[str, List[str]]
stdout: bytes
stderr: bytes
returncode: int
Expand Down
34 changes: 30 additions & 4 deletions kr8s/_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Tuple,
Type,
Union,
cast,
)

import anyio
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(
"resource must be a dict, string, have an obj attribute or a to_dict method"
)
if namespace is not None:
self.raw["metadata"]["namespace"] = namespace
self.raw["metadata"]["namespace"] = namespace # type: ignore
self._api = api
if self._api is None and not self._asyncio:
self._api = kr8s.api()
Expand Down Expand Up @@ -152,6 +153,7 @@ def name(self, value: str) -> None:
def namespace(self) -> Optional[str]:
"""Namespace of the Kubernetes resource."""
if self.namespaced:
assert self.api
return self.raw.get("metadata", {}).get("namespace", self.api.namespace)
return None

Expand Down Expand Up @@ -293,6 +295,7 @@ async def exists(self, ensure=False) -> bool:
async def async_exists(self, ensure=False) -> bool:
"""Check if this object exists in Kubernetes."""
try:
assert self.api
async with self.api.call_api(
"GET",
version=self.version,
Expand All @@ -311,6 +314,7 @@ async def async_exists(self, ensure=False) -> bool:

async def create(self) -> None:
"""Create this object in Kubernetes."""
assert self.api
async with self.api.call_api(
"POST",
version=self.version,
Expand All @@ -326,6 +330,7 @@ async def delete(self, propagation_policy: Optional[str] = None) -> None:
if propagation_policy:
data["propagationPolicy"] = propagation_policy
try:
assert self.api
async with self.api.call_api(
"DELETE",
version=self.version,
Expand All @@ -346,6 +351,7 @@ async def refresh(self) -> None:
async def async_refresh(self) -> None:
"""Refresh this object from Kubernetes."""
try:
assert self.api
async with self.api.call_api(
"GET",
version=self.version,
Expand All @@ -372,6 +378,7 @@ async def async_patch(self, patch: Dict, *, subresource=None, type=None) -> None
if subresource:
url = f"{url}/{subresource}"
try:
assert self.api
async with self.api.call_api(
"PATCH",
version=self.version,
Expand Down Expand Up @@ -401,6 +408,7 @@ async def scale(self, replicas: Optional[int] = None) -> None:
async def async_watch(self):
"""Watch this object in Kubernetes."""
since = self.metadata.get("resourceVersion")
assert self.api
async for event, obj in self.api.async_watch(
self.endpoint,
namespace=self.namespace,
Expand Down Expand Up @@ -932,6 +940,7 @@ async def logs(
params["limitBytes"] = int(limit_bytes)

with contextlib.suppress(httpx.ReadTimeout):
assert self.api
async with self.api.call_api(
"GET",
version=self.version,
Expand Down Expand Up @@ -990,7 +999,7 @@ async def async_exec(
command: List[str],
*,
container: Optional[str] = None,
stdin: Optional[Union[str | BinaryIO]] = None,
stdin: Optional[Union[str, BinaryIO]] = None,
stdout: Optional[BinaryIO] = None,
stderr: Optional[BinaryIO] = None,
check: bool = True,
Expand Down Expand Up @@ -1018,7 +1027,7 @@ async def exec(
command: List[str],
*,
container: Optional[str] = None,
stdin: Optional[Union[str | BinaryIO]] = None,
stdin: Optional[Union[str, BinaryIO]] = None,
stdout: Optional[BinaryIO] = None,
stderr: Optional[BinaryIO] = None,
check: bool = True,
Expand Down Expand Up @@ -1266,6 +1275,7 @@ async def async_proxy_http_request(
) -> httpx.Response:
if port is None:
port = self.raw["spec"]["ports"][0]["port"]
assert self.api
async with self.api.call_api(
method,
version=self.version,
Expand Down Expand Up @@ -1301,11 +1311,20 @@ async def ready_pods(self) -> List[Pod]:

async def async_ready_pods(self) -> List[Pod]:
"""Return a list of ready Pods for this Service."""
assert self.api
pods = await self.api.async_get(
"pods",
label_selector=dict_to_selector(self.spec["selector"]),
namespace=self.namespace,
)
if isinstance(pods, Pod):
pods = [pods]
elif isinstance(pods, List) and all(isinstance(pod, Pod) for pod in pods):
# The all(isinstance(...) for ...) check doesn't seem to narrow the type
# correctly in pyright so we need to explicitly use cast
pods = cast(List[Pod], pods)
else:
raise TypeError(f"Unexpected type {type(pods)} returned from API")
return [pod for pod in pods if await pod.async_ready()]

async def ready(self) -> bool:
Expand Down Expand Up @@ -1397,12 +1416,19 @@ class Deployment(APIObject):

async def pods(self) -> List[Pod]:
"""Return a list of Pods for this Deployment."""
assert self.api
pods = await self.api.async_get(
"pods",
label_selector=dict_to_selector(self.spec["selector"]["matchLabels"]),
namespace=self.namespace,
)
return pods
if isinstance(pods, Pod):
return [pods]
if isinstance(pods, List) and all(isinstance(pod, Pod) for pod in pods):
# The all(isinstance(...) for ...) check doesn't seem to narrow the type
# correctly in pyright so we need to explicitly use cast
return cast(List[Pod], pods)
raise TypeError(f"Unexpected type {type(pods)} returned from API")

async def ready(self):
"""Check if the deployment is ready."""
Expand Down
5 changes: 4 additions & 1 deletion kr8s/_portforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sniffio

from ._exceptions import ConnectionClosedError
from ._types import APIObjectWithPods

if TYPE_CHECKING:
from ._objects import APIObject
Expand Down Expand Up @@ -112,6 +113,7 @@ async def __aenter__(self, *args, **kwargs):
return await self._run_task.__aenter__(*args, **kwargs)

async def __aexit__(self, *args, **kwargs):
assert self._run_task
return await self._run_task.__aexit__(*args, **kwargs)

async def start(self) -> int:
Expand Down Expand Up @@ -181,7 +183,7 @@ async def _select_pod(self) -> APIObject:
if isinstance(self._resource, Pod):
return self._resource

if hasattr(self._resource, "async_ready_pods"):
if isinstance(self._resource, APIObjectWithPods):
try:
return random.choice(await self._resource.async_ready_pods())
except IndexError:
Expand All @@ -196,6 +198,7 @@ async def _connect_websocket(self):
if not self.pod:
self.pod = await self._select_pod()
try:
assert self.pod.api
async with self.pod.api.open_websocket(
version=self.pod.version,
url=f"{self.pod.endpoint}/{self.pod.name}/portforward",
Expand Down
12 changes: 11 additions & 1 deletion kr8s/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, Kr8s Developers (See LICENSE for list)
# SPDX-License-Identifier: BSD 3-Clause License
from os import PathLike
from typing import Union
from typing import TYPE_CHECKING, List, Protocol, Union, runtime_checkable

PathType = Union[str, PathLike[str]]

if TYPE_CHECKING:
from ._objects import Pod


@runtime_checkable
class APIObjectWithPods(Protocol):
"""An APIObject subclass that contains other Pod objects."""

async def async_ready_pods(self) -> List["Pod"]: ...
7 changes: 4 additions & 3 deletions kr8s/asyncio/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2024, Kr8s Developers (See LICENSE for list)
# SPDX-License-Identifier: BSD 3-Clause License
from typing import Dict, List, Optional, Union
from typing import Dict, Optional, Type, Union

from kr8s._api import Api
from kr8s._objects import APIObject

from ._api import api as _api


async def get(
kind: str,
*names: List[str],
*names: str,
namespace: Optional[str] = None,
label_selector: Optional[Union[str, Dict]] = None,
field_selector: Optional[Union[str, Dict]] = None,
as_object: Optional[object] = None,
as_object: Optional[Type[APIObject]] = None,
allow_unknown_type: bool = True,
api=None,
_asyncio=True,
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ convention = "google"

[tool.mypy]
exclude = ["examples", "tests", "venv", "ci", "docs", "conftest.py"]

[tool.pyright]
exclude = ["examples", "**/tests", "venv", "ci", "docs", "conftest.py"]

# We often override corotuines with sync methods so this is not useful
reportIncompatibleMethodOverride = "none"

# When run with pre-commit, we don't want to report missing imports
reportMissingImports = "none"
reportMissingModuleSource = "none"

0 comments on commit 333d833

Please sign in to comment.