Skip to content

Commit

Permalink
feat(python): rework channel database
Browse files Browse the repository at this point in the history
[no changelog]
  • Loading branch information
M1nd3r committed Nov 19, 2024
1 parent dc832ab commit 641611b
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 95 deletions.
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ typing_extensions>=4.7.1
construct-classes>=0.1.2
appdirs>=1.4.4
cryptography >=43.0.3
platformdirs >=2
8 changes: 4 additions & 4 deletions python/src/trezorlib/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..client import TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.thp import channel_database
from ..transport.thp.channel_database import get_channel_db

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,7 +102,7 @@ def get_passphrase(


def get_client(transport: Transport) -> TrezorClient:
stored_channels = channel_database.load_stored_channels()
stored_channels = get_channel_db().load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
Expand All @@ -115,7 +115,7 @@ def get_client(transport: Transport) -> TrezorClient:
)
except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.")
channel_database.remove_channel(path)
get_channel_db().remove_channel(path)
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
Expand Down Expand Up @@ -355,7 +355,7 @@ def trezorctl_command_with_client(
try:
return func(client, *args, **kwargs)
finally:
channel_database.save_channel(client.protocol)
get_channel_db().save_channel(client.protocol)
# if not session_was_resumed:
# try:
# client.end_session()
Expand Down
27 changes: 19 additions & 8 deletions python/src/trezorlib/cli/trezorctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.thp import channel_database
from ..transport.thp.channel_database import get_channel_db
from ..transport.udp import UdpTransport
from . import (
AliasedGroup,
Expand Down Expand Up @@ -196,6 +197,13 @@ def configure_logging(verbose: int) -> None:
"--record",
help="Record screen changes into a specified directory.",
)
@click.option(
"-n",
"--no-store",
is_flag=True,
help="Do not store channels data between commands.",
default=False,
)
@click.version_option(version=__version__)
@click.pass_context
def cli_main(
Expand All @@ -207,9 +215,10 @@ def cli_main(
script: bool,
session_id: Optional[str],
record: Optional[str],
no_store: bool,
) -> None:
configure_logging(verbose)

channel_database.set_channel_database(should_not_store=no_store)
bytes_session_id: Optional[bytes] = None
if session_id is not None:
try:
Expand Down Expand Up @@ -296,10 +305,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
try:
client = get_client(transport)
description = format_device_name(client.features)
# json_string = channel_database.channel_to_str(client.protocol)
# print(json_string)
channel_database.save_channel(client.protocol)
# client.end_session()
get_channel_db().save_channel(client.protocol)
except DeviceIsBusy:
description = "Device is in use by another process"
except Exception:
Expand Down Expand Up @@ -376,9 +382,14 @@ def clear_session(session: "Session") -> None:


@cli.command()
def new_clear_session() -> None:
"""New Clear session (remove cached channels from trezorlib)."""
channel_database.clear_stored_channels()
def delete_channels() -> None:
"""
Delete cached channels.
Do not use together with the `-n` (`--no-store`) flag,
as the JSON database will not be deleted.
"""
get_channel_db().clear_stored_channels()


@cli.command()
Expand Down
183 changes: 105 additions & 78 deletions python/src/trezorlib/transport/thp/channel_database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import logging
import os
Expand All @@ -8,39 +10,104 @@

LOG = logging.getLogger(__name__)

if True:
from platformdirs import user_cache_dir, user_config_dir
db: "ChannelDatabase | None" = None


def get_channel_db() -> ChannelDatabase:
if db is None:
set_channel_database(should_not_store=True)
assert db is not None
return db


class ChannelDatabase:

APP_NAME = "@trezor" # TODO
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
CONFIG_PATH = os.path.join(user_config_dir(appname=APP_NAME), "config.json")
else:
DATA_PATH = os.path.join("./channel_data.json")
CONFIG_PATH = os.path.join("./config.json")
def load_stored_channels(self) -> t.List[ChannelData]: ...
def clear_stored_channels(self) -> None: ...
def read_all_channels(self) -> t.List: ...
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
def save_channel(self, new_channel: ProtocolAndChannel): ...
def remove_channel(self, transport_path: str) -> None: ...


class ChannelDatabase: # TODO not finished
should_store: bool = False
class DummyChannelDatabase(ChannelDatabase):

def __init__(
self, config_path: str = CONFIG_PATH, data_path: str = DATA_PATH
) -> None:
if not os.path.exists(CONFIG_PATH):
with open(CONFIG_PATH, "w") as f:
json.dump([], f)
def load_stored_channels(self) -> t.List[ChannelData]:
return []

def clear_stored_channels(self) -> None:
pass

def load_stored_channels() -> t.List[ChannelData]:
dicts = read_all_channels()
return [dict_to_channel_data(d) for d in dicts]
def read_all_channels(self) -> t.List:
return []

def save_all_channels(self, channels: t.List[t.Dict]) -> None:
return

def channel_to_str(channel: ProtocolAndChannel) -> str:
return json.dumps(channel.get_channel_data().to_dict())
def save_channel(self, new_channel: ProtocolAndChannel):
pass

def remove_channel(self, transport_path: str) -> None:
pass

def str_to_channel_data(channel_data: str) -> ChannelData:
return dict_to_channel_data(json.loads(channel_data))

class JsonChannelDatabase(ChannelDatabase):
def __init__(self, data_path: str) -> None:
self.data_path = data_path
super().__init__()

def load_stored_channels(self) -> t.List[ChannelData]:
dicts = self.read_all_channels()
return [dict_to_channel_data(d) for d in dicts]

def clear_stored_channels(self) -> None:
LOG.debug("Clearing contents of %s", self.data_path)
with open(self.data_path, "w") as f:
json.dump([], f)
try:
os.remove(self.data_path)
except Exception as e:
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))

def read_all_channels(self) -> t.List:
ensure_file_exists(self.data_path)
with open(self.data_path, "r") as f:
return json.load(f)

def save_all_channels(self, channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(self.data_path, "w") as f:
json.dump(channels, f, indent=4)

def save_channel(self, new_channel: ProtocolAndChannel):

LOG.debug("save channel")
channels = self.read_all_channels()
transport_path = new_channel.transport.get_path()

# If the channel is found in database: replace the old entry by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
self.save_all_channels(channels)
return

# Channel was not found: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
self.save_all_channels(channels)

def remove_channel(self, transport_path: str) -> None:
LOG.debug(
"Removing channel with path %s from the channel database.",
transport_path,
)
channels = self.read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
self.save_all_channels(remaining_channels)


def dict_to_channel_data(dict: t.Dict) -> ChannelData:
Expand All @@ -57,63 +124,23 @@ def dict_to_channel_data(dict: t.Dict) -> ChannelData:
)


def ensure_file_exists() -> None:
LOG.debug("checking if file %s exists", DATA_PATH)
if not os.path.exists(DATA_PATH):
os.makedirs(os.path.dirname(DATA_PATH), exist_ok=True)
LOG.debug("File %s does not exist. Creating a new one.", DATA_PATH)
with open(DATA_PATH, "w") as f:
def ensure_file_exists(file_path: str) -> None:
LOG.debug("checking if file %s exists", file_path)
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
LOG.debug("File %s does not exist. Creating a new one.", file_path)
with open(file_path, "w") as f:
json.dump([], f)


def clear_stored_channels() -> None:
LOG.debug("Clearing contents of %s", DATA_PATH)
with open(DATA_PATH, "w") as f:
json.dump([], f)
try:
os.remove(DATA_PATH)
except Exception as e:
LOG.exception("Failed to delete %s (%s)", DATA_PATH, str(type(e)))
def set_channel_database(should_not_store: bool):
global db
if should_not_store:
db = DummyChannelDatabase()
else:
from platformdirs import user_cache_dir

APP_NAME = "@trezor" # TODO
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")

def read_all_channels() -> t.List:
ensure_file_exists()
with open(DATA_PATH, "r") as f:
return json.load(f)


def save_all_channels(channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(DATA_PATH, "w") as f:
json.dump(channels, f, indent=4)


def save_channel(new_channel: ProtocolAndChannel):
LOG.debug("save channel")
channels = read_all_channels()
transport_path = new_channel.transport.get_path()

# If the channel is found in database: replace the old entry by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
save_all_channels(channels)
return

# Channel was not found: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
save_all_channels(channels)


def remove_channel(transport_path: str) -> None:
LOG.debug(
"Removing channel with path %s from the channel database.",
transport_path,
)
channels = read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
save_all_channels(remaining_channels)
db = JsonChannelDatabase(DATA_PATH)
8 changes: 5 additions & 3 deletions python/src/trezorlib/transport/thp/protocol_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader
from . import channel_database, control_byte
from . import control_byte
from .channel_database import ChannelDatabase, get_channel_db
from .protocol_and_channel import ProtocolAndChannel

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self._has_valid_channel = True
self.channel_database: ChannelDatabase = get_channel_db()

def get_channel(self) -> ProtocolV2:
if not self._has_valid_channel:
Expand All @@ -99,13 +101,13 @@ def read(self, session_id: int) -> t.Any:
sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id:
raise Exception("Received messsage on a different session.")
channel_database.save_channel(self)
self.channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)

def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data)
channel_database.save_channel(self)
self.channel_database.save_channel(self)

def get_features(self) -> messages.Features:
if not self._has_valid_channel:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ def client(
# Get a new client
_raw_client = _get_raw_client(request)

from trezorlib.transport.thp import channel_database
from trezorlib.transport.thp.channel_database import get_channel_db

channel_database.clear_stored_channels()
get_channel_db().clear_stored_channels()
_raw_client.protocol = None
_raw_client.__init__(
transport=_raw_client.transport,
Expand Down

0 comments on commit 641611b

Please sign in to comment.