Skip to content

Commit

Permalink
fix mypy issues and refactor TRCWriter a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
zariiii9003 committed Jan 14, 2024
1 parent 9ae767e commit e60393f
Showing 1 changed file with 44 additions and 33 deletions.
77 changes: 44 additions & 33 deletions can/io/trc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@
import os
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Callable, Dict, Generator, List, Optional, TextIO, Union
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Mapping,
Optional,
TextIO,
Tuple,
Union,
)

from ..message import Message
from ..typechecking import StringPathLike
Expand Down Expand Up @@ -257,18 +268,19 @@ class TRCWriter(TextIOMessageWriter):
file: TextIO
first_timestamp: Optional[float]

FORMAT_MESSAGE = (
"{msgnr:>7} {time:13.3f} DT {channel:>2} {id:>8} {dir:>2} - {dlc:<4} {data}"
)
FORMAT_MESSAGE_V1_1 = "{msgnr:>6}){time:12.3f} Rx {id:>8} {dlc:<1} {data}"

FORMAT_MESSAGE_V1_0 = "{msgnr:>6}) {time:7.0f} {id:>8} {dlc:<1} {data}"
MESSAGE_FORMAT_MAP: Mapping[TRCFileVersion, str] = {
TRCFileVersion.V1_0: "{msgnr:>6}) {time:7.0f} {id:>8} {dlc:<1} {data}",
TRCFileVersion.V1_1: "{msgnr:>6}){time:12.3f} Rx {id:>8} {dlc:<1} {data}",
TRCFileVersion.V2_1: (
"{msgnr:>7} {time:13.3f} DT {channel:>2} {id:>8} {dir:>2} - {dlc:<4} {data}"
),
}

def __init__(
self,
file: Union[StringPathLike, TextIO],
channel: int = 1,
fileversion : int = TRCFileVersion.V2_1,
file_version: Union[int, TRCFileVersion] = TRCFileVersion.V1_0,
**kwargs: Any,
) -> None:
"""
Expand All @@ -278,6 +290,12 @@ def __init__(
:param channel: a default channel to use when the message does not
have a channel set
"""
if kwargs.get("append", False):
raise ValueError(
f"{self.__class__.__name__} is currently not equipped to "
f"append messages to an existing file."
)

super().__init__(file, mode="w")
self.channel = channel

Expand All @@ -289,10 +307,19 @@ def __init__(
self.filepath = os.path.abspath(self.file.name)
self.header_written = False
self.msgnr = 0
self.first_timestamp = None
self.file_version = fileversion
self._msg_fmt_string = self.FORMAT_MESSAGE_V1_0
self._format_message = self._format_message_init
self.first_timestamp: Optional[float] = None
self.file_version, self._msg_fmt_string = self._parse_version(file_version)

def _parse_version(
self, file_version: Union[int, TRCFileVersion]
) -> Tuple[TRCFileVersion, str]:
try:
version = TRCFileVersion(file_version)
msg_fmt_string = self.MESSAGE_FORMAT_MAP[version]
return version, msg_fmt_string
except (KeyError, ValueError) as exc:
err_msg = f"File version is not supported: {file_version}"
raise NotImplementedError(err_msg) from exc

def _write_header_v1_0(self, start_time: datetime) -> None:
lines = [
Expand All @@ -316,7 +343,7 @@ def _write_header_v1_0(self, start_time: datetime) -> None:
self.file.writelines(line + "\n" for line in lines)

def _write_header_v1_1(self, start_time: datetime) -> None:
header_time = start_time - datetime(year=1899, month=12, day=30)
header_time = start_time - datetime(year=1899, month=12, day=30)
lines = [
";$FILEVERSION=1.1",
f";$STARTTIME={header_time/timedelta(days=1)}",
Expand All @@ -337,11 +364,10 @@ def _write_header_v1_1(self, start_time: datetime) -> None:
"; | | | | Data Length",
"; | | | | | Data Bytes (hex) ...",
"; | | | | | |",
";---+-- ----+---- --+-- ----+--- + -+ -- -- -- -- -- -- --" ,
";---+-- ----+---- --+-- ----+--- + -+ -- -- -- -- -- -- --",
]
self.file.writelines(line + "\n" for line in lines)


def _write_header_v2_1(self, start_time: datetime) -> None:
header_time = start_time - datetime(year=1899, month=12, day=30)
lines = [
Expand All @@ -366,7 +392,7 @@ def _write_header_v2_1(self, start_time: datetime) -> None:
]
self.file.writelines(line + "\n" for line in lines)

def _format_message_by_format(self, msg, channel):
def _format_message(self, msg: Message, channel: int) -> str:
if msg.is_extended_id:
arb_id = f"{msg.arbitration_id:07X}"
else:
Expand All @@ -376,7 +402,7 @@ def _format_message_by_format(self, msg, channel):

serialized = self._msg_fmt_string.format(
msgnr=self.msgnr,
time=(msg.timestamp - self.first_timestamp) * 1000,
time=(msg.timestamp - (self.first_timestamp or 0.0)) * 1000,
channel=channel,
id=arb_id,
dir="Rx" if msg.is_rx else "Tx",
Expand All @@ -385,28 +411,13 @@ def _format_message_by_format(self, msg, channel):
)
return serialized

def _format_message_init(self, msg, channel):
if self.file_version == TRCFileVersion.V1_0:
self._format_message = self._format_message_by_format
self._msg_fmt_string = self.FORMAT_MESSAGE_V1_0
elif self.file_version == TRCFileVersion.V2_1:
self._format_message = self._format_message_by_format
self._msg_fmt_string = self.FORMAT_MESSAGE
elif self.file_version == TRCFileVersion.V1_1:
self._format_message = self._format_message_by_format
self._msg_fmt_string = self.FORMAT_MESSAGE_V1_1
else:
raise NotImplementedError("File format is not supported")

return self._format_message_by_format(msg, channel)

def write_header(self, timestamp: float) -> None:
# write start of file header
start_time = datetime.utcfromtimestamp(timestamp)

if self.file_version == TRCFileVersion.V1_0:
self._write_header_v1_0(start_time)
elif self.file_version == TRCFileVersion.V1_1 :
elif self.file_version == TRCFileVersion.V1_1:
self._write_header_v1_1(start_time)
elif self.file_version == TRCFileVersion.V2_1:
self._write_header_v2_1(start_time)
Expand Down

0 comments on commit e60393f

Please sign in to comment.