Skip to content

Commit

Permalink
Reimplement download_all using a queue
Browse files Browse the repository at this point in the history
  • Loading branch information
ihabunek committed Sep 21, 2024
1 parent 46c0314 commit a03d1d7
Show file tree
Hide file tree
Showing 9 changed files with 446 additions and 80 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dev = [

test = [
"pytest",
"pytest-httpserver",
"vermin",
]

Expand Down
158 changes: 158 additions & 0 deletions tests/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import asyncio
import hashlib
import os
import tempfile
from pathlib import Path
from typing import NamedTuple

import pytest
from pytest_httpserver import HTTPServer

from twitchdl.http import TaskError, TaskSuccess, download_all

MiB = 1024**2


class File(NamedTuple):
data: bytes
hash: str
path: str


def generate_test_file(size: int):
data = os.urandom(size)
hash = hashlib.sha256(data).hexdigest()
return File(data, hash, f"/{hash}")


def hash_file(path: Path):
hash = hashlib.sha256()
with open(path, "rb") as f:
while True:
chunk = f.read()
if not chunk:
break
hash.update(chunk)
return hash.hexdigest()


@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmp_dir:
yield Path(tmp_dir)


def test_success(httpserver: HTTPServer, temp_dir: Path):
count = 10
workers = 5
file_size = 1 * MiB

files = [generate_test_file(file_size) for _ in range(count)]
for f in files:
httpserver.expect_request(f.path).respond_with_data(f.data) # type: ignore

sources = [httpserver.url_for(f.path) for f in files]
targets = [temp_dir / f.hash for f in files]

result = asyncio.run(download_all(zip(sources, targets), workers))
assert result.ok
assert len(result.results) == count

for index, (file, source, target, result) in enumerate(
zip(files, sources, targets, result.results)
):
assert isinstance(result, TaskSuccess)
assert result.ok
assert not result.existing
assert result.task_id == index
assert result.size == file_size
assert result.url == source
assert result.target == target

assert target.exists()
assert os.path.getsize(target) == file_size
assert file.hash == hash_file(target)


def test_allow_failures(httpserver: HTTPServer, temp_dir: Path):
count = 10
workers = 5
file_size = 1 * MiB
failing_index = 5

files = [generate_test_file(file_size) for _ in range(count)]
for index, f in enumerate(files):
if index == failing_index:
httpserver.expect_request(f.path).respond_with_data("not found", status=404) # type: ignore
else:
httpserver.expect_request(f.path).respond_with_data(f.data) # type: ignore

sources = [httpserver.url_for(f.path) for f in files]
targets = [temp_dir / f.hash for f in files]

result = asyncio.run(download_all(zip(sources, targets), workers))
results = result.results
assert result.ok
assert len(results) == count

for index, (file, source, target, result) in enumerate(zip(files, sources, targets, results)):
if index == failing_index:
assert not target.exists()
assert isinstance(result, TaskError)
assert result.task_id == index
assert not result.ok
assert result.url == source
assert result.target == target
else:
assert target.exists()
assert os.path.getsize(target) == file_size
assert file.hash == hash_file(target)
assert isinstance(result, TaskSuccess)
assert result.task_id == index
assert result.size == file_size
assert not result.existing
assert result.ok
assert result.url == source
assert result.target == target


def test_dont_allow_failures(httpserver: HTTPServer, temp_dir: Path):
count = 10
workers = 5
file_size = 1 * MiB
failing_index = 5

files = [generate_test_file(file_size) for _ in range(count)]
for index, f in enumerate(files):
if index == failing_index:
httpserver.expect_request(f.path).respond_with_data("not found", status=404) # type: ignore
else:
httpserver.expect_request(f.path).respond_with_data(f.data) # type: ignore

sources = [httpserver.url_for(f.path) for f in files]
targets = [temp_dir / f.hash for f in files]

result = asyncio.run(download_all(zip(sources, targets), workers, allow_failures=False))
results = result.results
assert not result.ok
assert len(results) == count

for index, (file, source, target, result) in enumerate(zip(files, sources, targets, results)):
if index == failing_index:
assert not target.exists()
assert isinstance(result, TaskError)
assert result.task_id == index
assert not result.ok
assert result.url == source
assert result.target == target
else:
assert target.exists()
assert os.path.getsize(target) == file_size
assert file.hash == hash_file(target)
assert isinstance(result, TaskSuccess)
assert result.task_id == index
assert result.size == file_size
assert not result.existing
assert result.ok
assert result.url == source
assert result.target == target
34 changes: 18 additions & 16 deletions tests/test_progress.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from twitchdl.progress import Progress
from pathlib import Path

from twitchdl.progress import VideoDownloadProgress


def test_initial_values():
progress = Progress(10)
progress = VideoDownloadProgress(10)
assert progress.downloaded == 0
assert progress.estimated_total is None
assert progress.progress_perc == 0
Expand All @@ -13,10 +15,10 @@ def test_initial_values():


def test_downloaded():
progress = Progress(3)
progress.start(1, 300)
progress.start(2, 300)
progress.start(3, 300)
progress = VideoDownloadProgress(3)
progress.start(1, "foo1", Path("foo1"), 300)
progress.start(2, "foo2", Path("foo2"), 300)
progress.start(3, "foo3", Path("foo3"), 300)

assert progress.downloaded == 0
assert progress.progress_bytes == 0
Expand Down Expand Up @@ -46,13 +48,13 @@ def test_downloaded():
assert progress.progress_bytes == 500
assert progress.progress_perc == 55

progress.abort(2)
progress.abort(2, Exception())
progress._recalculate()
assert progress.downloaded == 500
assert progress.progress_bytes == 300
assert progress.progress_perc == 33

progress.start(2, 300)
progress.start(2, "foo2", Path("foo2"), 300)

progress.advance(1, 150)
progress.advance(2, 300)
Expand All @@ -73,28 +75,28 @@ def test_downloaded():


def test_estimated_total():
progress = Progress(3)
progress = VideoDownloadProgress(3)
assert progress.estimated_total is None

progress.start(1, 12000)
progress.start(1, "foo1", Path("foo1"), 12000)
progress._recalculate()
assert progress.estimated_total == 12000 * 3

progress.start(2, 11000)
progress.start(2, "foo2", Path("foo2"), 11000)
progress._recalculate()
assert progress.estimated_total == 11500 * 3

progress.start(3, 10000)
progress.start(3, "foo3", Path("foo3"), 10000)
progress._recalculate()
assert progress.estimated_total == 11000 * 3


def test_vod_downloaded_count():
progress = Progress(3)
progress = VideoDownloadProgress(3)

progress.start(1, 100)
progress.start(2, 100)
progress.start(3, 100)
progress.start(1, "foo1", Path("foo1"), 100)
progress.start(2, "foo2", Path("foo2"), 100)
progress.start(3, "foo3", Path("foo3"), 100)

assert progress.downloaded_count == 0

Expand Down
1 change: 0 additions & 1 deletion twitchdl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def download(
max_workers=max_workers,
cache_dir=cache_dir,
)

download(list(ids), options)


Expand Down
24 changes: 21 additions & 3 deletions twitchdl/commands/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
from twitchdl.exceptions import ConsoleError
from twitchdl.http import download_all, download_file
from twitchdl.naming import clip_filename, video_filename, video_placeholders
from twitchdl.output import blue, bold, green, print_error, print_log, underlined, yellow
from twitchdl.output import (
blue,
bold,
green,
print_error,
print_exception,
print_log,
underlined,
yellow,
)
from twitchdl.playlists import (
Playlist,
enumerate_vods,
Expand All @@ -28,6 +37,7 @@
parse_playlists,
select_playlist,
)
from twitchdl.progress import VideoDownloadProgress
from twitchdl.twitch import Chapter, ClipAccessToken, Video


Expand Down Expand Up @@ -284,15 +294,23 @@ def _download_video(video: Video, args: DownloadOptions) -> None:
sources = [base_uri + vod.path for vod in vods]
targets = [cache_dir / f"{vod.index:05d}.ts" for vod in vods]

asyncio.run(
result = asyncio.run(
download_all(
zip(sources, targets),
args.max_workers,
skip_existing=True,
allow_failures=False,
rate_limit=args.rate_limit,
count=len(vods),
progress=VideoDownloadProgress(len(vods)),
)
)

if not result.ok:
for ex in result.exceptions:
print()
print_exception(ex)
raise ConsoleError("Download failed")

join_playlist = make_join_playlist(vods_m3u8, vods, targets)
join_playlist_path = cache_dir / "playlist_downloaded.m3u8"
join_playlist.dump(join_playlist_path) # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions twitchdl/entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import Any, List, Literal, Mapping, Optional, TypedDict

TaskID = int
"""Identifier for a download task"""


@dataclass
class DownloadOptions:
Expand Down
Loading

0 comments on commit a03d1d7

Please sign in to comment.