Skip to content

Commit

Permalink
Add more type stubs for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WyattBlue committed Sep 25, 2024
1 parent 339fc48 commit 6e9698f
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 174 deletions.
20 changes: 11 additions & 9 deletions tests/test_codec_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ class Options(TypedDict, total=False):

@overload
def iter_raw_frames(
path: str, packet_sizes: list, ctx: VideoCodecContext
path: str, packet_sizes: list[int], ctx: VideoCodecContext
) -> Iterator[VideoFrame]: ...
@overload
def iter_raw_frames(
path: str, packet_sizes: list, ctx: AudioCodecContext
path: str, packet_sizes: list[int], ctx: AudioCodecContext
) -> Iterator[AudioFrame]: ...
def iter_raw_frames(
path: str, packet_sizes: list, ctx: VideoCodecContext | AudioCodecContext
path: str, packet_sizes: list[int], ctx: VideoCodecContext | AudioCodecContext
) -> Iterator[VideoFrame | AudioFrame]:
with open(path, "rb") as f:
for i, size in enumerate(packet_sizes):
Expand Down Expand Up @@ -85,14 +85,16 @@ def test_codec_tag(self):
assert ctx.codec_tag == "xvid"

# wrong length
with self.assertRaises(ValueError) as cm:
with pytest.raises(
ValueError, match="Codec tag should be a 4 character string"
):
ctx.codec_tag = "bob"
assert str(cm.exception) == "Codec tag should be a 4 character string."

# wrong type
with self.assertRaises(ValueError) as cm:
with pytest.raises(
ValueError, match="Codec tag should be a 4 character string"
):
ctx.codec_tag = 123
assert str(cm.exception) == "Codec tag should be a 4 character string."

with av.open(fate_suite("h264/interlaced_crop.mp4")) as container:
assert container.streams[0].codec_tag == "avc1"
Expand Down Expand Up @@ -175,14 +177,14 @@ def test_bits_per_coded_sample(self):
with pytest.raises(ValueError):
stream.codec_context.bits_per_coded_sample = 32

def test_parse(self):
def test_parse(self) -> None:
# This one parses into a single packet.
self._assert_parse("mpeg4", fate_suite("h264/interlaced_crop.mp4"))

# This one parses into many small packets.
self._assert_parse("mpeg2video", fate_suite("mpeg2/mpeg2_field_encoding.ts"))

def _assert_parse(self, codec_name, path):
def _assert_parse(self, codec_name: str, path: str) -> None:
fh = av.open(path)
packets = []
for packet in fh.demux(video=0):
Expand Down
41 changes: 20 additions & 21 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,46 +57,45 @@ def test_decode_audio_sample_count(self) -> None:
)
assert sample_count == total_samples

def test_decoded_time_base(self):
def test_decoded_time_base(self) -> None:
container = av.open(fate_suite("h264/interlaced_crop.mp4"))
stream = container.streams.video[0]

assert stream.time_base == Fraction(1, 25)

for packet in container.demux(stream):
for frame in packet.decode():
assert not isinstance(frame, av.subtitles.subtitle.SubtitleSet)
assert packet.time_base == frame.time_base
assert stream.time_base == frame.time_base
return

def test_decoded_motion_vectors(self):
def test_decoded_motion_vectors(self) -> None:
container = av.open(fate_suite("h264/interlaced_crop.mp4"))
stream = container.streams.video[0]
codec_context = stream.codec_context
codec_context.options = {"flags2": "+export_mvs"}

for packet in container.demux(stream):
for frame in packet.decode():
vectors = frame.side_data.get("MOTION_VECTORS")
if frame.key_frame:
# Key frame don't have motion vectors
assert vectors is None
else:
assert len(vectors) > 0
return

def test_decoded_motion_vectors_no_flag(self):
for frame in container.decode(stream):
vectors = frame.side_data.get("MOTION_VECTORS")
if frame.key_frame:
# Key frame don't have motion vectors
assert vectors is None
else:
assert vectors is not None and len(vectors) > 0
return

def test_decoded_motion_vectors_no_flag(self) -> None:
container = av.open(fate_suite("h264/interlaced_crop.mp4"))
stream = container.streams.video[0]

for packet in container.demux(stream):
for frame in packet.decode():
vectors = frame.side_data.get("MOTION_VECTORS")
if not frame.key_frame:
assert vectors is None
return
for frame in container.decode(stream):
vectors = frame.side_data.get("MOTION_VECTORS")
if not frame.key_frame:
assert vectors is None
return

def test_decode_video_corrupt(self):
def test_decode_video_corrupt(self) -> None:
# write an empty file
path = self.sandboxed("empty.h264")
with open(path, "wb"):
Expand All @@ -114,7 +113,7 @@ def test_decode_video_corrupt(self):
assert packet_count == 1
assert frame_count == 0

def test_decode_close_then_use(self):
def test_decode_close_then_use(self) -> None:
container = av.open(fate_suite("h264/interlaced_crop.mp4"))
container.close()

Expand Down
25 changes: 8 additions & 17 deletions tests/test_file_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def test_container_probing(self):
assert self.file.duration == 6165333
assert str(self.file.format) == "<av.ContainerFormat 'mpegts'>"
assert self.file.format.name == "mpegts"
self.assertEqual(
self.file.format.long_name, "MPEG-TS (MPEG-2 Transport Stream)"
)
assert self.file.format.long_name == "MPEG-TS (MPEG-2 Transport Stream)"
assert self.file.metadata == {}
assert self.file.size == 207740
assert self.file.start_time == 1400000
Expand All @@ -25,11 +23,8 @@ def test_container_probing(self):
def test_stream_probing(self):
stream = self.file.streams[0]

# check __repr__
self.assertTrue(
str(stream).startswith(
"<av.AudioStream #0 aac_latm at 48000Hz, stereo, fltp at "
)
assert str(stream).startswith(
"<av.AudioStream #0 aac_latm at 48000Hz, stereo, fltp at "
)

# actual stream properties
Expand Down Expand Up @@ -65,9 +60,9 @@ def setUp(self):
with open(path, "wb"):
pass

self.file = av.open(path)
self.file = av.open(path, "r")

def test_container_probing(self):
def test_container_probing(self) -> None:
assert self.file.bit_rate == 0
assert self.file.duration is None
assert str(self.file.format) == "<av.ContainerFormat 'flac'>"
Expand All @@ -78,14 +73,11 @@ def test_container_probing(self):
assert self.file.start_time is None
assert len(self.file.streams) == 1

def test_stream_probing(self):
def test_stream_probing(self) -> None:
stream = self.file.streams[0]

# ensure __repr__ does not crash
self.assertTrue(
str(stream).startswith(
"<av.AudioStream #0 flac at 0Hz, 0 channels, None at "
)
assert str(stream).startswith(
"<av.AudioStream #0 flac at 0Hz, 0 channels, None at "
)

# actual stream properties
Expand Down Expand Up @@ -191,7 +183,6 @@ def test_container_probing(self) -> None:
def test_stream_probing(self) -> None:
stream = self.file.streams[0]

# check __repr__
assert str(stream).startswith("<av.SubtitleStream #0 subtitle/mov_text at ")

# actual stream properties
Expand Down
20 changes: 9 additions & 11 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@


def generate_audio_frame(
frame_num, input_format="s16", layout="stereo", sample_rate=44100, frame_size=1024
):
frame_num: int,
input_format: str = "s16",
layout: str = "stereo",
sample_rate: int = 44100,
frame_size: int = 1024,
) -> AudioFrame:
"""
Generate audio frame representing part of the sinusoidal wave
:param input_format: default: s16
:param layout: default: stereo
:param sample_rate: default: 44100
:param frame_size: default: 1024
:param frame_num: frame number
:return: audio frame for sinusoidal wave audio signal slice
"""
frame = AudioFrame(format=input_format, layout=layout, samples=frame_size)
frame.sample_rate = sample_rate
Expand All @@ -31,7 +29,7 @@ def generate_audio_frame(
data = np.zeros(frame_size, dtype=format_dtypes[input_format])
for j in range(frame_size):
data[j] = np.sin(2 * np.pi * (frame_num + j) * (i + 1) / float(frame_size))
frame.planes[i].update(data)
frame.planes[i].update(data) # type: ignore

return frame

Expand Down Expand Up @@ -79,8 +77,8 @@ def test_generator_graph(self):
lutrgb.link_to(sink)

# pads and links
self.assertIs(src.outputs[0].link.output, lutrgb.inputs[0])
self.assertIs(lutrgb.inputs[0].link.input, src.outputs[0])
assert src.outputs[0].link.output is lutrgb.inputs[0]
assert lutrgb.inputs[0].link.input is src.outputs[0]

frame = sink.pull()
assert isinstance(frame, VideoFrame)
Expand Down
Loading

0 comments on commit 6e9698f

Please sign in to comment.