Skip to content

Commit

Permalink
Add typing to utils (#139)
Browse files Browse the repository at this point in the history
* Add typing to utils

* Put aioftp under a type checking guard

* Add types-requests to tox mypy reqs

* Undo path var renaming

* Fix response typing

* Fix aioftp not defined

* Remove pathlike subscripting

* Remove type subscripting on Queue

* Add missing annotation

---------

Co-authored-by: Stuart Mumford <[email protected]>
  • Loading branch information
dstansby and Cadair authored Apr 4, 2024
1 parent bd1d93f commit d49ed03
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 22 deletions.
52 changes: 30 additions & 22 deletions parfive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import pathlib
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union, TypeVar, Generator
from pathlib import Path
from itertools import count
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -11,6 +12,10 @@

import parfive

if TYPE_CHECKING:
import aioftp


__all__ = [
"cancel_task",
"run_in_thread",
Expand All @@ -22,7 +27,7 @@


# Copied out of CPython under PSF Licence 2
def _parseparam(s):
def _parseparam(s: str) -> Generator[str, None, None]:
while s[:1] == ";":
s = s[1:]
end = s.find(";")
Expand All @@ -35,7 +40,7 @@ def _parseparam(s):
s = s[end:]


def parse_header(line):
def parse_header(line: str) -> Tuple[str, Dict[str, str]]:
"""Parse a Content-type like header.
Return the main content-type and a dictionary of options.
"""
Expand Down Expand Up @@ -68,7 +73,7 @@ def default_name(path: os.PathLike, resp: aiohttp.ClientResponse, url: str) -> o
return pathlib.Path(path) / name


def run_task_in_thread(loop, coro):
def run_task_in_thread(loop: asyncio.BaseEventLoop, coro: asyncio.Task) -> Any:
"""
This function returns the asyncio Future after running the loop in a
thread.
Expand All @@ -84,7 +89,7 @@ def run_task_in_thread(loop, coro):
return future.result()


async def get_ftp_size(client, filepath):
async def get_ftp_size(client: "aioftp.Client", filepath: os.PathLike) -> int:
"""
Given an `aioftp.ClientSession` object get the expected size of the file,
return ``None`` if the size can not be determined.
Expand All @@ -99,22 +104,22 @@ async def get_ftp_size(client, filepath):
return int(size) if size else size


def get_http_size(resp):
def get_http_size(resp: aiohttp.ClientResponse) -> Union[int, str, None]:
size = resp.headers.get("content-length", None)
return int(size) if size else size


def replacement_filename(path):
def replacement_filename(path: os.PathLike) -> Path: # type: ignore[return]
"""
Given a path generate a unique filename.
"""
path = pathlib.Path(path)

if not path.exists:
if not path.exists():
return path

suffix = "".join(path.suffixes)
for c in count(1):
for c in count(start=1):
if suffix:
name, _ = path.name.split(suffix)
else:
Expand All @@ -125,7 +130,7 @@ def replacement_filename(path):
return new_path


def get_filepath(filepath, overwrite):
def get_filepath(filepath: os.PathLike, overwrite: bool) -> Tuple[Union[Path, str], bool]:
"""
Get the filepath to download to and ensure dir exists.
Expand All @@ -145,7 +150,7 @@ def get_filepath(filepath, overwrite):
return filepath, False


def sha256sum(filename):
def sha256sum(filename: str) -> str:
"""
https://stackoverflow.com/a/44873382
"""
Expand All @@ -159,47 +164,50 @@ def sha256sum(filename):


class MultiPartDownloadError(Exception):
def __init__(self, response):
def __init__(self, response: aiohttp.ClientResponse) -> None:
self.response = response


class FailedDownload(Exception):
def __init__(self, filepath_partial, url, exception):
def __init__(self, filepath_partial: Path, url: str, exception: BaseException) -> None:
self.filepath_partial = filepath_partial
self.url = url
self.exception = exception
super().__init__()

def __repr__(self):
def __repr__(self) -> str:
out = super().__repr__()
out += f"\n {self.url} {self.exception}"
return out

def __str__(self):
def __str__(self) -> str:
return f"Download Failed: {self.url} with error {str(self.exception)}"


class Token:
def __init__(self, n):
def __init__(self, n: int) -> None:
self.n = n

def __repr__(self):
def __repr__(self) -> str:
return super().__repr__() + f"n = {self.n}"

def __str__(self):
def __str__(self) -> str:
return f"Token {self.n}"


class _QueueList(list):
_T = TypeVar("_T")


class _QueueList(List[_T]):
"""
A list, with an extra method that empties the list and puts it into a
`asyncio.Queue`.
Creating the queue can only be done inside a running asyncio loop.
"""

def generate_queue(self, maxsize=0):
queue = asyncio.Queue(maxsize=maxsize)
def generate_queue(self, maxsize: int = 0) -> asyncio.Queue:
queue: asyncio.Queue = asyncio.Queue(maxsize=maxsize)
for item in self:
queue.put_nowait(item)
self.clear()
Expand All @@ -218,7 +226,7 @@ class ParfiveFutureWarning(FutureWarning):
"""


def remove_file(filepath):
def remove_file(filepath: os.PathLike) -> None:
"""
Remove the file from the disk, if it exists
"""
Expand All @@ -232,7 +240,7 @@ def remove_file(filepath):
)


async def cancel_task(task):
async def cancel_task(task: asyncio.Task) -> bool:
"""
Call cancel on a task and then wait for it to exit.
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ exclude_lines =
[mypy]
plugins = pydantic.mypy

[mypy-parfive.utils]
disallow_untyped_defs = True

# Ignore the autogenerated version file
[mypy-parfive._version]
ignore_missing_imports = True
Expand Down

0 comments on commit d49ed03

Please sign in to comment.