-
Notifications
You must be signed in to change notification settings - Fork 1
/
token_validator.go
154 lines (131 loc) · 4.64 KB
/
token_validator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package auth
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/MicahParks/keyfunc"
"github.com/buger/jsonparser"
"github.com/golang-jwt/jwt/v4"
log "github.com/sirupsen/logrus"
"strings"
)
// TokenValidatorInterface interface of validation objects
type TokenValidatorInterface interface {
RetrieveClaimsFromToken(ctx context.Context, tokenInput string) (*Claims, error)
MatchClaims(ctx context.Context, tokenClaims *Claims, ruleClaims []byte) bool
ValidateClaimsForRule(ctx context.Context, tokenClaims *Claims, requestedRole string, rules []Rule) (*Rule, error)
}
// NewTokenValidator creates a new TokenValidator for a given system
func NewTokenValidator(jwksURL, boundIssuer, boundAudience string) *TokenValidator {
log.Debugf("Using %s for JWK retrival", jwksURL)
jwks, err := keyfunc.Get(jwksURL, keyfunc.Options{})
if err != nil {
log.Fatalf("Failed to get the JWKS from the given URL.\nError: %v", err)
}
validator := &TokenValidator{
jwks: jwks,
boundIssuer: boundIssuer,
boundAudience: boundAudience,
}
return validator
}
// TokenValidator implements a TokenValidatorInterface validating jwt tokens with a remote server
type TokenValidator struct {
jwks *keyfunc.JWKS
boundIssuer string
boundAudience string
}
// RetrieveClaimsFromToken validate the token and get all included claims
func (t *TokenValidator) RetrieveClaimsFromToken(ctx context.Context, tokenInput string) (*Claims, error) {
tokenClaims, err := jwt.ParseWithClaims(tokenInput, &jwt.MapClaims{}, t.jwks.Keyfunc)
if err != nil {
return nil, err
}
token, err := jwt.ParseWithClaims(tokenInput, &jwt.RegisteredClaims{}, t.jwks.Keyfunc)
if err != nil {
return nil, err
}
if !token.Valid {
return nil, fmt.Errorf("token invalid")
}
if t.boundIssuer != "" && !token.Claims.(*jwt.RegisteredClaims).VerifyIssuer(t.boundIssuer, true) {
return nil, fmt.Errorf("bound issuer %s expected", t.boundIssuer)
}
if t.boundAudience != "" && !token.Claims.(*jwt.RegisteredClaims).VerifyAudience(t.boundAudience, true) {
return nil, fmt.Errorf("bound audience %s expected", t.boundAudience)
}
Logger(ctx).Debugf("Raw token: %s", token.Raw)
parts := strings.Split(token.Raw, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("error splitting token into parts")
}
claimsJSON, err := json.Marshal(tokenClaims.Claims)
if err != nil {
return nil, fmt.Errorf("error decoding claims section: %s", err)
}
claims := &Claims{
ClaimsJSON: claimsJSON,
RegisteredClaims: token.Claims.(*jwt.RegisteredClaims),
}
return claims, nil
}
// MatchClaimsInternal implements claims matching on the json byte data level
func MatchClaimsInternal(ctx context.Context, claims []byte, rules []byte) (bool, error) {
matches := true
err := jsonparser.ObjectEach(rules, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error {
keyString := string(key)
claimsObj, claimsObjDataType, _, err := jsonparser.Get(claims, keyString)
//Check for parsing errors
if err != nil && claimsObjDataType != jsonparser.NotExist {
return err
}
//Check if rule type matches with claim type
if claimsObjDataType != dataType {
matches = false
return nil
}
switch dataType {
case jsonparser.Object:
//Check if object matches with rules
objMatches, err := MatchClaimsInternal(ctx, claimsObj, value)
if err != nil {
return err
}
//Check if object matches
if !objMatches {
matches = false
return nil
}
case jsonparser.String, jsonparser.Boolean, jsonparser.Number:
if !bytes.Equal(claimsObj, value) {
matches = false
return nil
}
case jsonparser.Array:
return fmt.Errorf("handling for arraytypes not implemented yet")
case jsonparser.NotExist, jsonparser.Unknown, jsonparser.Null:
return fmt.Errorf("iterated over a key with type %s. This should not happen", dataType.String())
}
return nil
})
return matches && err == nil, err
}
// MatchClaims check if all claims from a token are presented within rules
func (t *TokenValidator) MatchClaims(ctx context.Context, tokenClaims *Claims, ruleClaims []byte) bool {
Logger(ctx).Debugf("Rules JSON: %s", ruleClaims)
match, err := MatchClaimsInternal(ctx, tokenClaims.ClaimsJSON, ruleClaims)
if err != nil {
Logger(ctx).Warnf("error matching claims: %s", err)
}
return match && err == nil
}
// ValidateClaimsForRule check if
func (t *TokenValidator) ValidateClaimsForRule(ctx context.Context, tokenClaims *Claims, requestedRole string, rules []Rule) (*Rule, error) {
for _, rule := range rules {
if strings.Compare(rule.Role, requestedRole) == 0 && t.MatchClaims(ctx, tokenClaims, rule.ClaimValues) {
return &rule, nil
}
}
return nil, nil
}