Skip to content

Commit

Permalink
add SASLBuffer utility
Browse files Browse the repository at this point in the history
  • Loading branch information
slingamn committed Feb 7, 2024
1 parent f1e8ead commit 5c25eee
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
68 changes: 68 additions & 0 deletions ircutils/sasl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ package ircutils

import (
"encoding/base64"
"errors"
"strings"
)

var (
ErrSASLLimitExceeded = errors.New("SASL total response size exceeded configured limit")
ErrSASLTooLong = errors.New("SASL response chunk exceeded 400-byte limit")
)

// EncodeSASLResponse encodes a raw SASL response as parameters to successive
Expand Down Expand Up @@ -35,3 +42,64 @@ func EncodeSASLResponse(raw []byte) (result []string) {

return result
}

// SASLBuffer handles buffering and decoding SASL responses sent as parameters
// to AUTHENTICATE commands, as described in the IRCv3 SASL specification.
// Do not copy a SASLBuffer after first use.
type SASLBuffer struct {
maxLength int
buffer strings.Builder
}

// NewSASLBuffer returns a new SASLBuffer. maxLength is the maximum amount of
// base64'ed data to buffer (0 for no limit).
func NewSASLBuffer(maxLength int) *SASLBuffer {
result := new(SASLBuffer)
result.Initialize(maxLength)
return result
}

// Initialize initializes a SASLBuffer in place.
func (b *SASLBuffer) Initialize(maxLength int) {
b.maxLength = maxLength
}

// Add processes an additional SASL response chunk sent via AUTHENTICATE.
// If the response is complete, it returns the decoded response along with
// any decoding or protocol errors detected.
func (b *SASLBuffer) Add(value string) (done bool, output []byte, err error) {
if value == "+" {
output, err = b.getAndReset()
return true, output, err
}

if len(value) > 400 {
b.buffer.Reset()
return true, nil, ErrSASLTooLong
}

if b.maxLength != 0 && (b.buffer.Len()+len(value)) > b.maxLength {
b.buffer.Reset()
return true, nil, ErrSASLLimitExceeded
}

b.buffer.WriteString(value)
if len(value) < 400 {
output, err = b.getAndReset()
return true, output, err
} else {
// 400 bytes, wait for continuation line or +
return false, nil, nil
}
}

// Clear resets the buffer state.
func (b *SASLBuffer) Clear() {
b.buffer.Reset()
}

func (b *SASLBuffer) getAndReset() (output []byte, err error) {
output, err = base64.StdEncoding.DecodeString(b.buffer.String())
b.buffer.Reset()
return
}
62 changes: 62 additions & 0 deletions ircutils/sasl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,65 @@ func TestSplitResponse(t *testing.T) {
},
)
}

func TestBuffer(t *testing.T) {
b := NewSASLBuffer(1600)

// less than 400 bytes
done, output, err := b.Add("c2hpdmFyYW0Ac2hpdmFyYW0Ac2hpdmFyYW1wYXNzcGhyYXNl")
assertEqual(done, true)
assertEqual(output, []byte("shivaram\x00shivaram\x00shivarampassphrase"))
assertEqual(err, nil)

// 400 bytes exactly plus a continuation +:
done, output, err = b.Add("c2xpbmdhbW4Ac2xpbmdhbW4AMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMTExMQ==")
assertEqual(done, false)
assertEqual(output, []byte(nil))
assertEqual(err, nil)
done, output, err = b.Add("+")
assertEqual(done, true)
assertEqual(output, []byte("slingamn\x00slingamn\x001111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"))
assertEqual(err, nil)

// over 400 bytes
done, output, err = b.Add("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==")
assertEqual(done, true)
assertEqual(output, []byte(nil))
assertEqual(err, ErrSASLTooLong)

// a single +
done, output, err = b.Add("+")
assertEqual(done, true)
assertEqual(len(output), 0)
assertEqual(err, nil)

// length limit
for i := 0; i < 4; i++ {
done, output, err = b.Add("AGVtZXJzaW9uAEVzdCB1dCBiZWF0YWUgb21uaXMgaXBzYW0uIFF1aXMgZnVnaWF0IGRlbGVuaXRpIHRvdGFtIHF1aS4gSXBzdW0gcXVhbSBhIGRvbG9ydW0gdGVtcG9yYSB2ZWxpdCBsYWJvcnVtIG9kaXQuIEV0IHNhZXBlIHZvbHVwdGF0ZSBzZWQgY3VtcXVlIHZlbC4gVm9sdXB0YXMgc2ludCBhYiBwYXJpYXR1ciBsaWJlcm8gdmVyaXRhdGlzIGNvcnJ1cHRpLiBWZXJvIGl1cmUgb21uaXMgdWxsYW0uIFZlcm8gYmVhdGFlIGRvbG9yZXMgZmFjZXJlIGZ1Z2lhdCBpcHNhbS4gRWEgZXN0IHBhcmlhdHVyIG1pbmltYSBub2JpcyBz")
assertEqual(done, false)
assertEqual(output, []byte(nil))
assertEqual(err, nil)
}
done, output, err = b.Add("AA==")
assertEqual(done, true)
assertEqual(output, []byte(nil))
assertEqual(err, ErrSASLLimitExceeded)

// invalid base64
done, output, err = b.Add("!!!")
assertEqual(done, true)
assertEqual(len(output), 0)
if err == nil {
t.Errorf("expected non-nil error from invalid base64")
}

// two lines
done, output, err = b.Add("AGVtZXJzaW9uAEVzdCB1dCBiZWF0YWUgb21uaXMgaXBzYW0uIFF1aXMgZnVnaWF0IGRlbGVuaXRpIHRvdGFtIHF1aS4gSXBzdW0gcXVhbSBhIGRvbG9ydW0gdGVtcG9yYSB2ZWxpdCBsYWJvcnVtIG9kaXQuIEV0IHNhZXBlIHZvbHVwdGF0ZSBzZWQgY3VtcXVlIHZlbC4gVm9sdXB0YXMgc2ludCBhYiBwYXJpYXR1ciBsaWJlcm8gdmVyaXRhdGlzIGNvcnJ1cHRpLiBWZXJvIGl1cmUgb21uaXMgdWxsYW0uIFZlcm8gYmVhdGFlIGRvbG9yZXMgZmFjZXJlIGZ1Z2lhdCBpcHNhbS4gRWEgZXN0IHBhcmlhdHVyIG1pbmltYSBub2JpcyBz")
assertEqual(done, false)
assertEqual(output, []byte(nil))
assertEqual(err, nil)
done, output, err = b.Add("dW50IGF1dCB1dC4gRG9sb3JlcyB1dCBsYXVkYW50aXVtIG1haW9yZXMgdGVtcG9yaWJ1cyB2b2x1cHRhdGVzLiBSZWljaWVuZGlzIGltcGVkaXQgb21uaXMgZXQgdW5kZSBkZWxlY3R1cyBxdWFzIGFiLiBRdWFlIGVsaWdlbmRpIG5lY2Vzc2l0YXRpYnVzIGRvbG9yaWJ1cyBtb2xlc3RpYXMgdGVtcG9yYSBtYWduYW0gYXNzdW1lbmRhLg==")
assertEqual(done, true)
assertEqual(output, []byte("\x00emersion\x00Est ut beatae omnis ipsam. Quis fugiat deleniti totam qui. Ipsum quam a dolorum tempora velit laborum odit. Et saepe voluptate sed cumque vel. Voluptas sint ab pariatur libero veritatis corrupti. Vero iure omnis ullam. Vero beatae dolores facere fugiat ipsam. Ea est pariatur minima nobis sunt aut ut. Dolores ut laudantium maiores temporibus voluptates. Reiciendis impedit omnis et unde delectus quas ab. Quae eligendi necessitatibus doloribus molestias tempora magnam assumenda."))
assertEqual(err, nil)
}

0 comments on commit 5c25eee

Please sign in to comment.