Skip to content

Commit

Permalink
Ensure type is passed through _io.sync (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson authored May 21, 2024
1 parent a676fab commit a8b5e28
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions kr8s/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,31 @@

import inspect
import subprocess
import sys
import tempfile
from contextlib import asynccontextmanager
from functools import partial, wraps
from threading import Thread
from typing import Any, AsyncGenerator, Awaitable, Callable, Generator, Tuple, TypeVar
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Generator,
Tuple,
TypeVar,
)

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

import anyio

T = TypeVar("T")
C = TypeVar("C")
P = ParamSpec("P")


class Portal:
Expand All @@ -44,13 +60,13 @@ async def _run(self):
self._portal = portal
await portal.sleep_until_stopped()

def call(self, func: Callable[..., T], *args, **kwargs) -> T:
def call(self, func: Callable[P, T], *args, **kwargs) -> T:
while not self._portal:
pass
return self._portal.call(func, *args, **kwargs)


def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]:
def run_sync(coro: Callable[P, Awaitable[T]]) -> Callable[P, T]:
"""Wraps coroutine in a function that blocks until it has executed.
Parameters
Expand All @@ -65,18 +81,17 @@ def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]:
"""

@wraps(coro)
def wrapped(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
wrapped = partial(coro, *args, **kwargs)
wrapped.__doc__ = coro.__doc__
if inspect.isasyncgenfunction(coro):
return iter_over_async(wrapped)
portal = Portal()
if inspect.iscoroutinefunction(coro):
portal = Portal()
return portal.call(wrapped)
raise TypeError(f"Expected coroutine function, got {coro.__class__.__name__}")

wrapped.__doc__ = coro.__doc__
return wrapped
return wrapper


def iter_over_async(agen: AsyncGenerator) -> Generator:
Expand All @@ -97,7 +112,7 @@ async def get_next() -> Tuple[bool, Any]:
yield obj


def sync(source: object) -> object:
def sync(source: C) -> C:
"""Convert all public async methods/properties of an object to universal methods.
See :func:`run_sync` for more info
Expand Down

0 comments on commit a8b5e28

Please sign in to comment.