Skip to content

Commit

Permalink
Fix CLI args type conversion
Browse files Browse the repository at this point in the history
Signed-off-by: Bernát Gábor <[email protected]>
  • Loading branch information
gaborbernat committed Oct 20, 2024
1 parent 301c163 commit 37d9e52
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
28 changes: 20 additions & 8 deletions src/toml_fmt_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import os
import sys
from abc import ABC, abstractmethod
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError, Namespace
from argparse import (
ArgumentDefaultsHelpFormatter,
ArgumentParser,
ArgumentTypeError,
Namespace,
_ArgumentGroup, # noqa: PLC2701
)
from collections import deque
from copy import deepcopy
from dataclasses import dataclass
Expand All @@ -16,7 +22,7 @@
from typing import TYPE_CHECKING, Any, Generic, TypeVar

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence

if sys.version_info >= (3, 11): # pragma: >=3.11 cover
import tomllib
Expand Down Expand Up @@ -63,7 +69,7 @@ def filename(self) -> str:
raise NotImplementedError

@abstractmethod
def add_format_flags(self, parser: ArgumentParser) -> None:
def add_format_flags(self, parser: _ArgumentGroup) -> None:
"""
Add any additional flags to configure the formatter.
Expand Down Expand Up @@ -126,7 +132,7 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]:
:param args: CLI arguments
:return: the parsed options
"""
parser = _build_cli(info)
parser, type_conversion = _build_cli(info)
parser.parse_args(namespace=info.opt, args=args)
res = []
for pyproject_toml in info.opt.inputs:
Expand All @@ -144,7 +150,9 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]:
if isinstance(config, dict):
for key in set(vars(override_opt).keys()) - {"inputs", "stdout", "check", "no_print_diff"}:
if key in config:
setattr(override_opt, key, config[key])
raw = config[key]
converted = type_conversion[key](raw)
setattr(override_opt, key, converted)
res.append(
_Config(
toml_filename=pyproject_toml,
Expand All @@ -159,7 +167,7 @@ def _cli_args(info: TOMLFormatter[T], args: Sequence[str]) -> list[_Config[T]]:
return res


def _build_cli(of: TOMLFormatter[T]) -> ArgumentParser:
def _build_cli(of: TOMLFormatter[T]) -> tuple[ArgumentParser, Mapping[str, Callable[[Any], Any]]]:
parser = ArgumentParser(
formatter_class=ArgumentDefaultsHelpFormatter,
prog=of.prog,
Expand Down Expand Up @@ -200,15 +208,16 @@ def _build_cli(of: TOMLFormatter[T]) -> ArgumentParser:
help="number of spaces to use for indentation",
metavar="count",
)
of.add_format_flags(format_group) # type: ignore[arg-type]
of.add_format_flags(format_group)
type_conversion = {a.dest: a.type for a in format_group._actions if a.type and a.dest} # noqa: SLF001
msg = "pyproject.toml file(s) to format, use '-' to read from stdin"
parser.add_argument(
"inputs",
nargs="+",
type=partial(_toml_path_creator, of.filename),
help=msg,
)
return parser
return parser, type_conversion


def _toml_path_creator(filename: str, argument: str) -> Path | None:
Expand Down Expand Up @@ -288,7 +297,10 @@ def _color_diff(diff: Iterable[str]) -> Iterable[str]:
yield line


ArgumentGroup = _ArgumentGroup

__all__ = [
"ArgumentGroup",
"FmtNamespace",
"TOMLFormatter",
"run",
Expand Down
40 changes: 36 additions & 4 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@

import pytest

from toml_fmt_common import GREEN, RED, RESET, FmtNamespace, TOMLFormatter, run
from toml_fmt_common import GREEN, RED, RESET, ArgumentGroup, FmtNamespace, TOMLFormatter, run

if TYPE_CHECKING:
from argparse import ArgumentParser
from pathlib import Path

from pytest_mock import MockerFixture


class DumpNamespace(FmtNamespace):
extra: str
tuple_magic: tuple[str, ...]


class Dumb(TOMLFormatter[DumpNamespace]):
Expand All @@ -35,11 +35,18 @@ def filename(self) -> str:
def override_cli_from_section(self) -> tuple[str, ...]:
return "start", "sub"

def add_format_flags(self, parser: ArgumentParser) -> None: # noqa: PLR6301
def add_format_flags(self, parser: ArgumentGroup) -> None: # noqa: PLR6301
parser.add_argument("extra", help="this is something extra")
parser.add_argument("-t", "--tuple-magic", default=(), type=lambda t: tuple(t.split(".")))

def format(self, text: str, opt: DumpNamespace) -> str: # noqa: PLR6301
return text if os.environ.get("NO_FMT") else f"{text}\nextras = {opt.extra!r}"
if os.environ.get("NO_FMT"):
return text
return "\n".join([
text,
f"extras = {opt.extra!r}",
*([f"magic = {','.join(opt.tuple_magic)!r}"] if opt.tuple_magic else []),
])


def test_dumb_help(capsys: pytest.CaptureFixture[str]) -> None:
Expand Down Expand Up @@ -77,6 +84,31 @@ def test_dumb_format_with_override(capsys: pytest.CaptureFixture[str], tmp_path:
]


def test_dumb_format_with_override_custom_type(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
dumb = tmp_path / "dumb.toml"
dumb.write_text("[start.sub]\ntuple_magic = '1.2.3'")

exit_code = run(Dumb(), ["E", str(dumb)])
assert exit_code == 1

assert dumb.read_text() == "[start.sub]\ntuple_magic = '1.2.3'\nextras = 'E'\nmagic = '1,2,3'"

out, err = capsys.readouterr()
assert not err
assert out.splitlines() == [
f"{RED}--- {dumb}",
f"{RESET}",
f"{GREEN}+++ {dumb}",
f"{RESET}",
"@@ -1,2 +1,4 @@",
"",
" [start.sub]",
" tuple_magic = '1.2.3'",
f"{GREEN}+extras = 'E'{RESET}",
f"{GREEN}+magic = '1,2,3'{RESET}",
]


def test_dumb_format_no_print_diff(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
dumb = tmp_path / "dumb.toml"
dumb.write_text("[start.sub]\nextra = 'B'")
Expand Down

0 comments on commit 37d9e52

Please sign in to comment.