Skip to content

Commit

Permalink
Do not use pydantic for settings validation (#98)
Browse files Browse the repository at this point in the history
* Revert "Use pydantic for config validation"

This reverts commit b0bbcd7.

* Drop use_aiofiles as a keyword arg

* More env again

* Don't delete the env test

* Comments

* Unused imports
  • Loading branch information
Cadair authored Jun 23, 2022
1 parent f3dac92 commit d3ee1b8
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 52 deletions.
96 changes: 59 additions & 37 deletions parfive/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os
import platform
import warnings
from typing import Any, Dict, Union, Optional
from typing import Dict, Union, Optional

try:
from typing import Literal # Added in Python 3.8
except ImportError:
from typing_extensions import Literal # type: ignore

from dataclasses import field, dataclass

import aiohttp
from pydantic import BaseSettings, Field

import parfive
from parfive.utils import ParfiveUserWarning
Expand All @@ -25,32 +26,43 @@ def _default_headers():
}


class EnvConfig(BaseSettings):
@dataclass
class EnvConfig:
"""
Configuration read from environment variables.
"""

# Session scoped env vars
serial_mode: bool = Field(False, env="PARFIVE_SINGLE_DOWNLOAD")
disable_range: bool = Field(False, env="PARFIVE_DISABLE_RANGE")
hide_progress: bool = Field(False, env="PARFIVE_HIDE_PROGRESS")
debug_logging: bool = Field(False, env="PARFIVE_DEBUG")
timeout_total: float = Field(0, env="PARFIVE_TOTAL_TIMEOUT")
timeout_sock_read: float = Field(90, env="PARFIVE_SOCK_READ_TIMEOUT")
override_use_aiofiles: bool = Field(False, env="PARFIVE_OVERWRITE_ENABLE_AIOFILES")


class SessionConfig(BaseSettings):
serial_mode: bool = field(default=False, init=False)
disable_range: bool = field(default=False, init=False)
hide_progress: bool = field(default=False, init=False)
debug_logging: bool = field(default=False, init=False)
timeout_total: float = field(default=0, init=False)
timeout_sock_read: float = field(default=90, init=False)
override_use_aiofiles: bool = field(default=False, init=False)

def __post_init__(self):
self.serial_mode = "PARFIVE_SINGLE_DOWNLOAD" in os.environ
self.disable_range = "PARFIVE_DISABLE_RANGE" in os.environ
self.hide_progress = "PARFIVE_HIDE_PROGRESS" in os.environ
self.debug_logging = "PARFIVE_DEBUG" in os.environ
self.timeout_total = float(os.environ.get("PARFIVE_TOTAL_TIMEOUT", 0))
self.timeout_sock_read = float(os.environ.get("PARFIVE_SOCK_READ_TIMEOUT", 90))
self.override_use_aiofiles = "PARFIVE_OVERWRITE_ENABLE_AIOFILES" in os.environ


@dataclass
class SessionConfig:
"""
Configuration options for `parfive.Downloader`.
"""

http_proxy: Optional[str] = Field(None, env="HTTP_PROXY")
http_proxy: Optional[str] = None
"""
The URL of a proxy to use for HTTP requests. Will default to the value of
the ``HTTP_PROXY`` env var.
"""
https_proxy: Optional[str] = Field(None, env="HTTPS_PROXY")
https_proxy: Optional[str] = None
"""
The URL of a proxy to use for HTTPS requests. Will default to the value of
the ``HTTPS_PROXY`` env var.
Expand All @@ -75,7 +87,7 @@ class SessionConfig(BaseSettings):
If `True` (the default) a progress bar will be shown for every file if any
progress bars are shown.
"""
notebook: Optional[bool] = None
notebook: Union[bool, None] = None
"""
If `None` `tqdm` will automatically detect if it can draw rich IPython
Notebook progress bars. If `False` or `True` notebook mode will be forced
Expand All @@ -85,7 +97,7 @@ class SessionConfig(BaseSettings):
"""
If not `None` configure the logger to log to stderr with this log level.
"""
use_aiofiles: bool = Field(False)
use_aiofiles: Optional[bool] = False
"""
Enables using `aiofiles` to write files to disk in their own thread pool.
Expand All @@ -105,14 +117,14 @@ class SessionConfig(BaseSettings):
overridden by the ``PARFIVE_TOTAL_TIMEOUT`` and
``PARFIVE_SOCK_READ_TIMEOUT`` environment variables.
"""
aiohttp_session_kwargs: Dict = Field(default_factory=dict)
aiohttp_session_kwargs: Dict = field(default_factory=dict)
"""
Any extra keyword arguments to be passed to `aiohttp.ClientSession`.
Note that the `headers` keyword argument is handled separately, so should
not be included in this dict.
"""
env: EnvConfig = Field(default_factory=EnvConfig)
env: EnvConfig = field(default_factory=EnvConfig)

@staticmethod
def _aiofiles_importable():
Expand All @@ -123,8 +135,7 @@ def _aiofiles_importable():
return True

def _compute_aiofiles(self, use_aiofiles):
if self.env.override_use_aiofiles:
use_aiofiles = True
use_aiofiles = use_aiofiles or self.env.override_use_aiofiles
if use_aiofiles and not self._aiofiles_importable():
warnings.warn(
"Can not use aiofiles even though use_aiofiles is set to True as aiofiles can not be imported.",
Expand All @@ -133,14 +144,17 @@ def _compute_aiofiles(self, use_aiofiles):
use_aiofiles = False
return use_aiofiles

def __init__(self, **data):
super().__init__(**data)
def __post_init__(self):
if self.timeouts is None:
timeouts = {
"total": self.env.timeout_total,
"sock_read": self.env.timeout_sock_read,
}
self.timeouts = aiohttp.ClientTimeout(**timeouts)
if self.http_proxy is None:
self.http_proxy = os.environ.get("HTTP_PROXY", None)
if self.https_proxy is None:
self.https_proxy = os.environ.get("HTTPS_PROXY", None)

if self.use_aiofiles is not None:
self.use_aiofiles = self._compute_aiofiles(self.use_aiofiles)
Expand All @@ -149,7 +163,8 @@ def __init__(self, **data):
self.log_level = "DEBUG"


class DownloaderConfig(BaseSettings):
@dataclass
class DownloaderConfig:
"""
Hold all downloader session state.
"""
Expand All @@ -158,29 +173,36 @@ class DownloaderConfig(BaseSettings):
max_splits: int = 5
progress: bool = True
overwrite: Union[bool, Literal["unique"]] = False
headers: Optional[Dict[str, str]] = Field(default_factory=_default_headers)
config: Optional[SessionConfig] = Field(default_factory=SessionConfig)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if self.config is None:
self.config = SessionConfig()

# headers is deprecated here.
# The arguments passed to SessionConfig take precedence.
# To make this priority work, the defaults on SessionConfig
# are that these two arguments default to None.
# When it is removed after the deprecation period, the defaults here
# should be moved to SessionConifg
headers: Optional[Dict[str, str]] = field(default_factory=_default_headers)
config: Optional[SessionConfig] = field(default_factory=SessionConfig)
env: EnvConfig = field(default_factory=EnvConfig)

def __post_init__(self):
self.max_conn = 1 if self.env.serial_mode else self.max_conn
self.max_splits = 1 if self.env.serial_mode or self.env.disable_range else self.max_splits
self.progress = False if self.env.hide_progress else self.progress

# Default headers if None
if self.headers is None:
self.headers = self.__dataclass_fields__["headers"].default_factory()

if self.progress is False:
self.config.file_progress = False
self.file_progress = False

if self.config is None:
self.config = SessionConfig()

# Remove this after deprecation period
if self.config.headers is not None:
self.headers = self.config.headers
if self.headers is None:
self.headers = _default_headers()

def __getattr__(self, __name: str) -> Any:
def __getattr__(self, __name: str):
return getattr(self.config, __name)

@property
Expand Down
2 changes: 1 addition & 1 deletion parfive/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
progress: bool = True,
overwrite: Union[bool, Literal["unique"]] = False,
headers: Optional[Dict[str, str]] = None,
config: Optional[SessionConfig] = None,
config: SessionConfig = None,
):

msg = (
Expand Down
2 changes: 1 addition & 1 deletion parfive/tests/test_aiofiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_enable_aiofiles_constructor(use_aiofiles):
), f"expected={use_aiofiles}, got={dl.config.use_aiofiles}"


@patch.dict(os.environ, {"PARFIVE_OVERWRITE_ENABLE_AIOFILES": "True"})
@patch.dict(os.environ, {"PARFIVE_OVERWRITE_ENABLE_AIOFILES": "some_value_to_enable_it"})
@pytest.mark.parametrize("use_aiofiles", [True, False])
def test_enable_aiofiles_env_overwrite_always_enabled(use_aiofiles):
dl = Downloader(config=parfive.SessionConfig(use_aiofiles=use_aiofiles))
Expand Down
23 changes: 12 additions & 11 deletions parfive/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ssl

import aiohttp
import pytest

Expand Down Expand Up @@ -33,15 +35,6 @@ def test_session_config_env_defaults():
assert c.env.timeout_sock_read == 90


def test_use_aiofiles():
c = DownloaderConfig()
assert c.use_aiofiles is False
c = DownloaderConfig(config=SessionConfig(use_aiofiles=True))
assert c.use_aiofiles is True
c = DownloaderConfig(config=SessionConfig(use_aiofiles=False))
assert c.use_aiofiles is False


def test_headers_deprecated():
c = DownloaderConfig()
assert isinstance(c.headers, dict)
Expand Down Expand Up @@ -74,5 +67,13 @@ def test_headers_deprecated():

def test_deprecated_downloader_arguments():
with pytest.warns(ParfiveFutureWarning, match="headers keyword"):
d = Downloader(headers={"ni": "shrubbery"})
assert d.config.headers == {"ni": "shrubbery"}
d = Downloader(headers="ni")
assert d.config.headers == "ni"


def test_ssl_context():
# Assert that the unpickalable SSL context object doesn't anger the
# dataclass gods
ssl_ctx = ssl.create_default_context()
c = SessionConfig(aiohttp_session_kwargs={"context": ssl_ctx})
d = Downloader(config=c)
2 changes: 1 addition & 1 deletion parfive/tests/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def test_proxy_passed_as_kwargs_to_get(tmpdir, url, proxy):
("GET", url),
{
"allow_redirects": True,
"timeout": ClientTimeout(total=0.0, connect=None, sock_read=90.0, sock_connect=None),
"timeout": ClientTimeout(total=0, connect=None, sock_read=90, sock_connect=None),
"proxy": proxy,
},
]
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ python_requires = >=3.7
install_requires =
tqdm >= 4.27.0
aiohttp
pydantic
typing_extensions;python_version<'3.8'
setup_requires =
setuptools_scm
Expand Down

0 comments on commit d3ee1b8

Please sign in to comment.