Skip to content

Commit

Permalink
fix: adjust returns of .get() method
Browse files Browse the repository at this point in the history
  • Loading branch information
manuba95 committed Sep 7, 2024
1 parent 95cfaef commit b135a72
Showing 1 changed file with 41 additions and 26 deletions.
67 changes: 41 additions & 26 deletions floodlight/io/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,19 +899,19 @@ def __init__(self, dataset_dir_name="idsse_dataset", match_id="J03WMX"):
f"{self._IDSSE_SCHEMA}://"
f"{self._IDSSE_BASE_URL}/"
f"{self._IDSSE_FILE_IDS_INFO[match_id]}"
f"?private_link={self._IDSSE_PRIVAT_LINK}"
f"?private_link={self._IDSSE_PRIVATE_LINK}"
)
self._IDSSE_HOST_URL_EVENT = (
f"{self._IDSSE_SCHEMA}://"
f"{self._IDSSE_BASE_URL}/"
f"{self._IDSSE_FILE_IDS_EVENT[match_id]}"
f"?private_link={self._IDSSE_PRIVAT_LINK}"
f"?private_link={self._IDSSE_PRIVATE_LINK}"
)
self._IDSSE_HOST_URL_POSITION = (
f"{self._IDSSE_SCHEMA}://"
f"{self._IDSSE_BASE_URL}/"
f"{self._IDSSE_FILE_IDS_POSITION[match_id]}"
f"?private_link={self._IDSSE_PRIVAT_LINK}"
f"?private_link={self._IDSSE_PRIVATE_LINK}"
)
elif match_id == "all":
pass
Expand Down Expand Up @@ -972,19 +972,19 @@ def __init__(self, dataset_dir_name="idsse_dataset", match_id="J03WMX"):
f"{self._IDSSE_SCHEMA}://"
f"{self._IDSSE_BASE_URL}/"
f"{self._IDSSE_FILE_IDS_INFO[file_id]}"
f"?private_link={self._IDSSE_PRIVAT_LINK}"
f"?private_link={self._IDSSE_PRIVATE_LINK}"
)
self._IDSSE_HOST_URL_EVENT = (
f"{self._IDSSE_SCHEMA}://"
f"{self._IDSSE_BASE_URL}/"
f"{self._IDSSE_FILE_IDS_EVENT[file_id]}"
f"?private_link={self._IDSSE_PRIVAT_LINK}"
f"?private_link={self._IDSSE_PRIVATE_LINK}"
)
self._IDSSE_HOST_URL_POSITION = (
f"{self._IDSSE_SCHEMA}://"
f"{self._IDSSE_BASE_URL}/"
f"{self._IDSSE_FILE_IDS_POSITION[file_id]}"
f"?private_link={self._IDSSE_PRIVAT_LINK}"
f"?private_link={self._IDSSE_PRIVATE_LINK}"
)

self._IDSSE_FILE_NAME_INFO = (
Expand Down Expand Up @@ -1043,14 +1043,12 @@ def get(
events=True,
positions=True,
) -> Tuple[
Tuple[Dict[str, Dict[str, Events]], Dict[str, Teamsheet], Pitch],
Tuple[
Dict[str, Dict[str, XY]],
Dict[str, Code],
Dict[str, Code],
Dict[str, Teamsheet],
Pitch,
],
Dict[str, Dict[str, Events]],
Dict[str, Dict[str, XY]],
Dict[str, Code],
Dict[str, Code],
Dict[str, Teamsheet],
Pitch,
]:
"""Get event and position data from the IDSSE dataset.
Expand All @@ -1069,22 +1067,25 @@ def get(
extracted from the data.
events: bool, optional
Specifies whether the event data should be returned. Default is True. If
false None will be returned instead of the event data objects.
false None will be returned instead of the events-objects.
positions: bool, optional
Specifies whether the position data should be returned. Default is True. If
false None will be returned instead of the position data objects. This will
improve performance considerably if only event data is required.
false None will be returned instead of the XY-objects, possession-objects,
and ballstatus-objects. This will improve performance considerably if only
event data is required.
Returns
-------
match_data: Tuple[Tuple[Events, Teamsheets, Pitch], Tuple[Dict[XY], Dict[Code],
Dict[Code], Dict[Teamsheets], Pitch]
Returns a tuple of shape (event_data, position_data) as returned by the
match_data: Tuple[Dict[str, Dict[str, Events]], Dict[str, Dict[str, XY]],
Dict[str, Code], Dict[str, Code], Dict[str, Teamsheet],Pitch
Returns a tuple of shape (events_objects, xy_objects, possession_objects,
ballstatus_objects, teamsheets_objects, pitch_object) as returned by the
``floodlight.io.dfl.read_event_data_xml()`` and
``floodlight.io.dfl.read_position_data_xml()`` functions for the requested
match. If any of the arguments ``events`` or ``positions`` are set to False,
None is returned instead of `event_data` or `position_data`, respectively.
None is returned instead of `event_data` or `xy_objects`,
`possession_objects`, and `ballstatus_objects`, respectively.
"""

if match_id in ["J03WMX", "J03WN1"]:
Expand Down Expand Up @@ -1121,20 +1122,34 @@ def get(

# parse event data
if events is True:
data_objects_events = read_event_data_xml(
events_objects, teamsheets_objects, pitch_object = read_event_data_xml(
file_name_events, file_name_infos, teamsheet_home, teamsheet_away
)
else:
data_objects_events = None
events_objects, teamsheets_objects, pitch_object = (None, None, None)

# parse position data
if positions is True:
data_objects_positions = read_position_data_xml(
(
xy_objects,
possession_objects,
ballstatus_objects,
teamsheets_objects,
pitch_object,
) = read_position_data_xml(
file_name_positions, file_name_infos, teamsheet_home, teamsheet_away
)
else:
data_objects_events = None
xy_objects, possession_objects, ballstatus_objects = (None, None, None)

# assemble
match_data = (data_objects_events, data_objects_positions)
match_data = (
events_objects,
xy_objects,
possession_objects,
ballstatus_objects,
teamsheets_objects,
pitch_object,
)

return match_data

0 comments on commit b135a72

Please sign in to comment.