Skip to content

Commit

Permalink
perf: make contracts load faster (#2371)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Nov 4, 2024
1 parent f833e46 commit b71c810
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 38 deletions.
6 changes: 5 additions & 1 deletion src/ape/contracts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .base import ContractContainer, ContractEvent, ContractInstance, ContractLog, ContractNamespace
def __getattr__(name: str):
import ape.contracts.base as module

return getattr(module, name)


__all__ = [
"ContractContainer",
Expand Down
73 changes: 38 additions & 35 deletions src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
from typing import TYPE_CHECKING, Any, Optional, Union

import click
import pandas as pd
from eth_pydantic_types import HexBytes
from eth_utils import to_hex
from ethpm_types.abi import EventABI, MethodABI
from ethpm_types.contract_type import ABI_W_SELECTOR_T, ContractType
from IPython.lib.pretty import for_type
from ethpm_types.abi import EventABI

from ape.api.accounts import AccountAPI
from ape.api.address import Address, BaseAddress
from ape.api.query import (
ContractCreation,
Expand All @@ -34,7 +30,6 @@
MissingDeploymentBytecodeError,
)
from ape.logging import get_rich_console, logger
from ape.types.address import AddressType
from ape.types.events import ContractLog, LogFilter, MockContractLog
from ape.utils.abi import StructParser, _enrich_natspec
from ape.utils.basemodel import (
Expand All @@ -49,9 +44,12 @@
from ape.utils.misc import log_instead_of_fail

if TYPE_CHECKING:
from ethpm_types.abi import ConstructorABI, ErrorABI
from ethpm_types.abi import ConstructorABI, ErrorABI, MethodABI
from ethpm_types.contract_type import ABI_W_SELECTOR_T, ContractType
from pandas import DataFrame

from ape.api.transactions import ReceiptAPI, TransactionAPI
from ape.types.address import AddressType


class ContractConstructor(ManagerAccessMixin):
Expand Down Expand Up @@ -90,7 +88,7 @@ def serialize_transaction(self, *args, **kwargs) -> "TransactionAPI":
def __call__(self, private: bool = False, *args, **kwargs) -> "ReceiptAPI":
txn = self.serialize_transaction(*args, **kwargs)

if "sender" in kwargs and isinstance(kwargs["sender"], AccountAPI):
if "sender" in kwargs and hasattr(kwargs["sender"], "call"):
sender = kwargs["sender"]
return sender.call(txn, **kwargs)
elif "sender" not in kwargs and self.account_manager.default_sender is not None:
Expand All @@ -104,7 +102,7 @@ def __call__(self, private: bool = False, *args, **kwargs) -> "ReceiptAPI":


class ContractCall(ManagerAccessMixin):
def __init__(self, abi: MethodABI, address: AddressType) -> None:
def __init__(self, abi: "MethodABI", address: "AddressType") -> None:
super().__init__()
self.abi = abi
self.address = address
Expand Down Expand Up @@ -140,9 +138,9 @@ def __call__(self, *args, **kwargs) -> Any:

class ContractMethodHandler(ManagerAccessMixin):
contract: "ContractInstance"
abis: list[MethodABI]
abis: list["MethodABI"]

def __init__(self, contract: "ContractInstance", abis: list[MethodABI]) -> None:
def __init__(self, contract: "ContractInstance", abis: list["MethodABI"]) -> None:
super().__init__()
self.contract = contract
self.abis = abis
Expand Down Expand Up @@ -320,7 +318,7 @@ def estimate_gas_cost(self, *args, **kwargs) -> int:
return self.transact.estimate_gas_cost(*arguments, **kwargs)


def _select_method_abi(abis: list[MethodABI], args: Union[tuple, list]) -> MethodABI:
def _select_method_abi(abis: list["MethodABI"], args: Union[tuple, list]) -> "MethodABI":
args = args or []
selected_abi = None
for abi in abis:
Expand All @@ -335,13 +333,10 @@ def _select_method_abi(abis: list[MethodABI], args: Union[tuple, list]) -> Metho


class ContractTransaction(ManagerAccessMixin):
abi: MethodABI
address: AddressType

def __init__(self, abi: MethodABI, address: AddressType) -> None:
def __init__(self, abi: "MethodABI", address: "AddressType") -> None:
super().__init__()
self.abi = abi
self.address = address
self.abi: "MethodABI" = abi
self.address: "AddressType" = address

@log_instead_of_fail(default="<ContractTransaction>")
def __repr__(self) -> str:
Expand All @@ -362,7 +357,7 @@ def __call__(self, *args, **kwargs) -> "ReceiptAPI":
txn = self.serialize_transaction(*args, **kwargs)
private = kwargs.get("private", False)

if "sender" in kwargs and isinstance(kwargs["sender"], AccountAPI):
if "sender" in kwargs and hasattr(kwargs["sender"], "call"):
return kwargs["sender"].call(txn, **kwargs)

txn = self.provider.prepare_transaction(txn)
Expand Down Expand Up @@ -441,6 +436,7 @@ def _as_transaction(self, *args) -> ContractTransaction:
)


# TODO: In Ape 0.9 - make not a BaseModel - no reason to.
class ContractEvent(BaseInterfaceModel):
"""
The types of events on a :class:`~ape.contracts.base.ContractInstance`.
Expand Down Expand Up @@ -616,7 +612,7 @@ def query(
stop_block: Optional[int] = None,
step: int = 1,
engine_to_use: Optional[str] = None,
) -> pd.DataFrame:
) -> "DataFrame":
"""
Iterate through blocks for log events
Expand All @@ -635,6 +631,8 @@ def query(
Returns:
pd.DataFrame
"""
# perf: pandas import is really slow. Avoid importing at module level.
import pandas as pd

if start_block < 0:
start_block = self.chain_manager.blocks.height + start_block
Expand Down Expand Up @@ -800,7 +798,7 @@ def poll_logs(


class ContractTypeWrapper(ManagerAccessMixin):
contract_type: ContractType
contract_type: "ContractType"
base_path: Optional[Path] = None

@property
Expand All @@ -812,7 +810,7 @@ def selector_identifiers(self) -> dict[str, str]:
return self.contract_type.selector_identifiers

@property
def identifier_lookup(self) -> dict[str, ABI_W_SELECTOR_T]:
def identifier_lookup(self) -> dict[str, "ABI_W_SELECTOR_T"]:
"""
Provides a mapping of method, error, and event selector identifiers to
ABI Types.
Expand Down Expand Up @@ -898,6 +896,9 @@ def repr_pretty_for_assignment(cls, *args, **kwargs):
info = _get_info()
error_type.info = error_type.__doc__ = info # type: ignore
if info:
# perf: Avoid forcing everyone to import from IPython.
from IPython.lib.pretty import for_type

error_type._repr_pretty_ = repr_pretty_for_assignment # type: ignore

# Register the dynamically-created type with IPython so it integrates.
Expand All @@ -922,8 +923,8 @@ class ContractInstance(BaseAddress, ContractTypeWrapper):

def __init__(
self,
address: AddressType,
contract_type: ContractType,
address: "AddressType",
contract_type: "ContractType",
txn_hash: Optional[Union[str, HexBytes]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -957,7 +958,9 @@ def __call__(self, *args, **kwargs) -> "ReceiptAPI":
return super().__call__(*args, **kwargs)

@classmethod
def from_receipt(cls, receipt: "ReceiptAPI", contract_type: ContractType) -> "ContractInstance":
def from_receipt(
cls, receipt: "ReceiptAPI", contract_type: "ContractType"
) -> "ContractInstance":
"""
Create a contract instance from the contract deployment receipt.
"""
Expand Down Expand Up @@ -997,7 +1000,7 @@ def __repr__(self) -> str:
return f"<{contract_name} {self.address}>"

@property
def address(self) -> AddressType:
def address(self) -> "AddressType":
"""
The address of the contract.
Expand All @@ -1009,7 +1012,7 @@ def address(self) -> AddressType:

@cached_property
def _view_methods_(self) -> dict[str, ContractCallHandler]:
view_methods: dict[str, list[MethodABI]] = dict()
view_methods: dict[str, list["MethodABI"]] = dict()

for abi in self.contract_type.view_methods:
if abi.name in view_methods:
Expand All @@ -1028,7 +1031,7 @@ def _view_methods_(self) -> dict[str, ContractCallHandler]:

@cached_property
def _mutable_methods_(self) -> dict[str, ContractTransactionHandler]:
mutable_methods: dict[str, list[MethodABI]] = dict()
mutable_methods: dict[str, list["MethodABI"]] = dict()

for abi in self.contract_type.mutable_methods:
if abi.name in mutable_methods:
Expand Down Expand Up @@ -1075,7 +1078,7 @@ def call_view_method(self, method_name: str, *args, **kwargs) -> Any:

else:
# Didn't find anything that matches
name = self.contract_type.name or ContractType.__name__
name = self.contract_type.name or "ContractType"
raise ApeAttributeError(f"'{name}' has no attribute '{method_name}'.")

def invoke_transaction(self, method_name: str, *args, **kwargs) -> "ReceiptAPI":
Expand Down Expand Up @@ -1110,7 +1113,7 @@ def invoke_transaction(self, method_name: str, *args, **kwargs) -> "ReceiptAPI":

else:
# Didn't find anything that matches
name = self.contract_type.name or ContractType.__name__
name = self.contract_type.name or "ContractType"
raise ApeAttributeError(f"'{name}' has no attribute '{method_name}'.")

def get_event_by_signature(self, signature: str) -> ContractEvent:
Expand Down Expand Up @@ -1168,7 +1171,7 @@ def get_error_by_signature(self, signature: str) -> type[CustomError]:

@cached_property
def _events_(self) -> dict[str, list[ContractEvent]]:
events: dict[str, list[EventABI]] = {}
events: dict[str, list["EventABI"]] = {}

for abi in self.contract_type.events:
if abi.name in events:
Expand Down Expand Up @@ -1339,7 +1342,7 @@ class ContractContainer(ContractTypeWrapper, ExtraAttributesMixin):
contract_container = project.MyContract # Assuming there is a contract named "MyContract"
"""

def __init__(self, contract_type: ContractType) -> None:
def __init__(self, contract_type: "ContractType") -> None:
self.contract_type = contract_type

@log_instead_of_fail(default="<ContractContainer>")
Expand Down Expand Up @@ -1404,7 +1407,7 @@ def deployments(self):
return self.chain_manager.contracts.get_deployments(self)

def at(
self, address: AddressType, txn_hash: Optional[Union[str, HexBytes]] = None
self, address: "AddressType", txn_hash: Optional[Union[str, HexBytes]] = None
) -> ContractInstance:
"""
Get a contract at the given address.
Expand Down Expand Up @@ -1473,7 +1476,7 @@ def deploy(self, *args, publish: bool = False, **kwargs) -> ContractInstance:
if kwargs.get("value") and not self.contract_type.constructor.is_payable:
raise MethodNonPayableError("Sending funds to a non-payable constructor.")

if "sender" in kwargs and isinstance(kwargs["sender"], AccountAPI):
if "sender" in kwargs and hasattr(kwargs["sender"], "call"):
# Handle account-related preparation if needed, such as signing
receipt = self._cache_wrap(lambda: kwargs["sender"].call(txn, **kwargs))

Expand Down Expand Up @@ -1533,7 +1536,7 @@ def declare(self, *args, **kwargs) -> "ReceiptAPI":
transaction = self.provider.network.ecosystem.encode_contract_blueprint(
self.contract_type, *args, **kwargs
)
if "sender" in kwargs and isinstance(kwargs["sender"], AccountAPI):
if "sender" in kwargs and hasattr(kwargs["sender"], "call"):
return kwargs["sender"].call(transaction)

receipt = self.provider.send_transaction(transaction)
Expand Down
2 changes: 1 addition & 1 deletion src/ape_node/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def disconnect(self):

def _clean(self):
if self._data_dir.is_dir():
shutil.rmtree(self._data_dir)
shutil.rmtree(self._data_dir, ignore_errors=True)

# dir must exist when initializing chain.
self._data_dir.mkdir(parents=True, exist_ok=True)
Expand Down
6 changes: 5 additions & 1 deletion tests/functional/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,12 @@ class TestProject:
def test_init(self, with_dependencies_project_path):
# Purpose not using `project_with_contracts` fixture.
project = Project(with_dependencies_project_path)
project.manifest_path.unlink(missing_ok=True)
assert project.path == with_dependencies_project_path
project.manifest_path.unlink(missing_ok=True)

# Re-init to show it doesn't create the manifest file.
project = Project(with_dependencies_project_path)

# Manifest should have been created by default.
assert not project.manifest_path.is_file()

Expand Down

0 comments on commit b71c810

Please sign in to comment.