Skip to content

Commit

Permalink
Fix threaded sendrecv (#4538)
Browse files Browse the repository at this point in the history
* Restore sndrcv behaviour from before 53afe84

* Fix possible race condition of sndrcv

* Use much better timeout for threading

* Reduce abuse on public servers

* fix doip unit tests

* add testcase

* fix test case

* fix unit tests

* fix unit tests

* fix unit tests

* fix unit tests

---------

Co-authored-by: gpotter2 <[email protected]>
  • Loading branch information
polybassa and gpotter2 authored Sep 24, 2024
1 parent 19eeafe commit da9a952
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 35 deletions.
41 changes: 21 additions & 20 deletions scapy/sendrecv.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class debug:
Automatically enabled when a generator is passed as the packet
:param _flood:
:param threaded: if True, packets are sent in a thread and received in another.
defaults to False.
Defaults to True.
:param session: a flow decoder used to handle stream of packets
:param chainEX: if True, exceptions during send will be forwarded
:param stop_filter: Python function applied to each packet to determine if
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(self,
rcv_pks=None, # type: Optional[SuperSocket]
prebuild=False, # type: bool
_flood=None, # type: Optional[_FloodGenerator]
threaded=False, # type: bool
threaded=True, # type: bool
session=None, # type: Optional[_GlobSessionType]
chainEX=False, # type: bool
stop_filter=None # type: Optional[Callable[[Packet], bool]]
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(self,
self.noans = 0
self._flood = _flood
self.threaded = threaded
self.breakout = False
self.breakout = Event()
# Instantiate packet holders
if prebuild and not self._flood:
self.tobesent = list(pkt) # type: _PacketIterable
Expand All @@ -174,6 +174,7 @@ def __init__(self,
self.timeout = None

while retry >= 0:
self.breakout.clear()
self.hsent = {} # type: Dict[bytes, List[Packet]]

if threaded or self._flood:
Expand All @@ -190,7 +191,7 @@ def __init__(self,
except KeyboardInterrupt as ex:
interrupted = ex

self.breakout = True
self.breakout.set()

# Ended. Let's close gracefully
if self._flood:
Expand Down Expand Up @@ -251,28 +252,33 @@ def results(self):
# type: () -> Tuple[SndRcvList, PacketList]
return self.ans_result, self.unans_result

def _stop_sniffer_if_done(self) -> None:
"""Close the sniffer if all expected answers have been received"""
if self._send_done and self.noans >= self.notans and not self.multi:
if self.sniffer and self.sniffer.running:
self.sniffer.stop(join=False)

def _sndrcv_snd(self):
# type: () -> None
"""Function used in the sending thread of sndrcv()"""
i = 0
p = None
try:
if self.verbose:
print("Begin emission:")
os.write(1, b"Begin emission\n")
for p in self.tobesent:
# Populate the dictionary of _sndrcv_rcv
# _sndrcv_rcv won't miss the answer of a packet that
# has not been sent
self.hsent.setdefault(p.hashret(), []).append(p)
# Send packet
self.pks.send(p)
if self.inter:
time.sleep(self.inter)
if self.breakout:
time.sleep(self.inter)
if self.breakout.is_set():
break
i += 1
if self.verbose:
print("Finished sending %i packets." % i)
os.write(1, b"\nFinished sending %i packets\n" % i)
except SystemExit:
pass
except Exception:
Expand All @@ -291,13 +297,10 @@ def _sndrcv_snd(self):
elif not self._send_done:
self.notans = i
self._send_done = True
# In threaded mode, timeout.
if self.threaded and self.timeout is not None and not self.breakout:
t = time.monotonic() + self.timeout
while time.monotonic() < t:
if self.breakout:
break
time.sleep(0.1)
self._stop_sniffer_if_done()
# In threaded mode, timeout
if self.threaded and self.timeout is not None and not self.breakout.is_set():
self.breakout.wait(timeout=self.timeout)
if self.sniffer and self.sniffer.running:
self.sniffer.stop()

Expand All @@ -324,9 +327,7 @@ def _process_packet(self, r):
self.noans += 1
sentpkt._answered = 1
break
if self._send_done and self.noans >= self.notans and not self.multi:
if self.sniffer and self.sniffer.running:
self.sniffer.stop(join=False)
self._stop_sniffer_if_done()
if not ok:
if self.verbose > 1:
os.write(1, b".")
Expand All @@ -342,7 +343,7 @@ def _sndrcv_rcv(self, callback):
self.sniffer = AsyncSniffer()
self.sniffer._run(
prn=self._process_packet,
timeout=None if self.threaded else self.timeout,
timeout=None if self.threaded and not self._flood else self.timeout,
store=False,
opened_socket=self.rcv_pks,
session=self.session,
Expand Down
36 changes: 26 additions & 10 deletions test/contrib/automotive/doip.uts
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ import tempfile
= Test DoIPSocket

server_up = threading.Event()
sniff_up = threading.Event()
def server():
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -426,6 +427,7 @@ def server():
sock.listen(1)
server_up.set()
connection, address = sock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -437,7 +439,7 @@ server_thread.start()
server_up.wait(timeout=1)
sock = DoIPSocket(activate_routing=False)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

Expand All @@ -446,6 +448,7 @@ assert len(pkts) == 2
~ linux

server_up = threading.Event()
sniff_up = threading.Event()
def server():
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -456,6 +459,7 @@ def server():
sock.listen(1)
server_up.set()
connection, address = sock.accept()
sniff_up.wait(timeout=1)
for i in range(len(buffer)):
connection.send(buffer[i:i+1])
time.sleep(0.01)
Expand All @@ -469,13 +473,14 @@ server_thread.start()
server_up.wait(timeout=1)
sock = DoIPSocket(activate_routing=False)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

= Test DoIPSocket 3

server_up = threading.Event()
sniff_up = threading.Event()
def server():
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -486,6 +491,7 @@ def server():
sock.listen(1)
server_up.set()
connection, address = sock.accept()
sniff_up.wait(timeout=1)
while buffer:
randlen = random.randint(0, len(buffer))
connection.send(buffer[:randlen])
Expand All @@ -501,14 +507,15 @@ server_thread.start()
server_up.wait(timeout=1)
sock = DoIPSocket(activate_routing=False)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2


= Test DoIPSocket6

server_up = threading.Event()
sniff_up = threading.Event()
def server():
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
Expand All @@ -519,6 +526,7 @@ def server():
sock.listen(1)
server_up.set()
connection, address = sock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -530,7 +538,7 @@ server_thread.start()
server_up.wait(timeout=1)
sock = DoIPSocket(ip="::1", activate_routing=False)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

Expand Down Expand Up @@ -604,6 +612,7 @@ def _load_certificate_chain(context) -> None:


server_up = threading.Event()
sniff_up = threading.Event()
def server():
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
_load_certificate_chain(context)
Expand All @@ -619,6 +628,7 @@ def server():
ssock.listen(1)
server_up.set()
connection, address = ssock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -633,14 +643,15 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
sock = DoIPSocket(activate_routing=False, force_tls=True, context=context)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

= Test DoIPSslSocket6
~ broken_windows

server_up = threading.Event()
sniff_up = threading.Event()
def server():
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
_load_certificate_chain(context)
Expand All @@ -656,6 +667,7 @@ def server():
ssock.listen(1)
server_up.set()
connection, address = ssock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -670,14 +682,15 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
sock = DoIPSocket(ip="::1", activate_routing=False, force_tls=True, context=context)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

= Test UDS_DoIPSslSocket6
~ broken_windows

server_up = threading.Event()
sniff_up = threading.Event()
def server():
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
_load_certificate_chain(context)
Expand All @@ -693,6 +706,7 @@ def server():
ssock.listen(1)
server_up.set()
connection, address = ssock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -707,15 +721,16 @@ context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
sock = UDS_DoIPSocket(ip="::1", activate_routing=False, force_tls=True, context=context)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_thread.join(timeout=1)
assert len(pkts) == 2

= Test UDS_DualDoIPSslSocket6
~ broken_windows
~ broken_windows not_pypy

server_tcp_up = threading.Event()
server_tls_up = threading.Event()
sniff_up = threading.Event()
def server_tls():
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
_load_certificate_chain(context)
Expand All @@ -732,6 +747,7 @@ def server_tls():
ssock.listen(1)
server_tls_up.set()
connection, address = ssock.accept()
sniff_up.wait(timeout=1)
connection.send(buffer)
connection.close()
finally:
Expand All @@ -748,7 +764,7 @@ def server_tcp():
server_tcp_up.set()
connection, address = sock.accept()
connection.send(buffer)
connection.shutdown()
connection.shutdown(socket.SHUT_RDWR)
connection.close()
finally:
sock.close()
Expand All @@ -767,7 +783,7 @@ context.verify_mode = ssl.CERT_NONE

sock = UDS_DoIPSocket(ip="::1", context=context)

pkts = sock.sniff(timeout=1, count=2)
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
server_tcp_thread.join(timeout=1)
server_tls_thread.join(timeout=1)
assert len(pkts) == 2
12 changes: 12 additions & 0 deletions test/contrib/automotive/scanner/enumerator.uts
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,19 @@ class MockISOTPSocket(SuperSocket):
return len(sx)
@staticmethod
def select(sockets, remain=None):
time.sleep(0)
return sockets
def sr(self, *args, **kargs):
from scapy import sendrecv
return sendrecv.sndrcv(self, *args, threaded=False, **kargs)
def sr1(self, *args, **kargs):
from scapy import sendrecv
ans = sendrecv.sndrcv(self, *args, threaded=False, **kargs)[0] # type: SndRcvList
if len(ans) > 0:
pkt = ans[0][1] # type: Packet
return pkt
else:
return None

sock = MockISOTPSocket()
sock.rcvd_queue.put(b"\x41")
Expand Down
6 changes: 3 additions & 3 deletions test/regression.uts
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,7 @@ sck = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
ssck = StreamSocket(sck)

try:
r = ssck.sr1(ICMP(type='echo-request'), timeout=0.1, chainEX=True)
r = ssck.sr1(ICMP(type='echo-request'), timeout=0.1, chainEX=True, threaded=False)
assert False
except Exception:
assert True
Expand Down Expand Up @@ -2132,7 +2132,7 @@ retry_test(_test)
~ netaccess needs_root IP ICMP
def _test():
packet = IP(dst="8.8.8.8")/ICMP()
r = srflood(packet, timeout=2)
r = srflood(packet, timeout=0.5)
assert packet.sent_time is not None

retry_test(_test)
Expand All @@ -2142,7 +2142,7 @@ retry_test(_test)
def _test():
packet1 = IP(dst="8.8.8.8")/ICMP()
packet2 = IP(dst="8.8.4.4")/ICMP()
r = srflood([packet1, packet2], timeout=2)
r = srflood([packet1, packet2], timeout=0.5)
assert packet1.sent_time is not None
assert packet2.sent_time is not None

Expand Down
Loading

0 comments on commit da9a952

Please sign in to comment.