Skip to content

Commit

Permalink
Add callbacks for ftp downloads (#150)
Browse files Browse the repository at this point in the history
* Add callbacks for ftp downloads

* Add ftp callback tests and refactor http callback tests.

---------

Co-authored-by: Stuart Mumford <[email protected]>
  • Loading branch information
samaloney and Cadair authored Apr 8, 2024
1 parent fc9a4e5 commit 7c70d4c
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 10 deletions.
9 changes: 9 additions & 0 deletions parfive/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,10 @@ async def _get_ftp(

await asyncio.gather(*download_workers)
await downloaded_chunks_queue.join()

for callback in self.config.done_callbacks:
callback(filepath, url, None)

return str(filepath)

except (Exception, asyncio.CancelledError) as e:
Expand All @@ -851,6 +855,11 @@ async def _get_ftp(
# computed the filepath, so we have no file to cleanup
if filepath is not None:
remove_file(filepath)
filepath = None

for callback in self.config.done_callbacks:
callback(filepath, url, e)

raise FailedDownload(filepath_partial, url, e)

finally:
Expand Down
77 changes: 67 additions & 10 deletions parfive/tests/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from pathlib import Path
from tempfile import gettempdir
from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import aiohttp
import pytest
from aiohttp import ClientTimeout
from aiohttp import ClientConnectorError, ClientTimeout
from pytest_socket import SocketConnectBlockedError

import parfive
from parfive.config import SessionConfig
Expand Down Expand Up @@ -495,24 +496,80 @@ def test_proxy_passed_as_kwargs_to_get(tmpdir, url, proxy):
]


def test_done_callback(httpserver, tmpdir):
tmpdir = str(tmpdir)
def test_http_callback_success(httpserver, tmpdir):
# Test callback on successful download
httpserver.serve_content(
"SIMPLE = T", headers={"Content-Disposition": "attachment; filename=testfile.fits"}
)

def done_callback(filepath, url, error):
(Path(gettempdir()) / "callback.done").touch()
cb = MagicMock()
dl = Downloader(config=SessionConfig(done_callbacks=[cb]))
dl.enqueue_file(httpserver.url, path=tmpdir, max_splits=None)

dl = Downloader(config=SessionConfig(done_callbacks=[done_callback]))
dl.enqueue_file(httpserver.url, path=Path(tmpdir), max_splits=None)
assert dl.queued_downloads == 1

dl.download()

assert cb.call_count == 1
cb_path, cb_url, cb_status = cb.call_args[0]
assert cb_path == tmpdir / "testfile.fits"
assert httpserver.url == cb_url
assert cb_status is None


def test_http_callback_fail(httpserver, tmpdir):
# Test callback on failed download
cb = MagicMock()
dl = Downloader(config=SessionConfig(done_callbacks=[cb]))
url = "http://test.com/myfile.txt"
dl.enqueue_file(url, path=tmpdir, max_splits=None)

assert dl.queued_downloads == 1

dl.download()

assert cb.call_count == 1
cb_path, cb_url, cb_status = cb.call_args[0]
assert cb_path is None
assert url == cb_url
assert isinstance(cb_status, (SocketConnectBlockedError, ClientConnectorError))


@pytest.mark.allow_hosts(True)
def test_ftp_callback_success(tmpdir):
cb = MagicMock()
dl = Downloader(config=SessionConfig(done_callbacks=[cb]))
url = "ftp://ftp.swpc.noaa.gov/pub/warehouse/2011/2011_SRS.tar.gz"
dl.enqueue_file(url, path=str(tmpdir))

assert dl.queued_downloads == 1

dl.download()

assert cb.call_count == 1
cb_path, cb_url, cb_status = cb.call_args[0]
assert cb_path == tmpdir / "2011_SRS.tar.gz"
assert url == cb_url
assert cb_status is None


@mock.patch("aioftp.Client.context", side_effect=ConnectionRefusedError())
def test_ftp_callback_error(tmpdir):
# Download should fail as not marked with allowed hosts
cb = MagicMock()
dl = Downloader(config=SessionConfig(done_callbacks=[cb]))
url = "ftp://127.0.0.1/nosuchfile.txt"
dl.enqueue_file(url, path=str(tmpdir))

assert dl.queued_downloads == 1

dl.download()

assert (Path(gettempdir()) / "callback.done").exists()
(Path(gettempdir()) / "callback.done").unlink()
assert cb.call_count == 1
cb_path, cb_url, cb_status = cb.call_args[0]
assert cb_path is None
assert cb_url == url
assert isinstance(cb_status, ConnectionRefusedError)


class CustomThread(threading.Thread):
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,6 @@ ignore_missing_imports = True

[mypy-pytest.*]
ignore_missing_imports = True

[mypy-pytest_socket.*]
ignore_missing_imports = True

0 comments on commit 7c70d4c

Please sign in to comment.