Skip to content

Commit

Permalink
wip_breaking_maby
Browse files Browse the repository at this point in the history
  • Loading branch information
ihabunek committed Sep 16, 2024
1 parent f8e58ca commit 9bc2dde
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 59 deletions.
133 changes: 81 additions & 52 deletions twitchdl/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, NamedTuple, Optional, Tuple

Expand All @@ -28,6 +29,34 @@
"""


class Task(NamedTuple):
task_id: int
url: str
target: Path


@dataclass
class TaskResult(ABC):
task_id: int
url: str
target: Path

@property
def ok(self) -> bool:
return isinstance(self, TaskSuccess)


@dataclass
class TaskSuccess(TaskResult):
size: int
existing: bool


@dataclass
class TaskFailure(TaskResult):
exception: Exception


class TokenBucket(ABC):
@abstractmethod
def advance(self, size: int):
Expand Down Expand Up @@ -76,22 +105,27 @@ async def download(
target: Path,
progress: Progress,
token_bucket: TokenBucket,
):
) -> int:
# Download to a temp file first, then copy to target when over to avoid
# getting saving chunks which may persist if canceled or --keep is used
tmp_target = f"{target}.tmp"
downloaded_size = 0

with open(tmp_target, "wb") as f:
async with client.stream("GET", source) as response:
response.raise_for_status()
size = int(response.headers.get("content-length"))
progress.start(task_id, size, source, target)
content_length = int(response.headers.get("content-length"))
progress.start(task_id, content_length, source, target)
async for chunk in response.aiter_bytes(chunk_size=CHUNK_SIZE):
f.write(chunk)
size = len(chunk)
token_bucket.advance(size)
progress.advance(task_id, size)
chunk_size = len(chunk)
downloaded_size += chunk_size
token_bucket.advance(chunk_size)
progress.advance(task_id, chunk_size)
progress.end(task_id)

os.rename(tmp_target, target)
return downloaded_size


async def download_with_retries(
Expand All @@ -102,31 +136,26 @@ async def download_with_retries(
progress: Progress,
token_bucket: TokenBucket,
skip_existing: bool,
):
) -> TaskResult:
if skip_existing and target.exists():
size = os.path.getsize(target)
progress.already_downloaded(task_id, size)
return
return TaskSuccess(task_id, source, target, size, existing=True)

for n in range(RETRY_COUNT):
try:
return await download(client, task_id, source, target, progress, token_bucket)
size = await download(client, task_id, source, target, progress, token_bucket)
return TaskSuccess(task_id, source, target, size, existing=False)
except httpx.HTTPError as ex:
if n + 1 >= RETRY_COUNT:
progress.failed(task_id, ex)
raise
return TaskFailure(task_id, source, target, ex)
else:
progress.restart(task_id, ex)

raise Exception("Should not happen")


class QueueItem(NamedTuple):
task_id: int
url: str
target: Path


class DownloadAllResult(NamedTuple):
ok: bool
exceptions: Optional[List[Exception]] = None
Expand All @@ -140,71 +169,71 @@ async def download_all(
skip_existing: bool = True,
rate_limit: Optional[int] = None,
progress: Optional[Progress] = None,
) -> DownloadAllResult:
) -> List[TaskResult]:
"""Download files concurently."""

progress = progress or Progress()
token_bucket = LimitingTokenBucket(rate_limit) if rate_limit else EndlessTokenBucket()
queue: asyncio.Queue[QueueItem] = asyncio.Queue()
queue: asyncio.Queue[Task] = asyncio.Queue()
task_results: List[TaskResult] = []

async def producer():
"""Add all tasks to the queue then wait for queue to be depleted."""
for index, (source, target) in enumerate(source_targets):
await queue.put(QueueItem(index, source, target))
await queue.put(Task(index, source, target))
await queue.join()

async def worker(client: httpx.AsyncClient, worker_id: int):
# print(f"starting worker {worker_id}")
while True:
item = await queue.get()
# print(f"worker {worker_id} starting item {item.task_id}")
try:
await download_with_retries(
client,
item.task_id,
item.url,
item.target,
progress,
token_bucket,
skip_existing,
)
# print(f"worker {worker_id} finished item {item.task_id}")
except Exception:
# print(f"worker {worker_id} {ex=}")
if not allow_failures:
# print("raising because allow_failures is False")
raise
finally:
# print(f"worker {worker_id} task done {item.task_id}")
queue.task_done()
result = await download_with_retries(
client,
item.task_id,
item.url,
item.target,
progress,
token_bucket,
skip_existing,
)
task_results.append(result)
queue.task_done()
if not result.ok:
print("aborting?")
raise ValueError("Abort")

async with httpx.AsyncClient(timeout=TIMEOUT) as client:
# Task to fill the queue and then wait until it is depleted
# Task which fills the queue and then wait until it is depleted
producer_task = asyncio.create_task(producer(), name="Producer")

# Worker tasks to process the download queue
worker_tasks = [
asyncio.create_task(worker(client, worker_id)) for worker_id in range(worker_count)
asyncio.create_task(worker(client, worker_id), name=f"Downloader {worker_id}")
for worker_id in range(worker_count)
]

# Wait for queue to deplete or of the worker tasks to finish, whichever
# comes first. If allow_failures is False, workers will raise an error
# on failure thus stopping any pending downloads.
await asyncio.wait([producer_task, *worker_tasks], return_when=asyncio.FIRST_COMPLETED)
# comes first. Worker tasks will only finish if allow_failures is False
# and a download fails, otherwise they will run forever and the
# producer task will finish first.
done, pending = await asyncio.wait(
[producer_task, *worker_tasks], return_when=asyncio.FIRST_COMPLETED
)

for x in done:
print("done", x)

for x in pending:
print("pending", x)

success = producer_task.done()
print(f"{success=}")

# Cancel tasks and wait until they are cancelled
for task in worker_tasks:
task.cancel()

results = await asyncio.gather(*worker_tasks, return_exceptions=True)
from pprint import pp

print(f"{success=}")
pp(results)
exceptions = [r for r in results if isinstance(r, Exception)]
return DownloadAllResult(success, exceptions)
await asyncio.gather(*worker_tasks, return_exceptions=True)
return sorted(task_results, key=lambda x: x.task_id)


def download_file(url: str, target: Path, retries: int = RETRY_COUNT) -> None:
Expand Down
7 changes: 0 additions & 7 deletions twitchdl/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,6 @@ def start(self, task_id: int, size: int, source: str, target: Path):
self.tasks[task_id] = Task(task_id, size)
self.print()

# def set_size(self, task_id: int, size: int):
# if task_id not in self.tasks:
# raise ValueError(f"Task {task_id}: cannot set size, not started")

# self.tasks[task_id] = Task(task_id, size)
# self.print()

def advance(self, task_id: int, size: int):
if task_id not in self.tasks:
raise ValueError(f"Task {task_id}: cannot advance, not started")
Expand Down

0 comments on commit 9bc2dde

Please sign in to comment.