Skip to content

Commit

Permalink
Add HandshakeSuccess interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Sep 8, 2023
1 parent 03c21c0 commit b0849c4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
12 changes: 6 additions & 6 deletions common/bufio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
err = destination.WriteBuffer(buffer)
if err != nil {
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
Expand Down Expand Up @@ -130,7 +130,7 @@ func CopyExtendedWithSrcBuffer(originSource io.Reader, destination N.ExtendedWri
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
Expand Down Expand Up @@ -175,7 +175,7 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
Expand Down Expand Up @@ -304,7 +304,7 @@ func CopyPacketWithSrcBuffer(originSource N.PacketReader, destinationConn N.Pack
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
Expand Down Expand Up @@ -345,7 +345,7 @@ func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWri
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
Expand Down Expand Up @@ -381,7 +381,7 @@ func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWr
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
Expand Down
4 changes: 2 additions & 2 deletions common/bufio/copy_direct_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
Expand Down Expand Up @@ -102,7 +102,7 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe
if err != nil {
buffer.Release()
if !notFirstTime {
err = N.HandshakeFailure(originSource, err)
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
Expand Down
17 changes: 14 additions & 3 deletions common/network/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,26 @@ import (
E "github.com/sagernet/sing/common/exceptions"
)

type HandshakeConn interface {
type HandshakeFailure interface {
HandshakeFailure(err error) error
}

func HandshakeFailure(conn any, err error) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeConn](conn); isHandshakeConn {
type HandshakeSuccess interface {
HandshakeSuccess() error
}

func ReportHandshakeFailure(conn any, err error) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn {
return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
return E.Cause(err, "write handshake failure")
})
}
return err
}

func ReportHandshakeSuccess(conn any) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil
}

0 comments on commit b0849c4

Please sign in to comment.