Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable ABI validation for contract initialization #3538

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/core/contracts/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def invoke_contract(
function_signature,
*func_args,
abi_codec=contract.w3.codec,
abi_validation=contract.w3.abi_validation,
**func_kwargs,
)
function = contract.functions[abi_to_signature(fn_abi)]
Expand Down Expand Up @@ -810,6 +811,7 @@ async def async_invoke_contract(
function_signature,
*func_args,
abi_codec=contract.w3.codec,
abi_validation=contract.w3.abi_validation,
**func_kwargs,
)
function = contract.functions[abi_to_signature(fn_abi)]
Expand Down
28 changes: 28 additions & 0 deletions tests/core/utilities/test_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,21 @@ def test_get_abi_element_info_raises_mismatched_abi(contract_abi: ABI) -> None:
get_abi_element_info(contract_abi, "foo", *args, **{})


def test_get_abi_element_info_configurable_abi_validation() -> None:
assert (
get_abi_element_info(CONTRACT_ABI, "myFunction")["abi"]
== FUNCTION_ABI_NO_INPUTS
)
assert (
get_abi_element_info(CONTRACT_ABI, "myFunction", abi_validation=True)["abi"]
== FUNCTION_ABI_NO_INPUTS
)
assert (
get_abi_element_info(CONTRACT_ABI, "myFunction", abi_validation=False)["abi"]
== FUNCTION_ABI_NO_INPUTS
)


@pytest.mark.parametrize(
"abi,abi_element_identifier,args,kwargs,expected_abi",
(
Expand Down Expand Up @@ -711,6 +726,19 @@ def test_get_abi_element_raises_with_invalid_parameters(
get_abi_element(abi, abi_element_identifier, *args, **kwargs)


def test_get_abi_element_configurable_abi_validation() -> None:
assert get_abi_element(CONTRACT_ABI, "logTwoEvents", *[1]) == LOG_TWO_EVENTS_ABI
assert (
get_abi_element(CONTRACT_ABI, "logTwoEvents", *[1], abi_validation=True)
== LOG_TWO_EVENTS_ABI
)

assert (
get_abi_element(CONTRACT_ABI, "logTwoEvents", *[1], abi_validation=False)
== LOG_TWO_EVENTS_ABI
)


def test_get_abi_element_codec_override(contract_abi: ABI) -> None:
codec = ABICodec(default_registry)
args: Sequence[Any] = [1]
Expand Down
2 changes: 2 additions & 0 deletions web3/_utils/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def prepare_transaction(
abi_element_identifier,
*fn_args,
abi_codec=w3.codec,
abi_validation=w3.abi_validation,
**fn_kwargs,
),
)
Expand Down Expand Up @@ -255,6 +256,7 @@ def encode_transaction_data(
abi_element_identifier,
*args,
abi_codec=w3.codec,
abi_validation=w3.abi_validation,
**kwargs,
)
info_abi = fn_info["abi"]
Expand Down
9 changes: 8 additions & 1 deletion web3/contract/async_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractEvent":
self.name,
*args,
abi_codec=self.w3.codec,
abi_validation=self.w3.abi_validation,
**kwargs,
)
argument_types = get_abi_input_types(event_abi)
Expand Down Expand Up @@ -311,7 +312,12 @@ def __getattr__(self, event_name: str) -> "AsyncContractEvent":
"Are you sure you provided the correct contract abi?",
)
else:
event_abi = get_abi_element(self._events, event_name)
event_abi = get_abi_element(
self._events,
event_name,
abi_codec=self.w3.codec,
abi_validation=self.w3.abi_validation,
)
argument_types = get_abi_input_types(event_abi)
event_signature = str(get_abi_element_signature(event_name, argument_types))
return super().__getattribute__(event_signature)
Expand Down Expand Up @@ -339,6 +345,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractFunction":
element_name,
*args,
abi_codec=self.w3.codec,
abi_validation=self.w3.abi_validation,
**kwargs,
)

Expand Down
9 changes: 9 additions & 0 deletions web3/contract/base_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _get_event_abi(cls) -> ABIEvent:
filter_abi_by_type("event", cls.contract_abi),
cls.abi_element_identifier,
abi_codec=cls.w3.codec,
abi_validation=cls.w3.abi_validation,
),
)

Expand Down Expand Up @@ -517,6 +518,8 @@ def __init__(self, abi: Optional[ABIFunction] = None) -> None:
self.contract_abi,
),
self.abi_element_identifier,
abi_codec=self.w3.codec,
abi_validation=self.w3.abi_validation,
),
)
self.name = abi_to_signature(self.abi)
Expand All @@ -531,6 +534,7 @@ def _get_abi(cls) -> ABIFunction:
cls.contract_abi,
get_abi_element_signature(cls.abi_element_identifier),
abi_codec=cls.w3.codec,
abi_validation=cls.w3.abi_validation,
),
)

Expand All @@ -541,6 +545,7 @@ def _get_abi(cls) -> ABIFunction:
get_name_from_abi_element_identifier(cls.abi_element_identifier),
*cls.args,
abi_codec=cls.w3.codec,
abi_validation=cls.w3.abi_validation,
**cls.kwargs,
),
)
Expand Down Expand Up @@ -819,6 +824,7 @@ def encode_abi(
abi_element_identifier,
*args,
abi_codec=cls.w3.codec,
abi_validation=cls.w3.abi_validation,
**kwargs,
)

Expand Down Expand Up @@ -1166,6 +1172,7 @@ def _find_matching_fn_abi(
fn_identifier,
*args,
abi_codec=cls.w3.codec,
abi_validation=cls.w3.abi_validation,
**kwargs,
)

Expand All @@ -1181,6 +1188,8 @@ def _get_event_abi(
abi=cls.abi,
abi_element_identifier=event_name,
argument_names=argument_names,
abi_codec=cls.w3.codec,
abi_validation=cls.w3.abi_validation,
),
)

Expand Down
9 changes: 8 additions & 1 deletion web3/contract/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "ContractEvent":
self.name,
*args,
abi_codec=self.w3.codec,
abi_validation=self.w3.abi_validation,
**kwargs,
)
argument_types = get_abi_input_types(event_abi)
Expand Down Expand Up @@ -304,7 +305,12 @@ def __getattr__(self, event_name: str) -> "ContractEvent":
"Are you sure you provided the correct contract abi?",
)
else:
event_abi = get_abi_element(self._events, event_name)
event_abi = get_abi_element(
self._events,
event_name,
abi_codec=self.w3.codec,
abi_validation=self.w3.abi_validation,
)
argument_types = get_abi_input_types(event_abi)
event_signature = str(get_abi_element_signature(event_name, argument_types))
return super().__getattribute__(event_signature)
Expand Down Expand Up @@ -336,6 +342,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "ContractFunction":
element_name,
*args,
abi_codec=self.w3.codec,
abi_validation=self.w3.abi_validation,
**kwargs,
)

Expand Down
2 changes: 2 additions & 0 deletions web3/contract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def call_contract_function(
abi_element_identifier,
*args,
abi_codec=w3.codec,
abi_validation=w3.abi_validation,
**kwargs,
),
)
Expand Down Expand Up @@ -462,6 +463,7 @@ async def async_call_contract_function(
abi_element_identifier,
*args,
abi_codec=async_w3.codec,
abi_validation=async_w3.abi_validation,
**kwargs,
),
)
Expand Down
4 changes: 4 additions & 0 deletions web3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,13 @@ def __init__(
Dict[str, Union[Type[Module], Sequence[Any]]]
] = None,
ens: Union[ENS, "Empty"] = empty,
abi_validation: Optional[bool] = True,
) -> None:
_validate_provider(self, provider)

self.manager = self.RequestManager(self, provider, middleware)
self.codec = ABICodec(build_strict_registry())
self.abi_validation = abi_validation

if modules is None:
modules = get_default_modules()
Expand Down Expand Up @@ -464,11 +466,13 @@ def __init__(
Dict[str, Union[Type[Module], Sequence[Any]]]
] = None,
ens: Union[AsyncENS, "Empty"] = empty,
abi_validation: Optional[bool] = True,
) -> None:
_validate_provider(self, provider)

self.manager = self.RequestManager(self, provider, middleware)
self.codec = ABICodec(build_strict_registry())
self.abi_validation = abi_validation

self._modules = get_async_default_modules() if modules is None else modules
self._external_modules = None if external_modules is None else external_modules
Expand Down
22 changes: 19 additions & 3 deletions web3/utils/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def get_abi_element_info(
abi_element_identifier: ABIElementIdentifier,
*args: Optional[Sequence[Any]],
abi_codec: Optional[Any] = None,
abi_validation: Optional[bool] = True,
**kwargs: Optional[Dict[str, Any]],
) -> ABIElementInfo:
"""
Expand All @@ -490,6 +491,10 @@ def get_abi_element_info(
:param abi_codec: Codec used for encoding and decoding. Default with \
`strict_bytes_type_checking` enabled.
:type abi_codec: `Optional[Any]`
:param abi_validation: Enforce ABI validation of elements. Without validation, \
elements may contain invalid types or values and may be unusable. If multiple \
elements are found, only the first will be returned. Defaults to `True`.
:type abi_validation: `Optional[bool]`
:param kwargs: Find an element ABI with matching kwargs.
:type kwargs: `Optional[Dict[str, Any]]`
:return: Element information including the ABI, selector and args.
Expand Down Expand Up @@ -524,7 +529,12 @@ def get_abi_element_info(
(7, 3)
"""
fn_abi = get_abi_element(
abi, abi_element_identifier, *args, abi_codec=abi_codec, **kwargs
abi,
abi_element_identifier,
*args,
abi_codec=abi_codec,
abi_validation=abi_validation,
**kwargs,
)
fn_selector = encode_hex(function_abi_to_4byte_selector(fn_abi))
fn_inputs: Tuple[Any, ...] = tuple()
Expand All @@ -545,6 +555,7 @@ def get_abi_element(
abi_element_identifier: ABIElementIdentifier,
*args: Optional[Any],
abi_codec: Optional[Any] = None,
abi_validation: Optional[bool] = True,
**kwargs: Optional[Any],
) -> ABIElement:
"""
Expand Down Expand Up @@ -574,6 +585,10 @@ def get_abi_element(
:param abi_codec: Codec used for encoding and decoding. Default with \
`strict_bytes_type_checking` enabled.
:type abi_codec: `Optional[Any]`
:param abi_validation: Enforce ABI validation of elements. Without validation, \
elements may contain invalid types or values and may be unusable. If multiple \
elements are found, only the first will be returned. Defaults to `True`.
:type abi_validation: `Optional[bool]`
:param kwargs: Find an element ABI with matching kwargs.
:type kwargs: `Optional[Dict[str, Any]]`
:return: ABI element for the specific ABI element.
Expand Down Expand Up @@ -602,7 +617,8 @@ def get_abi_element(
'type': 'uint256'}], 'payable': False, 'stateMutability': 'nonpayable', \
'type': 'function'}
"""
validate_abi(abi)
if abi_validation:
validate_abi(abi)

if abi_codec is None:
abi_codec = ABICodec(default_registry)
Expand All @@ -620,7 +636,7 @@ def get_abi_element(
num_matches = len(abi_element_matches)

# Raise MismatchedABI when more than one found
if num_matches != 1:
if abi_validation and num_matches != 1:
error_diagnosis = _mismatched_abi_error_diagnosis(
abi_element_identifier,
abi,
Expand Down