From b8d4144fc8c21eb520841403b566daa39e6a247c Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 23 Jun 2022 13:48:24 +0200 Subject: [PATCH] Allow the user to override the client session with a callback (#99) --- parfive/config.py | 68 +++++++++++++++++++++--------------- parfive/downloader.py | 2 +- parfive/tests/test_config.py | 7 ++-- 3 files changed, 43 insertions(+), 34 deletions(-) diff --git a/parfive/config.py b/parfive/config.py index 24de159..457dc6c 100644 --- a/parfive/config.py +++ b/parfive/config.py @@ -1,14 +1,14 @@ import os import platform import warnings -from typing import Dict, Union, Optional +from typing import Dict, Union, Callable, Optional try: from typing import Literal # Added in Python 3.8 except ImportError: from typing_extensions import Literal # type: ignore -from dataclasses import field, dataclass +from dataclasses import InitVar, field, dataclass import aiohttp @@ -26,6 +26,18 @@ def _default_headers(): } +def _default_aiohttp_session(config: "SessionConfig") -> aiohttp.ClientSession: + """ + The aiohttp session with the kwargs stored by this config. + + Notes + ----- + `aiohttp.ClientSession` expects to be instantiated in a asyncio context + where it can get a running loop. + """ + return aiohttp.ClientSession(headers=config.headers) + + @dataclass class EnvConfig: """ @@ -117,12 +129,14 @@ class SessionConfig: overridden by the ``PARFIVE_TOTAL_TIMEOUT`` and ``PARFIVE_SOCK_READ_TIMEOUT`` environment variables. """ - aiohttp_session_kwargs: Dict = field(default_factory=dict) + aiohttp_session_generator: Callable[ + ["SessionConfig"], aiohttp.ClientSession + ] = _default_aiohttp_session """ - Any extra keyword arguments to be passed to `aiohttp.ClientSession`. - - Note that the `headers` keyword argument is handled separately, so should - not be included in this dict. + An optional function to generate the `aiohttp.ClientSession` class. + Due to the fact that this session needs to be executed inside the asyncio context it is a callable. + It takes one argument which is the instance of this ``SessionConfig`` class. + It is expected that you pass ``.headers`` through to this session or the headers will not be sent. """ env: EnvConfig = field(default_factory=EnvConfig) @@ -179,40 +193,36 @@ class DownloaderConfig: # are that these two arguments default to None. # When it is removed after the deprecation period, the defaults here # should be moved to SessionConifg - headers: Optional[Dict[str, str]] = field(default_factory=_default_headers) + headers: InitVar[Optional[Dict[str, str]]] = None config: Optional[SessionConfig] = field(default_factory=SessionConfig) env: EnvConfig = field(default_factory=EnvConfig) - def __post_init__(self): + def __post_init__(self, headers): + if self.config is None: + self.config = SessionConfig() + self.max_conn = 1 if self.env.serial_mode else self.max_conn self.max_splits = 1 if self.env.serial_mode or self.env.disable_range else self.max_splits self.progress = False if self.env.hide_progress else self.progress # Default headers if None - if self.headers is None: - self.headers = self.__dataclass_fields__["headers"].default_factory() + if self.config.headers is None: + if headers is None: + self.config.headers = _default_headers() + elif headers is not None: + self.config.headers = headers if self.progress is False: self.file_progress = False - if self.config is None: - self.config = SessionConfig() - - # Remove this after deprecation period - if self.config.headers is not None: - self.headers = self.config.headers - def __getattr__(self, __name: str): return getattr(self.config, __name) - @property - def aiohttp_session(self) -> aiohttp.ClientSession: - """ - The aiohttp session with the kwargs stored by this class. - - Notes - ----- - `aiohttp.ClientSession` expects to be instantiated in a asyncio context - where it can get a running loop. - """ - return aiohttp.ClientSession(headers=self.headers, **self.aiohttp_session_kwargs) + # Always delegate headers to config even though we have that attribute + def __getattribute__(self, __name): + if __name == "headers": + return self.config.headers + return super().__getattribute__(__name) + + def aiohttp_client_session(self): + return self.config.aiohttp_session_generator(self.config) diff --git a/parfive/downloader.py b/parfive/downloader.py index 44011ee..4a8b5a3 100644 --- a/parfive/downloader.py +++ b/parfive/downloader.py @@ -380,7 +380,7 @@ def _get_main_pb(self, total): return contextlib.contextmanager(lambda: iter([None]))() async def _run_http_download(self, main_pb): - async with self.config.aiohttp_session as session: + async with self.config.aiohttp_client_session() as session: self._generate_tokens() futures = await self._run_from_queue( self.http_queue.generate_queue(), diff --git a/parfive/tests/test_config.py b/parfive/tests/test_config.py index 047af6c..b953001 100644 --- a/parfive/tests/test_config.py +++ b/parfive/tests/test_config.py @@ -10,8 +10,7 @@ def test_session_config_defaults(): c = SessionConfig() - assert isinstance(c.aiohttp_session_kwargs, dict) - assert not c.aiohttp_session_kwargs + assert callable(c.aiohttp_session_generator) assert isinstance(c.timeouts, aiohttp.ClientTimeout) assert c.timeouts.total == 0 assert c.timeouts.sock_read == 90 @@ -74,6 +73,6 @@ def test_deprecated_downloader_arguments(): def test_ssl_context(): # Assert that the unpickalable SSL context object doesn't anger the # dataclass gods - ssl_ctx = ssl.create_default_context() - c = SessionConfig(aiohttp_session_kwargs={"context": ssl_ctx}) + gen = lambda config: aiohttp.ClientSession(context=ssl.create_default_context()) + c = SessionConfig(aiohttp_session_generator=gen) d = Downloader(config=c)