diff --git a/twitchdl/http.py b/twitchdl/http.py index cbcff1e..f8449c7 100644 --- a/twitchdl/http.py +++ b/twitchdl/http.py @@ -3,6 +3,7 @@ import os import time from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Iterable, List, NamedTuple, Optional, Tuple @@ -28,6 +29,34 @@ """ +class Task(NamedTuple): + task_id: int + url: str + target: Path + + +@dataclass +class TaskResult(ABC): + task_id: int + url: str + target: Path + + @property + def ok(self) -> bool: + return isinstance(self, TaskSuccess) + + +@dataclass +class TaskSuccess(TaskResult): + size: int + existing: bool + + +@dataclass +class TaskFailure(TaskResult): + exception: Exception + + class TokenBucket(ABC): @abstractmethod def advance(self, size: int): @@ -76,22 +105,27 @@ async def download( target: Path, progress: Progress, token_bucket: TokenBucket, -): +) -> int: # Download to a temp file first, then copy to target when over to avoid # getting saving chunks which may persist if canceled or --keep is used tmp_target = f"{target}.tmp" + downloaded_size = 0 + with open(tmp_target, "wb") as f: async with client.stream("GET", source) as response: response.raise_for_status() - size = int(response.headers.get("content-length")) - progress.start(task_id, size, source, target) + content_length = int(response.headers.get("content-length")) + progress.start(task_id, content_length, source, target) async for chunk in response.aiter_bytes(chunk_size=CHUNK_SIZE): f.write(chunk) - size = len(chunk) - token_bucket.advance(size) - progress.advance(task_id, size) + chunk_size = len(chunk) + downloaded_size += chunk_size + token_bucket.advance(chunk_size) + progress.advance(task_id, chunk_size) progress.end(task_id) + os.rename(tmp_target, target) + return downloaded_size async def download_with_retries( @@ -102,31 +136,26 @@ async def download_with_retries( progress: Progress, token_bucket: TokenBucket, skip_existing: bool, -): +) -> TaskResult: if skip_existing and target.exists(): size = os.path.getsize(target) progress.already_downloaded(task_id, size) - return + return TaskSuccess(task_id, source, target, size, existing=True) for n in range(RETRY_COUNT): try: - return await download(client, task_id, source, target, progress, token_bucket) + size = await download(client, task_id, source, target, progress, token_bucket) + return TaskSuccess(task_id, source, target, size, existing=False) except httpx.HTTPError as ex: if n + 1 >= RETRY_COUNT: progress.failed(task_id, ex) - raise + return TaskFailure(task_id, source, target, ex) else: progress.restart(task_id, ex) raise Exception("Should not happen") -class QueueItem(NamedTuple): - task_id: int - url: str - target: Path - - class DownloadAllResult(NamedTuple): ok: bool exceptions: Optional[List[Exception]] = None @@ -140,71 +169,71 @@ async def download_all( skip_existing: bool = True, rate_limit: Optional[int] = None, progress: Optional[Progress] = None, -) -> DownloadAllResult: +) -> List[TaskResult]: """Download files concurently.""" progress = progress or Progress() token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket() - queue: asyncio.Queue[QueueItem] = asyncio.Queue() + queue: asyncio.Queue[Task] = asyncio.Queue() + task_results: List[TaskResult] = [] async def producer(): """Add all tasks to the queue then wait for queue to be depleted.""" for index, (source, target) in enumerate(source_targets): - await queue.put(QueueItem(index, source, target)) + await queue.put(Task(index, source, target)) await queue.join() async def worker(client: httpx.AsyncClient, worker_id: int): - # print(f"starting worker {worker_id}") while True: item = await queue.get() - # print(f"worker {worker_id} starting item {item.task_id}") - try: - await download_with_retries( - client, - item.task_id, - item.url, - item.target, - progress, - token_bucket, - skip_existing, - ) - # print(f"worker {worker_id} finished item {item.task_id}") - except Exception: - # print(f"worker {worker_id} {ex=}") - if not allow_failures: - # print("raising because allow_failures is False") - raise - finally: - # print(f"worker {worker_id} task done {item.task_id}") - queue.task_done() + result = await download_with_retries( + client, + item.task_id, + item.url, + item.target, + progress, + token_bucket, + skip_existing, + ) + task_results.append(result) + queue.task_done() + if not result.ok: + print("aborting?") + raise ValueError("Abort") async with httpx.AsyncClient(timeout=TIMEOUT) as client: - # Task to fill the queue and then wait until it is depleted + # Task which fills the queue and then wait until it is depleted producer_task = asyncio.create_task(producer(), name="Producer") # Worker tasks to process the download queue worker_tasks = [ - asyncio.create_task(worker(client, worker_id)) for worker_id in range(worker_count) + asyncio.create_task(worker(client, worker_id), name=f"Downloader {worker_id}") + for worker_id in range(worker_count) ] # Wait for queue to deplete or of the worker tasks to finish, whichever - # comes first. If allow_failures is False, workers will raise an error - # on failure thus stopping any pending downloads. - await asyncio.wait([producer_task, *worker_tasks], return_when=asyncio.FIRST_COMPLETED) + # comes first. Worker tasks will only finish if allow_failures is False + # and a download fails, otherwise they will run forever and the + # producer task will finish first. + done, pending = await asyncio.wait( + [producer_task, *worker_tasks], return_when=asyncio.FIRST_COMPLETED + ) + + for x in done: + print("done", x) + + for x in pending: + print("pending", x) success = producer_task.done() + print(f"{success=}") # Cancel tasks and wait until they are cancelled for task in worker_tasks: task.cancel() - results = await asyncio.gather(*worker_tasks, return_exceptions=True) - from pprint import pp - - print(f"{success=}") - pp(results) - exceptions = [r for r in results if isinstance(r, Exception)] - return DownloadAllResult(success, exceptions) + await asyncio.gather(*worker_tasks, return_exceptions=True) + return sorted(task_results, key=lambda x: x.task_id) def download_file(url: str, target: Path, retries: int = RETRY_COUNT) -> None: diff --git a/twitchdl/progress.py b/twitchdl/progress.py index 0a5ac8a..125a9d7 100644 --- a/twitchdl/progress.py +++ b/twitchdl/progress.py @@ -96,13 +96,6 @@ def start(self, task_id: int, size: int, source: str, target: Path): self.tasks[task_id] = Task(task_id, size) self.print() - # def set_size(self, task_id: int, size: int): - # if task_id not in self.tasks: - # raise ValueError(f"Task {task_id}: cannot set size, not started") - - # self.tasks[task_id] = Task(task_id, size) - # self.print() - def advance(self, task_id: int, size: int): if task_id not in self.tasks: raise ValueError(f"Task {task_id}: cannot advance, not started")