Skip to content

Commit

Permalink
Add type annotations for kr8s._async_utils (#417)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson authored Jun 28, 2024
1 parent 5f14c0b commit 0182ca6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
4 changes: 3 additions & 1 deletion kr8s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,16 @@ def api(
>>> api = kr8s.api() # Uses the default kubeconfig
>>> print(api.version()) # Get the Kubernetes version
"""
return _run_sync(_api)(
ret = _run_sync(_api)(
url=url,
kubeconfig=kubeconfig,
serviceaccount=serviceaccount,
namespace=namespace,
context=context,
_asyncio=False,
)
assert isinstance(ret, (Api, _AsyncApi))
return ret


def whoami():
Expand Down
34 changes: 23 additions & 11 deletions kr8s/_async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Generator,
Tuple,
TypeVar,
Union,
)

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -90,7 +91,9 @@ def call(self, func: Callable[P, Awaitable[T]], *args, **kwargs) -> T:
return self._portal.call(func, *args, **kwargs)


def run_sync(coro: Callable[P, Awaitable[T]]) -> Callable[P, T]:
def run_sync(
coro: Callable[P, Union[AsyncGenerator, Awaitable[T]]]
) -> Callable[P, Union[Generator, T]]:
"""Wraps a coroutine in a function that blocks until it has executed.
Args:
Expand All @@ -100,17 +103,26 @@ def run_sync(coro: Callable[P, Awaitable[T]]) -> Callable[P, T]:
Callable: A sync function that executes the coroutine via the :class`Portal`.
"""

@wraps(coro)
def run_sync_inner(*args: P.args, **kwargs: P.kwargs) -> T:
wrapped = partial(coro, *args, **kwargs)
if inspect.isasyncgenfunction(coro):
return iter_over_async(wrapped)
if inspect.iscoroutinefunction(coro):
if inspect.isasyncgenfunction(coro):

@wraps(coro)
def run_gen_inner(*args: P.args, **kwargs: P.kwargs) -> Generator:
wrapped = partial(coro, *args, **kwargs)
return iter_over_async(wrapped())

return run_gen_inner

if inspect.iscoroutinefunction(coro):

@wraps(coro)
def run_sync_inner(*args: P.args, **kwargs: P.kwargs) -> T:
wrapped = partial(coro, *args, **kwargs)
portal = Portal()
return portal.call(wrapped)
raise TypeError(f"Expected coroutine function, got {coro.__class__.__name__}")

return run_sync_inner
return run_sync_inner

raise TypeError(f"Expected coroutine function, got {coro.__class__.__name__}")


def iter_over_async(agen: AsyncGenerator) -> Generator:
Expand All @@ -122,7 +134,7 @@ def iter_over_async(agen: AsyncGenerator) -> Generator:
Yields:
Any: object from async generator
"""
ait = agen().__aiter__()
ait = agen.__aiter__()

async def get_next() -> Tuple[bool, Any]:
try:
Expand Down Expand Up @@ -223,7 +235,7 @@ async def NamedTemporaryFile(
"""Create a temporary file that is deleted when the context exits."""
kwargs.update(delete=False)

def f() -> tempfile.NamedTemporaryFile:
def f():
return tempfile.NamedTemporaryFile(*args, **kwargs)

tmp = await anyio.to_thread.run_sync(f)
Expand Down

0 comments on commit 0182ca6

Please sign in to comment.