diff --git a/example.py b/example.py index 0e0ca725..4b9b9d84 100644 --- a/example.py +++ b/example.py @@ -1,3 +1,4 @@ +import asyncio from typing import Annotated from ape import chain @@ -14,6 +15,9 @@ # We can add parameters, which are values in state that can be updated by external triggers app.add_parameter("bad_number", default=3) +# Cannot call `app.state` outside of an app function handler +# app.state.something # NOTE: raises AttributeError + # NOTE: Don't do any networking until after initializing app USDC = tokens["USDC"] YFI = tokens["YFI"] @@ -21,9 +25,11 @@ @app.on_startup() def app_startup(startup_state: StateSnapshot): - # NOTE: This is called just as the app is put into "run" state, - # and handled by the first available worker - # raise Exception # NOTE: Any exception raised on startup aborts immediately + # This is called just as the app is put into "run" state, + # and handled by the first available worker + + # Any exception raised on startup aborts immediately: + # raise Exception # NOTE: raises StartupFailure # This is a great place to set `app.state` values that aren't parameters # NOTE: Non-parameter state is `None` by default @@ -36,7 +42,7 @@ def app_startup(startup_state: StateSnapshot): # Can handle some resource initialization for each worker, like LLMs or database connections class MyDB: def execute(self, query: str): - pass + pass # Handle query somehow... @app.on_worker_startup() @@ -45,9 +51,11 @@ def worker_startup(worker_state: TaskiqState): # NOTE: You need the type hint t # NOTE: Can put anything here, any python object works worker_state.db = MyDB() worker_state.block_count = 0 - # raise Exception # NOTE: Any exception raised on worker startup aborts immediately - # Cannot call `app.state` because it is not set up yet + # Any exception raised on worker startup aborts immediately: + # raise Exception # NOTE: raises StartupFailure + + # Cannot call `app.state` because it is not set up yet on worker startup functions # app.state.something # NOTE: raises AttributeError @@ -64,7 +72,7 @@ def exec_block(block: BlockAPI, context: Annotated[Context, TaskiqDepends()]): # Set new_block_timeout to adjust the expected block time. @app.on_(USDC.Transfer, start_block=19784367, new_block_timeout=25) # NOTE: Typing isn't required, it will still be an Ape `ContractLog` type -def exec_event1(log): +async def exec_event1(log): if log.log_index % 7 == app.state.bad_number: # If you raise any exception, Silverback will track the failure and keep running # NOTE: By default, if you have 3 tasks fail in a row, the app will shutdown itself @@ -73,12 +81,14 @@ def exec_event1(log): # You can update state whenever you want app.state.logs_processed += 1 + # Do any other long running tasks... + await asyncio.sleep(5) return {"amount": log.amount} @app.on_(YFI.Approval) # Any handler function can be async too -async def exec_event2(log: ContractLog): +def exec_event2(log: ContractLog): if log.log_index % 7 == 6: # If you ever want the app to immediately shutdown under some scenario, raise this exception raise CircuitBreaker("Oopsie!") @@ -91,12 +101,16 @@ async def exec_event2(log: ContractLog): # A final job to execute on Silverback shutdown @app.on_shutdown() def app_shutdown(): - # raise Exception # NOTE: Any exception raised on shutdown is ignored + # NOTE: Any exception raised on worker shutdown is ignored: + # raise Exception return {"some_metric": 123} # Just in case you need to release some resources or something inside each worker @app.on_worker_shutdown() def worker_shutdown(state: TaskiqState): # NOTE: You need the type hint here + # This is a good time to release resources state.db = None - # raise Exception # NOTE: Any exception raised on worker shutdown is ignored + + # NOTE: Any exception raised on worker shutdown is ignored: + # raise Exception diff --git a/silverback/_cli.py b/silverback/_cli.py index f3a48f1d..89ccbd15 100644 --- a/silverback/_cli.py +++ b/silverback/_cli.py @@ -1,6 +1,8 @@ import asyncio import os from concurrent.futures import ThreadPoolExecutor +from decimal import Decimal +from uuid import uuid4 import click from ape.cli import ( @@ -12,11 +14,15 @@ ) from ape.exceptions import Abort from taskiq import AsyncBroker +from taskiq.brokers.inmemory_broker import InMemoryBroker from taskiq.cli.worker.run import shutdown_broker +from taskiq.kicker import AsyncKicker from taskiq.receiver import Receiver from silverback._importer import import_from_string from silverback.runner import PollingRunner +from silverback.settings import Settings +from silverback.types import ScalarType, TaskType, is_scalar_type @click.group() @@ -130,3 +136,68 @@ def run(cli_ctx, account, runner, recorder, max_exceptions, path): def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, path): app = import_from_string(path) asyncio.run(run_worker(app.broker, worker_count=workers, shutdown_timeout=shutdown_timeout)) + + +class ScalarParam(click.ParamType): + name = "scalar" + + def convert(self, val, param, ctx) -> ScalarType: + if not isinstance(val, str) or is_scalar_type(val): + return val + + elif val.lower() in ("f", "false"): + return False + + elif val.lower() in ("t", "true"): + return True + + try: + return int(val) + except Exception: + pass + + try: + return float(val) + except Exception: + pass + + # NOTE: Decimal allows the most values, so leave last + return Decimal(val) + + +@cli.command(cls=ConnectedProviderCommand, help="Set parameters against a running silverback app") +@network_option( + default=os.environ.get("SILVERBACK_NETWORK_CHOICE", "auto"), + callback=_network_callback, +) +@click.option( + "-p", + "--param", + "param_updates", + type=(str, ScalarParam()), + multiple=True, +) +def set_param(param_updates): + + if len(param_updates) > 1: + task_name = str(TaskType._SET_PARAM_BATCH) + arg = dict(param_updates) + else: + param_name, arg = param_updates[0] + task_name = f"{TaskType._SET_PARAM}:{param_name}" + + async def set_parameters(): + broker = Settings().get_broker() + if isinstance(broker, InMemoryBroker): + raise RuntimeError("Cannot use with default in-memory broker") + + kicker = AsyncKicker(task_name, broker, labels={}) + task = await kicker.kiq(arg) + result = await task.wait_result() + + if result.is_err: + click.echo(result.error) + else: + click.echo(result.return_value) + + asyncio.run(set_parameters()) diff --git a/silverback/application.py b/silverback/application.py index 4caaccb0..d807f919 100644 --- a/silverback/application.py +++ b/silverback/application.py @@ -23,6 +23,12 @@ class TaskData: handler: AsyncTaskiqDecoratedTask +@dataclass +class ParameterInfo: + default: ScalarType | None + update_handler: AsyncTaskiqDecoratedTask | None + + class SharedState(defaultdict): def __init__(self): # Any unknown key returns None @@ -109,10 +115,31 @@ def __init__(self, settings: Settings | None = None): # NOTE: The runner needs to know the set of things that the app is tracking as a parameter # NOTE: We also need to know the defaults in case the parameters are not in the backup - self.parameter_defaults: dict[str, ScalarType | None] = dict() - + self.__parameters: dict[str, ParameterInfo] = { + # System state parameters + "system:last_block_seen": ParameterInfo( + default=-1, + # NOTE: Don't allow external updates + update_handler=None, + ), + "system:last_block_processed": ParameterInfo( + default=-1, + # NOTE: Don't allow external updates + update_handler=None, + ), + } + + # Register system tasks self._create_system_startup_task() + # TODO: Make backup optional and settings-driven + self.backup = AppDatastore(app_id=self.identifier) self._create_system_backup_task() + self._create_batch_parameter_task() + + @property + def parameters(self) -> dict[str, ParameterInfo]: + # NOTE: makes this variable read-only + return self.__parameters def _create_system_startup_task(self): # Add a task to load all parameters from state at startup @@ -124,22 +151,24 @@ async def startup_handler() -> StateSnapshot: # NOTE: attribute does not exist before this task is executed, # ensuring no one uses it during worker startup - self.backup = AppDatastore() - if not (startup_state := await self.backup.init(app_id=self.identifier)): - return StateSnapshot(last_block_seen=-1, last_block_processed=-1) + if not (startup_state := await self.backup.load()): + logger.warning("No state snapshot detected, using empty snapshot") + startup_state = StateSnapshot() # Use empty snapshot - logger.info("Finding cached parameters: [" + ", ".join(self.parameter_defaults) + "]") + for param_name, param_info in self.parameters.items(): - for param_name, default in self.parameter_defaults.items(): - if (cached_value := startup_state.parameter_values.get(param_name)) is not None: - logger.debug(f"Found cached value for '{param_name}': {cached_value}") + if (cached_value := startup_state.parameters.get(param_name)) is not None: + logger.info(f"Found cached value for app.state['{param_name}']: {cached_value}") self.state[param_name] = cached_value - else: - logger.debug( - f"Cached value not found for '{param_name}', using default: {default}" + elif param_info.default is not None: + logger.info( + f"Cached value not found for app.state['{param_name}']" + f", using default: {param_info.default}" ) - self.state[param_name] = default + self.state[param_name] = param_info.default + + # NOTE: `None` default doesn't need to be set because that's how SharedState works return startup_state @@ -150,18 +179,27 @@ async def startup_handler() -> StateSnapshot: task_type=str(TaskType._RESTORE), ) + def _create_snapshot(self) -> StateSnapshot: + return StateSnapshot( + parameters={param_name: self.state[param_name] for param_name in self.parameters}, + ) + def _create_system_backup_task(self): # TODO: Make backups optional # TODO: Allow configuring backup class # Add a task to backup state before/after every non-system runtime task and at shutdown - async def backup_handler(snapshot: StateSnapshot): - for param_name in self.parameter_defaults: - # Save our current parameter values, if set - if (current_value := self.state[param_name]) is not None: - snapshot.parameter_values[param_name] = current_value + async def backup_handler( + last_block_seen: int | None = None, + last_block_processed: int | None = None, + ): + if last_block_seen is not None: + self.state["system:last_block_seen"] = last_block_seen + + if last_block_processed is not None: + self.state["system:last_block_processed"] = last_block_processed - return await self.backup.save(snapshot) + return await self.backup.save(self._create_snapshot()) self.backup_task = self.broker.register_task( backup_handler, @@ -170,17 +208,47 @@ async def backup_handler(snapshot: StateSnapshot): task_type=str(TaskType._BACKUP), ) + def _create_batch_parameter_task(self): + async def batch_parameters_handler(parameter_updates: dict): + # NOTE: This is one blocking atomic task, it must be handled atomically + datapoints = {} + for param_name, new_value in parameter_updates.items(): + if "system:" in param_name: + logger.error(f"Cannot update system parameter '{param_name}'") + + elif param_name not in self.parameters: + logger.error(f"Unrecognized parameter '{param_name}'") + + else: + datapoints[param_name] = ParamChangeDatapoint( + old=self.state[param_name], new=new_value + ) + logger.success(f"Update: app.state['{param_name}'] = {new_value}") + self.state[param_name] = new_value + + await self.backup.save(self._create_snapshot()) + return datapoints + + self.batch_parameters_task = self.broker.register_task( + batch_parameters_handler, + # NOTE: Name makes it impossible to conflict with user's handler fn names + task_name=str(TaskType._SET_PARAM_BATCH), + task_type=str(TaskType._SET_PARAM_BATCH), + ) + def add_parameter(self, param_name: str, default: ScalarType | None = None): - if param_name in self.parameter_defaults: - raise ValueError(f"{param_name} already added!") + if "system:" in param_name: + raise ValueError("Cannot override system parameters") - # Update this to track parameter existance/default value - self.parameter_defaults[param_name] = default + if param_name in self.parameters: + raise ValueError(f"{param_name} already added!") # This handler will handle parameter changes during runtime async def update_handler(new_value): datapoint = ParamChangeDatapoint(old=self.state[param_name], new=new_value) + logger.success(f"Update: app.state['{param_name}'] = {new_value}") self.state[param_name] = new_value + await self.backup.save(self._create_snapshot()) return datapoint broker_task = self.broker.register_task( @@ -190,8 +258,8 @@ async def update_handler(new_value): task_type=str(TaskType._SET_PARAM), ) - self.tasks[TaskType._SET_PARAM].append(TaskData(container=None, handler=broker_task)) - # TODO: Allow accepting parameter updates to .kiq this task somehow + # Update this to track parameter existance/default value/update handler + self.__parameters[param_name] = ParameterInfo(default=default, update_handler=broker_task) def broker_task_decorator( self, diff --git a/silverback/middlewares.py b/silverback/middlewares.py index d81431df..a620293e 100644 --- a/silverback/middlewares.py +++ b/silverback/middlewares.py @@ -88,7 +88,7 @@ def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: message.labels["transaction_hash"] = log.transaction_hash message.labels["log_index"] = str(log.log_index) - elif task_type in (TaskType.STARTUP, TaskType._BACKUP): + elif task_type is TaskType.STARTUP: message.args[0] = StateSnapshot.model_validate(message.args[0]) # Record task start (appears on worker in distributed mode) diff --git a/silverback/runner.py b/silverback/runner.py index 941c8e61..431c8296 100644 --- a/silverback/runner.py +++ b/silverback/runner.py @@ -11,6 +11,7 @@ from .application import SilverbackApp from .exceptions import Halt, NoTasksAvailableError, NoWebsocketAvailableError, StartupFailure from .recorder import BaseRecorder, TaskResult +from .state import StateSnapshot from .subscriptions import SubscriptionType, Web3SubscriptionsManager from .types import TaskType from .utils import async_wrap_iter, hexbytes_dict @@ -27,10 +28,8 @@ def __init__( ): self.app = app self.recorder = recorder - self.state = None - - self.max_exceptions = max_exceptions self.exceptions = 0 + self.max_exceptions = max_exceptions logger.info(f"Using {self.__class__.__name__}: max_exceptions={self.max_exceptions}") @@ -56,25 +55,13 @@ async def _checkpoint( last_block_processed: int | None = None, ): """Set latest checkpoint block number""" - assert self.state, f"{self.__class__.__name__}.run() not triggered." - - logger.debug( - ( - f"Checkpoint block [seen={self.state.last_block_seen}, " - f"procssed={self.state.last_block_processed}]" - ) - ) - - if last_block_seen: - self.state.last_block_seen = last_block_seen - if last_block_processed: - self.state.last_block_processed = last_block_processed - - task = await self.app.backup_task.kiq(self.state) + task = await self.app.backup_task.kiq(last_block_seen, last_block_processed) result = await task.wait_result() if result.is_err: logger.error(f"Error setting state: {result.error}") + elif not result.return_value: + logger.error("State backup unsuccessful") @abstractmethod async def _block_task(self, block_handler: AsyncTaskiqDecoratedTask): @@ -121,14 +108,15 @@ async def run(self): if result.is_err: raise StartupFailure(f"System startup failure: {result.error}") - self.state = result.return_value + if not (startup_state := result.return_value): + raise StartupFailure("System startup failed to return any state") + # NOTE: State snapshot is immediately out of date after init - # (except for our block seen/processed values, which are not updated in worker) - logger.debug(f"Startup state: {self.state}") + logger.debug(f"Startup state: {startup_state}") # Execute Silverback startup task(s) before entering into runtime mode if startup_tasks := await asyncio.gather( - *(task_def.handler.kiq(self.state) for task_def in self.app.tasks[TaskType.STARTUP]) + *(task_def.handler.kiq(startup_state) for task_def in self.app.tasks[TaskType.STARTUP]) ): results = await asyncio.gather( *(startup_task.wait_result() for startup_task in startup_tasks) @@ -189,7 +177,7 @@ async def run(self): # NOTE: No need to handle results otherwise # NOTE: Do one last backup - backup_task = await self.app.backup_task.kiq(self.state) + backup_task = await self.app.backup_task.kiq() result = await backup_task.wait_result() if result.is_err or not result.return_value: diff --git a/silverback/state.py b/silverback/state.py index 4ab7c20b..fcfcc114 100644 --- a/silverback/state.py +++ b/silverback/state.py @@ -1,23 +1,58 @@ from pathlib import Path +from typing import Any -from pydantic import BaseModel, Field +from ape.logging import get_logger +from pydantic import BaseModel, Field, field_validator -from .types import ScalarType, SilverbackID, UTCTimestamp, utc_now +from .types import ScalarType, SilverbackID, UTCTimestamp, is_scalar_type, utc_now +logger = get_logger(__name__) -class StateSnapshot(BaseModel): - # Last block number seen by runner - last_block_seen: int - - # Last block number processed by a worker - last_block_processed: int +class StateSnapshot(BaseModel): # Last time the state was updated # NOTE: intended to use default when creating a model with this type last_updated: UTCTimestamp = Field(default_factory=utc_now) # Stored parameters from last session - parameter_values: dict[str, ScalarType | None] = {} + parameters: dict[str, ScalarType] = {} + + @field_validator("parameters", mode="before") + def parse_parameters(cls, parameters: dict) -> dict: + # NOTE: Filter out any values that we cannot serialize + successfully_parsed_parameters = {} + for param_name, param_value in parameters.items(): + if is_scalar_type(param_value): + successfully_parsed_parameters[param_name] = param_value + else: + logger.error( + f"Cannot backup '{param_name}' of type '{type(param_value)}': {param_value}" + ) + + return successfully_parsed_parameters + + @property + def last_block_seen(self) -> int: + # Last block number seen by runner + return self.parameters.get("system:last_block_seen", -1) # type: ignore[return-value] + + @property + def last_block_processed(self) -> int: + # Last block number processed by a worker + return self.parameters.get("system:last_block_processed", -1) # type: ignore[return-value] + + def __dir__(self) -> list[str]: + return [ + *(param for param in self.parameters if "system:" not in param), + "last_block_processed", + "last_block_seen", + ] + + def __getattr__(self, attr: str) -> Any: + try: + return super().__getattr__(attr) # type: ignore[misc] + except AttributeError: + return self.parameters.get(attr) class AppDatastore: @@ -41,14 +76,14 @@ class AppDatastore: - `SILVERBACK_APP_NAME`: Any alphabetical string valid as a folder name """ - async def init(self, app_id: SilverbackID) -> StateSnapshot | None: + def __init__(self, app_id: SilverbackID): data_folder = ( Path.cwd() / ".silverback-sessions" / app_id.name / app_id.ecosystem / app_id.network ) data_folder.mkdir(parents=True, exist_ok=True) - self.state_backup_file = data_folder / "state.json" + async def load(self) -> StateSnapshot | None: return ( StateSnapshot.parse_file(self.state_backup_file) if self.state_backup_file.exists() @@ -56,14 +91,5 @@ async def init(self, app_id: SilverbackID) -> StateSnapshot | None: ) async def save(self, snapshot: StateSnapshot) -> bool: - if self.state_backup_file.exists(): - old_snapshot = StateSnapshot.parse_file(self.state_backup_file) - if old_snapshot.last_block_seen > snapshot.last_block_seen: - snapshot.last_block_seen = old_snapshot.last_block_seen - if old_snapshot.last_block_processed > snapshot.last_block_processed: - snapshot.last_block_processed = old_snapshot.last_block_processed - - snapshot.last_updated = utc_now() self.state_backup_file.write_text(snapshot.model_dump_json()) - return True # Successful diff --git a/silverback/types.py b/silverback/types.py index 1b8fe333..8c756aad 100644 --- a/silverback/types.py +++ b/silverback/types.py @@ -1,10 +1,10 @@ from datetime import datetime, timezone from decimal import Decimal from enum import Enum # NOTE: `enum.StrEnum` only in Python 3.11+ -from typing import Literal, get_args +from typing import Any, Literal, get_args from ape.logging import get_logger -from pydantic import BaseModel, Field, RootModel, ValidationError, model_validator +from pydantic import BaseModel, Field, RootModel, model_validator from pydantic.functional_serializers import PlainSerializer from typing_extensions import Annotated @@ -16,6 +16,7 @@ class TaskType(str, Enum): _RESTORE = "system:load-snapshot" _BACKUP = "system:save-snapshot" _SET_PARAM = "system:set-param" + _SET_PARAM_BATCH = "system:batch-param" # User-accessible tasks STARTUP = "startup" @@ -60,6 +61,13 @@ class _BaseDatapoint(BaseModel): # This is okay, preferable actually, because it means we can store ints outside that range +def is_scalar_type(val: Any) -> bool: + return any( + isinstance(val, d_type.__origin__ if hasattr(d_type, "__origin__") else d_type) + for d_type in get_args(ScalarType) + ) + + class ScalarDatapoint(_BaseDatapoint): type: Literal["scalar"] = "scalar" data: ScalarType @@ -78,28 +86,31 @@ class ParamChangeDatapoint(_BaseDatapoint): Datapoint = ScalarDatapoint | ParamChangeDatapoint +def is_datapoint(val: Any) -> bool: + return any(isinstance(val, d_type) for d_type in get_args(Datapoint)) + + class Datapoints(RootModel): root: dict[str, Datapoint] @model_validator(mode="before") def parse_datapoints(cls, datapoints: dict) -> dict: - names_to_remove: dict[str, ValidationError] = {} - # Automatically convert raw scalar types - for name in datapoints: - if not any(isinstance(datapoints[name], d_type) for d_type in get_args(Datapoint)): - try: - datapoints[name] = ScalarDatapoint(data=datapoints[name]) - except ValidationError as e: - names_to_remove[name] = e - - # Prune and raise a warning about unconverted datapoints - for name in names_to_remove: - data = datapoints.pop(name) - logger.warning( - f"Cannot convert datapoint '{name}' of type '{type(data)}': {names_to_remove[name]}" - ) - - return datapoints + successfully_parsed_datapoints = {} + for name, datapoint in datapoints.items(): + if is_datapoint(datapoint): + successfully_parsed_datapoints[name] = datapoint + + elif is_scalar_type(datapoint): + # Automatically convert raw scalar types into datapoints + successfully_parsed_datapoints[name] = ScalarDatapoint(data=datapoints[name]) + + else: + # Prune and raise a warning about unconverted datapoints + logger.warning( + f"Cannot convert datapoint '{name}' of type '{type(datapoint)}': {datapoint}" + ) + + return successfully_parsed_datapoints # Add dict methods def get(self, key: str, default: Datapoint | None = None) -> Datapoint | None: