Skip to content

Commit

Permalink
Allow the user to override the client session with a callback (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadair authored Jun 23, 2022
1 parent d3ee1b8 commit b8d4144
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 34 deletions.
68 changes: 39 additions & 29 deletions parfive/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
import platform
import warnings
from typing import Dict, Union, Optional
from typing import Dict, Union, Callable, Optional

try:
from typing import Literal # Added in Python 3.8
except ImportError:
from typing_extensions import Literal # type: ignore

from dataclasses import field, dataclass
from dataclasses import InitVar, field, dataclass

import aiohttp

Expand All @@ -26,6 +26,18 @@ def _default_headers():
}


def _default_aiohttp_session(config: "SessionConfig") -> aiohttp.ClientSession:
"""
The aiohttp session with the kwargs stored by this config.
Notes
-----
`aiohttp.ClientSession` expects to be instantiated in a asyncio context
where it can get a running loop.
"""
return aiohttp.ClientSession(headers=config.headers)


@dataclass
class EnvConfig:
"""
Expand Down Expand Up @@ -117,12 +129,14 @@ class SessionConfig:
overridden by the ``PARFIVE_TOTAL_TIMEOUT`` and
``PARFIVE_SOCK_READ_TIMEOUT`` environment variables.
"""
aiohttp_session_kwargs: Dict = field(default_factory=dict)
aiohttp_session_generator: Callable[
["SessionConfig"], aiohttp.ClientSession
] = _default_aiohttp_session
"""
Any extra keyword arguments to be passed to `aiohttp.ClientSession`.
Note that the `headers` keyword argument is handled separately, so should
not be included in this dict.
An optional function to generate the `aiohttp.ClientSession` class.
Due to the fact that this session needs to be executed inside the asyncio context it is a callable.
It takes one argument which is the instance of this ``SessionConfig`` class.
It is expected that you pass ``.headers`` through to this session or the headers will not be sent.
"""
env: EnvConfig = field(default_factory=EnvConfig)

Expand Down Expand Up @@ -179,40 +193,36 @@ class DownloaderConfig:
# are that these two arguments default to None.
# When it is removed after the deprecation period, the defaults here
# should be moved to SessionConifg
headers: Optional[Dict[str, str]] = field(default_factory=_default_headers)
headers: InitVar[Optional[Dict[str, str]]] = None
config: Optional[SessionConfig] = field(default_factory=SessionConfig)
env: EnvConfig = field(default_factory=EnvConfig)

def __post_init__(self):
def __post_init__(self, headers):
if self.config is None:
self.config = SessionConfig()

self.max_conn = 1 if self.env.serial_mode else self.max_conn
self.max_splits = 1 if self.env.serial_mode or self.env.disable_range else self.max_splits
self.progress = False if self.env.hide_progress else self.progress

# Default headers if None
if self.headers is None:
self.headers = self.__dataclass_fields__["headers"].default_factory()
if self.config.headers is None:
if headers is None:
self.config.headers = _default_headers()
elif headers is not None:
self.config.headers = headers

if self.progress is False:
self.file_progress = False

if self.config is None:
self.config = SessionConfig()

# Remove this after deprecation period
if self.config.headers is not None:
self.headers = self.config.headers

def __getattr__(self, __name: str):
return getattr(self.config, __name)

@property
def aiohttp_session(self) -> aiohttp.ClientSession:
"""
The aiohttp session with the kwargs stored by this class.
Notes
-----
`aiohttp.ClientSession` expects to be instantiated in a asyncio context
where it can get a running loop.
"""
return aiohttp.ClientSession(headers=self.headers, **self.aiohttp_session_kwargs)
# Always delegate headers to config even though we have that attribute
def __getattribute__(self, __name):
if __name == "headers":
return self.config.headers
return super().__getattribute__(__name)

def aiohttp_client_session(self):
return self.config.aiohttp_session_generator(self.config)
2 changes: 1 addition & 1 deletion parfive/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def _get_main_pb(self, total):
return contextlib.contextmanager(lambda: iter([None]))()

async def _run_http_download(self, main_pb):
async with self.config.aiohttp_session as session:
async with self.config.aiohttp_client_session() as session:
self._generate_tokens()
futures = await self._run_from_queue(
self.http_queue.generate_queue(),
Expand Down
7 changes: 3 additions & 4 deletions parfive/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

def test_session_config_defaults():
c = SessionConfig()
assert isinstance(c.aiohttp_session_kwargs, dict)
assert not c.aiohttp_session_kwargs
assert callable(c.aiohttp_session_generator)
assert isinstance(c.timeouts, aiohttp.ClientTimeout)
assert c.timeouts.total == 0
assert c.timeouts.sock_read == 90
Expand Down Expand Up @@ -74,6 +73,6 @@ def test_deprecated_downloader_arguments():
def test_ssl_context():
# Assert that the unpickalable SSL context object doesn't anger the
# dataclass gods
ssl_ctx = ssl.create_default_context()
c = SessionConfig(aiohttp_session_kwargs={"context": ssl_ctx})
gen = lambda config: aiohttp.ClientSession(context=ssl.create_default_context())
c = SessionConfig(aiohttp_session_generator=gen)
d = Downloader(config=c)

0 comments on commit b8d4144

Please sign in to comment.