Skip to content

Commit

Permalink
Deduplicate fragment() and fix fragsize<8
Browse files Browse the repository at this point in the history
  • Loading branch information
gpotter2 committed Jul 30, 2023
1 parent d68bee0 commit 69a7d1e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
22 changes: 4 additions & 18 deletions scapy/layers/inet.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,24 +616,7 @@ def mysummary(self):

def fragment(self, fragsize=1480):
"""Fragment IP datagrams"""
fragsize = (fragsize + 7) // 8 * 8
lst = []
for p in self:
s = raw(p[IP].payload)
nb = (len(s) + fragsize - 1) // fragsize
for i in range(nb):
q = p.copy()
del q.payload
del q.chksum
del q.len
if i != nb - 1:
q.flags |= 1
q.frag += i * fragsize // 8
r = conf.raw_layer(load=s[i * fragsize:(i + 1) * fragsize])
r.overload_fields = p.payload.overload_fields.copy()
q.add_payload(r)
lst.append(q)
return lst
return fragment(self, fragsize=fragsize)


def in4_pseudoheader(proto, u, plen):
Expand Down Expand Up @@ -1122,6 +1105,9 @@ def inet_register_l3(l2, l3):
@conf.commands.register
def fragment(pkt, fragsize=1480):
"""Fragment a big IP datagram"""
if fragsize < 8:
warning("fragsize cannot be lower than 8")
fragsize = max(fragsize, 8)
lastfragsz = fragsize
fragsize -= fragsize % 8
lst = []
Expand Down
19 changes: 19 additions & 0 deletions test/scapy/layers/inet.uts
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,25 @@ assert len(frags2) == 2
assert len(frags2[0]) == 20 + paylen - paylen % 8
assert len(frags2[1]) == 20 + 1 + paylen % 8

= fragment() with fragsize lower than 8
paylen = 5
fragsize = paylen
frags1 = fragment(IP() / ("X" * paylen), paylen)
assert len(frags1) == 1
assert bytes(frags1[0].payload) == b"X" * paylen

fragsize = paylen + 1
frags2 = fragment(IP() / ("X" * paylen), fragsize)
assert len(frags2) == 1
assert bytes(frags2[0].payload) == b"X" * paylen

paylen = 16
fragsize = 5
frags3 = fragment(IP() / ("X" * paylen), fragsize)
assert len(frags3) == 2
assert bytes(frags3[0].payload) == b"X" * 8
assert bytes(frags3[1].payload) == b"X" * 8

= defrag()
nonfrag, unfrag, badfrag = defrag(frags)
assert not nonfrag
Expand Down

0 comments on commit 69a7d1e

Please sign in to comment.