diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 793f83142..f32ff7b85 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,19 +1,19 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.2.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: check-docstring-first - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + rev: 4.0.1 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v1.5.7 + rev: v1.6.0 hooks: - id: autopep8 - repo: https://github.com/PyCQA/isort - rev: 5.9.3 + rev: 5.10.1 hooks: - id: isort diff --git a/server/config.py b/server/config.py index f416b9f3c..3cb4dea94 100644 --- a/server/config.py +++ b/server/config.py @@ -125,6 +125,11 @@ def __init__(self): # How many previous queue sizes to consider self.QUEUE_POP_TIME_MOVING_AVG_SIZE = 5 + # Whether floats should be rounded before json encoding + self.JSON_ROUND_FLOATS = True + # The maximum number of decimal places to use for float serialization + self.JSON_ROUND_FLOATS_PRECISION = 2 + self._defaults = { key: value for key, value in vars(self).items() if key.isupper() } diff --git a/server/protocol/protocol.py b/server/protocol/protocol.py index a525f5a52..43da7e508 100644 --- a/server/protocol/protocol.py +++ b/server/protocol/protocol.py @@ -4,10 +4,30 @@ from asyncio import StreamReader, StreamWriter import server.metrics as metrics +from server.config import config from ..asyncio_extensions import synchronizedmethod -json_encoder = json.JSONEncoder(separators=(",", ":")) + +class CustomJSONEncoder(json.JSONEncoder): + # taken from https://stackoverflow.com/a/53798633 + def encode(self, o): + def round_floats(o): + if isinstance(o, float): + return round(o, config.JSON_ROUND_FLOATS_PRECISION) + if isinstance(o, dict): + return {k: round_floats(v) for k, v in o.items()} + if isinstance(o, (list, tuple)): + return [round_floats(x) for x in o] + return o + + if config.JSON_ROUND_FLOATS: + return super().encode(round_floats(o)) + else: + return super().encode(o) + + +json_encoder = CustomJSONEncoder(separators=(",", ":")) class DisconnectedError(ConnectionError): diff --git a/server/servercontext.py b/server/servercontext.py index 726688d5c..6e8190d74 100644 --- a/server/servercontext.py +++ b/server/servercontext.py @@ -132,8 +132,8 @@ async def client_connected(self, stream_reader, stream_writer): asyncio.CancelledError, ): pass - except Exception: - self._logger.exception() + except Exception as e: + self._logger.exception(e) finally: del self.connections[connection] # Do not wait for buffers to empty here. This could stop the process diff --git a/tests/integration_tests/test_game.py b/tests/integration_tests/test_game.py index 7fa648aaf..1a224ed8d 100644 --- a/tests/integration_tests/test_game.py +++ b/tests/integration_tests/test_game.py @@ -307,7 +307,10 @@ async def test_game_ended_rates_game(lobby_server): @pytest.mark.rabbitmq @fast_forward(30) -async def test_game_ended_broadcasts_rating_update(lobby_server, channel): +async def test_game_ended_broadcasts_rating_update( + lobby_server, channel, mocker, +): + mocker.patch("server.config.JSON_ROUND_FLOATS_PRECISION", 4) mq_proto_all = await connect_mq_consumer( lobby_server, channel, @@ -611,12 +614,13 @@ async def test_partial_game_ended_rates_game(lobby_server, tmp_user): @fast_forward(100) -async def test_ladder_game_draw_bug(lobby_server, database): +async def test_ladder_game_draw_bug(lobby_server, database, mocker): """ This simulates the infamous "draw bug" where a player could self destruct their own ACU in order to kill the enemy ACU and be awarded a victory instead of a draw. """ + mocker.patch("server.config.JSON_ROUND_FLOATS_PRECISION", 13) player1_id, proto1, player2_id, proto2 = await queue_players_for_matchmaking(lobby_server) msg1, msg2 = await asyncio.gather(*[ diff --git a/tests/unit_tests/test_protocol.py b/tests/unit_tests/test_protocol.py index 69b37902a..be5f398cd 100644 --- a/tests/unit_tests/test_protocol.py +++ b/tests/unit_tests/test_protocol.py @@ -13,6 +13,7 @@ QDataStreamProtocol, SimpleJsonProtocol ) +from server.protocol.protocol import json_encoder @pytest.fixture(scope="session") @@ -256,3 +257,47 @@ async def test_read_when_disconnected(protocol): with pytest.raises(DisconnectedError): await protocol.read_message() + + +def test_json_encoder_float_serialization(): + assert json_encoder.encode(123.0) == "123.0" + assert json_encoder.encode(0.99) == "0.99" + assert json_encoder.encode(0.999) == "1.0" + + +@given(message=st_messages()) +def test_json_encoder_encodes_server_messages(message): + new_encode = json_encoder.encode + old_encode = json.JSONEncoder(separators=(",", ":")).encode + + assert new_encode(message) == old_encode(message) + + +def st_dictionaries(): + value_types = ( + st.booleans(), + st.text(), + st.integers(), + st.none(), + ) + key_types = (*value_types, st.floats()) + return st.dictionaries( + keys=st.one_of(*key_types), + values=st.one_of( + *value_types, + st.lists(st.one_of(*value_types)), + st.tuples(st.one_of(*value_types)), + ) + ) + + +@ given(dct=st_dictionaries()) +def test_json_encoder_encodes_dicts(dct): + old_encode = json.JSONEncoder(separators=(",", ":")).encode + new_encode = json_encoder.encode + + assert new_encode(dct) == old_encode(dct) + + wrong_dict_key = (1, 2) + with pytest.raises(TypeError): + json_encoder.encode({wrong_dict_key: "a"})