Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support multiple extract token key #4328

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions rest/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,34 @@ type (
PrivateKeys []PrivateKeyConf
}

// JWTConf Key and expiration time configuration required for JWT authentication
JWTConf struct {
AccessSecret string
AccessExpire int64
// TokenLookup is a slice in the form of "<source>:<name>" that is used
// to extract token from the request.
// Optional.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "form:<name>"
TokenLookup []string `json:",optional"`
}

// A JWTTransConf is a jwtTrans config.
JWTTransConf struct {
Secret string
PrevSecret string
// TokenLookup is a slice in the form of "<source>:<name>" that is used
// to extract token from the request.
// Optional.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "form:<name>"
TokenLookup []string `json:",optional"`
}

// A RestConf is a http service config.
// Why not name it as Conf, because we need to consider usage like:
// type Config struct {
Expand Down
17 changes: 10 additions & 7 deletions rest/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@ func (ng *engine) addRoutes(r featuredRoutes) {
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
verifier func(chain.Chain) chain.Chain) chain.Chain {
if fr.jwt.enabled {
if len(fr.jwt.prevSecret) == 0 {
chn = chn.Append(handler.Authorize(fr.jwt.secret,
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
} else {
chn = chn.Append(handler.Authorize(fr.jwt.secret,
handler.WithPrevSecret(fr.jwt.prevSecret),
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
authOpts := []handler.AuthorizeOption{
handler.WithUnauthorizedCallback(ng.unauthorizedCallback),
}
if len(fr.jwt.prevSecret) > 0 {
authOpts = append(authOpts, handler.WithPrevSecret(fr.jwt.prevSecret))
}
if len(fr.jwt.tokenLookups) > 0 {
authOpts = append(authOpts, handler.WithTokenLookups(fr.jwt.tokenLookups))
}

chn = chn.Append(handler.Authorize(fr.jwt.secret, authOpts...))
}

return verifier(chn)
Expand Down
3 changes: 2 additions & 1 deletion rest/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ Verbose: true
{
priority: true,
jwt: jwtSetting{
enabled: true,
enabled: true,
tokenLookups: []string{"header:Token", "query:Token", "form:Token"},
},
signature: signatureSetting{},
routes: []Route{{
Expand Down
19 changes: 16 additions & 3 deletions rest/handler/authhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ var (
type (
// An AuthorizeOptions is authorize options.
AuthorizeOptions struct {
PrevSecret string
Callback UnauthorizedCallback
PrevSecret string
Callback UnauthorizedCallback
TokenLookups []string
}

// UnauthorizedCallback defines the method of unauthorized callback.
Expand All @@ -48,7 +49,12 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
opt(&authOpts)
}

parser := token.NewTokenParser()
var parseOpts []token.ParseOption
if len(authOpts.TokenLookups) > 0 {
parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenLookups))
}

parser := token.NewTokenParser(parseOpts...)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
Expand Down Expand Up @@ -97,6 +103,13 @@ func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
}
}

// WithTokenLookups used to set the source of the token
func WithTokenLookups(tokenLookups []string) AuthorizeOption {
return func(opts *AuthorizeOptions) {
opts.TokenLookups = tokenLookups
}
}

func detailAuthLog(r *http.Request, reason string) {
// discard dump error, only for debug purpose
details, _ := httputil.DumpRequest(r, true)
Expand Down
26 changes: 26 additions & 0 deletions rest/handler/authhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,32 @@ func TestAuthHandler(t *testing.T) {
assert.Equal(t, "content", resp.Body.String())
}

func TestAuthHandler_WithTokenLookups(t *testing.T) {
const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
token, err := buildToken(key, map[string]any{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("X-Token", token)
handler := Authorize(key, WithTokenLookups([]string{"header:X-Token"}))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
assert.Nil(t, err)

flusher, ok := w.(http.Flusher)
assert.True(t, ok)
flusher.Flush()
}))

resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "test", resp.Header().Get("X-Test"))
assert.Equal(t, "content", resp.Body.String())
}

func TestAuthHandlerWithPrevSecret(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
Expand Down
17 changes: 10 additions & 7 deletions rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,27 @@ func WithFileServer(path string, fs http.FileSystem) RunOption {
}

// WithJwt returns a func to enable jwt authentication in given route.
func WithJwt(secret string) RouteOption {
func WithJwt(jwt JWTConf) RouteOption {
return func(r *featuredRoutes) {
validateSecret(secret)
validateSecret(jwt.AccessSecret)
r.jwt.enabled = true
r.jwt.secret = secret
r.jwt.secret = jwt.AccessSecret
r.jwt.tokenLookups = jwt.TokenLookup
}
}

// WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition.
// Which means old and new jwt secrets work together for a period.
func WithJwtTransition(secret, prevSecret string) RouteOption {
func WithJwtTransition(jwt JWTTransConf) RouteOption {
return func(r *featuredRoutes) {
// why not validate prevSecret, because prevSecret is an already used one,
// even it not meet our requirement, we still need to allow the transition.
validateSecret(secret)
validateSecret(jwt.Secret)
r.jwt.enabled = true
r.jwt.secret = secret
r.jwt.prevSecret = prevSecret
r.jwt.secret = jwt.Secret
r.jwt.prevSecret = jwt.PrevSecret
r.jwt.tokenLookups = jwt.TokenLookup

}
}

Expand Down
4 changes: 2 additions & 2 deletions rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ Port: 0
Method: http.MethodGet,
Path: "/",
Handler: nil,
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
WithJwtTransition("preivous", "thenewone"))
}, WithJwt(JWTConf{AccessSecret: "thesecret"}), WithSignature(SignatureConf{}),
WithJwtTransition(JWTTransConf{Secret: "preivous", PrevSecret: "thenewone"}))

func() {
defer func() {
Expand Down
50 changes: 45 additions & 5 deletions rest/token/tokenparser.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
package token

import (
"fmt"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/golang-jwt/jwt/v4/request"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/timex"
)

const claimHistoryResetDuration = time.Hour * 24
const (
claimHistoryResetDuration = time.Hour * 24

jwtLookupHeader = "header"
jwtLookupQuery = "query"
jwtLookupForm = "form"
)

type (
// ParseOption defines the method to customize a TokenParser.
Expand All @@ -22,6 +31,7 @@
resetTime time.Duration
resetDuration time.Duration
history sync.Map
extractor request.MultiExtractor
}
)

Expand All @@ -30,6 +40,7 @@
parser := &TokenParser{
resetTime: timex.Now(),
resetDuration: claimHistoryResetDuration,
extractor: request.MultiExtractor{request.AuthorizationHeaderExtractor},
}

for _, opt := range opts {
Expand Down Expand Up @@ -79,10 +90,11 @@
}

func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
func(token *jwt.Token) (any, error) {
return []byte(secret), nil
}, request.WithParser(newParser()))
keyFunc := func(token *jwt.Token) (any, error) {
return []byte(secret), nil
}

return request.ParseFromRequest(r, tp.extractor, keyFunc, request.WithParser(newParser()))
}

func (tp *TokenParser) incrementCount(secret string) {
Expand Down Expand Up @@ -119,6 +131,34 @@
}
}

// WithExtractor used to configure the token extraction method of the TokenParser.
func WithExtractor(tokenLookups []string) ParseOption {
return func(parser *TokenParser) {
var headerNames, argumentNames []string
for _, lookup := range tokenLookups {
parts := strings.Split(strings.TrimSpace(lookup), ":")
if len(parts) < 2 {
logx.Must(fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", lookup))

Check warning on line 141 in rest/token/tokenparser.go

View check run for this annotation

Codecov / codecov/patch

rest/token/tokenparser.go#L141

Added line #L141 was not covered by tests
}

source := strings.TrimSpace(parts[0])
name := strings.TrimSpace(parts[1])
switch source {
case jwtLookupHeader:
headerNames = append(headerNames, name)
case jwtLookupQuery, jwtLookupForm:
argumentNames = append(argumentNames, name)
}
}

parser.extractor = request.MultiExtractor{
request.HeaderExtractor(headerNames),
request.ArgumentExtractor(argumentNames),
request.AuthorizationHeaderExtractor,
}
}
}

func newParser() *jwt.Parser {
return jwt.NewParser(jwt.WithJSONNumber())
}
78 changes: 78 additions & 0 deletions rest/token/tokenparser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package token
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -45,6 +47,82 @@ func TestTokenParser(t *testing.T) {
}
}

func TestTokenParser_CustomHeader(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
token, err := buildToken(key, map[string]any{"key": "value"}, 3600)
assert.Nil(t, err)
req.Header.Set("Token", token)

parser := NewTokenParser(WithExtractor([]string{"header:Token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}

func TestTokenParser_URLArgument(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
token, err := buildToken(key, map[string]any{"key": "value"}, 3600)
assert.Nil(t, err)

req := httptest.NewRequest(http.MethodGet, "http://localhost?token="+token, http.NoBody)

parser := NewTokenParser(WithExtractor([]string{"query:token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}

func TestTokenParser_FormArgument(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
token, err := buildToken(key, map[string]any{"key": "value"}, 3600)
assert.Nil(t, err)

// create form data
form := url.Values{}
form.Add("form_token", token)

// Using httptest.NewRequest to create a fake POST request
req := httptest.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

parser := NewTokenParser(WithExtractor([]string{"form:form_token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])

}

func TestTokenParser_Expired(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
Expand Down
7 changes: 4 additions & 3 deletions rest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ type (
RouteOption func(r *featuredRoutes)

jwtSetting struct {
enabled bool
secret string
prevSecret string
enabled bool
secret string
prevSecret string
tokenLookups []string
}

signatureSetting struct {
Expand Down
Loading