Unverified Commit 80849d79 authored by James Rasell's avatar James Rasell
Browse files

core: add ACL token expiry state, struct, and RPC handling.

The ACL token state schema has been updated to utilise two new
indexes which track expiration of tokens that are configured with
an expiration TTL or time. A new state function allows listing
ACL expired tokens which will be used by internal garbage
collection.

The ACL endpoint has been modified so that all validation happens
within a single function call. This is easier to understand and
see at a glance. The ACL token validation now also includes logic
for expiry ttl and times. The ACL endpoint upsert tests have been
condensed into a single, table driven test.

There is a new token canonicalize which provides a single place
for token canonicalization, rather than logic spread in the RPC
handler.
parent 08845cef
Showing with 1091 additions and 203 deletions
+1091 -203
......@@ -130,6 +130,12 @@ func TimeToPtr(t time.Duration) *time.Duration {
return &t
}
// TimeTimeToPtr returns the pointer to a time.Time. Its name is unfortunate,
// but needed due to the TimeToPtr function.
func TimeTimeToPtr(t time.Time) *time.Time {
return &t
}
// CompareTimePtrs return true if a is the same as b.
func CompareTimePtrs(a, b *time.Duration) bool {
if a == nil || b == nil {
......
......@@ -468,7 +468,7 @@ func (a *ACL) UpsertTokens(args *structs.ACLTokenUpsertRequest, reply *structs.A
// Validate non-zero set of tokens
if len(args.Tokens) == 0 {
return structs.NewErrRPCCoded(400, "must specify as least one token")
return structs.NewErrRPCCoded(http.StatusBadRequest, "must specify as least one token")
}
// Force the request to the authoritative region if we are creating global tokens
......@@ -486,14 +486,15 @@ func (a *ACL) UpsertTokens(args *structs.ACLTokenUpsertRequest, reply *structs.A
// the entire request as a single batch.
if hasGlobal {
if !allGlobal {
return structs.NewErrRPCCoded(400, "cannot upsert mixed global and non-global tokens")
return structs.NewErrRPCCoded(http.StatusBadRequest,
"cannot upsert mixed global and non-global tokens")
}
// Force the request to the authoritative region if it has global
args.Region = a.srv.config.AuthoritativeRegion
}
if done, err := a.srv.forward("ACL.UpsertTokens", args, args, reply); done {
if done, err := a.srv.forward(structs.ACLUpsertTokensRPCMethod, args, args, reply); done {
return err
}
defer metrics.MeasureSince([]string{"nomad", "acl", "upsert_tokens"}, time.Now())
......@@ -505,38 +506,41 @@ func (a *ACL) UpsertTokens(args *structs.ACLTokenUpsertRequest, reply *structs.A
return structs.ErrPermissionDenied
}
// Snapshot the state
state, err := a.srv.State().Snapshot()
// Snapshot the state so we can perform lookups against the accessor ID if
// needed. Do it here, so we only need to do this once no matter how many
// tokens we are upserting.
stateSnapshot, err := a.srv.State().Snapshot()
if err != nil {
return err
}
// Validate each token
for idx, token := range args.Tokens {
if err := token.Validate(); err != nil {
return structs.NewErrRPCCodedf(400, "token %d invalid: %v", idx, err)
}
// Generate an accessor and secret ID if new
if token.AccessorID == "" {
token.AccessorID = uuid.Generate()
token.SecretID = uuid.Generate()
token.CreateTime = time.Now().UTC()
// Store any existing token found, so we can perform the correct update
// validation.
var existingToken *structs.ACLToken
} else {
// Verify the token exists
out, err := state.ACLTokenByAccessorID(nil, token.AccessorID)
// If the token is being updated, perform a lookup so can can validate
// the new changes against the old.
if token.AccessorID != "" {
out, err := stateSnapshot.ACLTokenByAccessorID(nil, token.AccessorID)
if err != nil {
return structs.NewErrRPCCodedf(400, "token lookup failed: %v", err)
return structs.NewErrRPCCodedf(http.StatusBadRequest, "token lookup failed: %v", err)
}
if out == nil {
return structs.NewErrRPCCodedf(404, "cannot find token %s", token.AccessorID)
return structs.NewErrRPCCodedf(http.StatusBadRequest, "cannot find token %s", token.AccessorID)
}
existingToken = out
}
// Cannot toggle the "Global" mode
if token.Global != out.Global {
return structs.NewErrRPCCodedf(400, "cannot toggle global mode of %s", token.AccessorID)
}
// Canonicalize sets information needed by the validation function, so
// this order must be maintained.
token.Canonicalize()
if err := token.Validate(a.srv.config.ACLTokenMinExpirationTTL,
a.srv.config.ACLTokenMaxExpirationTTL, existingToken); err != nil {
return structs.NewErrRPCCodedf(http.StatusBadRequest, "token %d invalid: %v", idx, err)
}
// Compute the token hash
......@@ -549,14 +553,14 @@ func (a *ACL) UpsertTokens(args *structs.ACLTokenUpsertRequest, reply *structs.A
return err
}
// Populate the response. We do a lookup against the state to
// pickup the proper create / modify times.
state, err = a.srv.State().Snapshot()
// Populate the response. We do a lookup against the state to pick up the
// proper create / modify times.
stateSnapshot, err = a.srv.State().Snapshot()
if err != nil {
return err
}
for _, token := range args.Tokens {
out, err := state.ACLTokenByAccessorID(nil, token.AccessorID)
out, err := stateSnapshot.ACLTokenByAccessorID(nil, token.AccessorID)
if err != nil {
return structs.NewErrRPCCodedf(400, "token lookup failed: %v", err)
}
......
......@@ -1453,85 +1453,175 @@ func TestACLEndpoint_Bootstrap_Reset(t *testing.T) {
func TestACLEndpoint_UpsertTokens(t *testing.T) {
ci.Parallel(t)
s1, root, cleanupS1 := TestACLServer(t, nil)
defer cleanupS1()
codec := rpcClient(t, s1)
testutil.WaitForLeader(t, s1.RPC)
// Create the register request
p1 := mock.ACLToken()
p1.AccessorID = "" // Blank to create
// Lookup the tokens
req := &structs.ACLTokenUpsertRequest{
Tokens: []*structs.ACLToken{p1},
WriteRequest: structs.WriteRequest{
Region: "global",
AuthToken: root.SecretID,
// Each sub-test uses the same server to avoid creating a new one for each
// test. This means some care has to be taken with resource naming, but
// does avoid lots of calls to systems such as freeport.
testServer, rootACLToken, testServerCleanup := TestACLServer(t, nil)
defer testServerCleanup()
codec := rpcClient(t, testServer)
testutil.WaitForLeader(t, testServer.RPC)
testCases := []struct {
name string
testFn func(testServer *Server, aclToken *structs.ACLToken)
}{
{
name: "valid client token",
testFn: func(testServer *Server, aclToken *structs.ACLToken) {
// Create the register request with a mocked token. We must set
// an empty accessorID, otherwise Nomad treats this as an
// update request.
p1 := mock.ACLToken()
p1.AccessorID = ""
req := &structs.ACLTokenUpsertRequest{
Tokens: []*structs.ACLToken{p1},
WriteRequest: structs.WriteRequest{
Region: DefaultRegion,
AuthToken: aclToken.SecretID,
},
}
var resp structs.ACLTokenUpsertResponse
require.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertTokensRPCMethod, req, &resp))
require.Greater(t, resp.Index, uint64(0))
// Get the token out from the response.
created := resp.Tokens[0]
require.NotEqual(t, "", created.AccessorID)
require.NotEqual(t, "", created.SecretID)
require.NotEqual(t, time.Time{}, created.CreateTime)
require.Equal(t, p1.Type, created.Type)
require.Equal(t, p1.Policies, created.Policies)
require.Equal(t, p1.Name, created.Name)
// Check we created the token.
out, err := testServer.fsm.State().ACLTokenByAccessorID(nil, created.AccessorID)
require.Nil(t, err)
require.Equal(t, created, out)
// Update the token type and policy list so we can try updating
// it.
req.Tokens[0] = created
created.Type = "management"
created.Policies = nil
// Track the first upsert index, so we can test the next
// response against this and perform the update.
originalIndex := resp.Index
require.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertTokensRPCMethod, req, &resp))
assert.Greater(t, resp.Index, originalIndex)
// Read the token from state and perform an equality check to
// ensure everything matches as we expect.
out, err = testServer.fsm.State().ACLTokenByAccessorID(nil, created.AccessorID)
assert.Nil(t, err)
assert.Equal(t, created, out)
},
},
}
var resp structs.ACLTokenUpsertResponse
if err := msgpackrpc.CallWithCodec(codec, "ACL.UpsertTokens", req, &resp); err != nil {
t.Fatalf("err: %v", err)
}
assert.NotEqual(t, uint64(0), resp.Index)
// Get the token out from the response
created := resp.Tokens[0]
assert.NotEqual(t, "", created.AccessorID)
assert.NotEqual(t, "", created.SecretID)
assert.NotEqual(t, time.Time{}, created.CreateTime)
assert.Equal(t, p1.Type, created.Type)
assert.Equal(t, p1.Policies, created.Policies)
assert.Equal(t, p1.Name, created.Name)
// Check we created the token
out, err := s1.fsm.State().ACLTokenByAccessorID(nil, created.AccessorID)
assert.Nil(t, err)
assert.Equal(t, created, out)
// Update the token type
req.Tokens[0] = created
created.Type = "management"
created.Policies = nil
// Upsert again
if err := msgpackrpc.CallWithCodec(codec, "ACL.UpsertTokens", req, &resp); err != nil {
t.Fatalf("err: %v", err)
}
assert.NotEqual(t, uint64(0), resp.Index)
// Check we modified the token
out, err = s1.fsm.State().ACLTokenByAccessorID(nil, created.AccessorID)
assert.Nil(t, err)
assert.Equal(t, created, out)
}
func TestACLEndpoint_UpsertTokens_Invalid(t *testing.T) {
ci.Parallel(t)
s1, root, cleanupS1 := TestACLServer(t, nil)
defer cleanupS1()
codec := rpcClient(t, s1)
testutil.WaitForLeader(t, s1.RPC)
// Create the register request
p1 := mock.ACLToken()
p1.Type = "blah blah"
// Lookup the tokens
req := &structs.ACLTokenUpsertRequest{
Tokens: []*structs.ACLToken{p1},
WriteRequest: structs.WriteRequest{
Region: "global",
AuthToken: root.SecretID,
{
name: "valid management token with expiration",
testFn: func(testServer *Server, aclToken *structs.ACLToken) {
// Create our RPC request object which includes a management
// token with a TTL.
req := &structs.ACLTokenUpsertRequest{
Tokens: []*structs.ACLToken{
{
Name: "my-management-token-" + uuid.Generate(),
Type: structs.ACLManagementToken,
ExpirationTTL: 10 * time.Minute,
},
},
WriteRequest: structs.WriteRequest{
Region: DefaultRegion,
AuthToken: aclToken.SecretID,
},
}
// Send the RPC request and ensure the expiration time is as
// expected.
var resp structs.ACLTokenUpsertResponse
require.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertTokensRPCMethod, req, &resp))
require.Equal(t, 10*time.Minute, resp.Tokens[0].ExpirationTime.Sub(resp.Tokens[0].CreateTime))
},
},
{
name: "valid client token with expiration",
testFn: func(testServer *Server, aclToken *structs.ACLToken) {
// Create an ACL policy so this can be associated to our client
// token.
policyReq := &structs.ACLPolicyUpsertRequest{
Policies: []*structs.ACLPolicy{mock.ACLPolicy()},
WriteRequest: structs.WriteRequest{
Region: DefaultRegion,
AuthToken: aclToken.SecretID,
},
}
var policyResp structs.GenericResponse
require.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertPoliciesRPCMethod, policyReq, &policyResp))
// Create our RPC request object which includes a client token
// with a TTL that is associated to policies above.
tokenReq := &structs.ACLTokenUpsertRequest{
Tokens: []*structs.ACLToken{
{
Name: "my-client-token-" + uuid.Generate(),
Type: structs.ACLClientToken,
Policies: []string{policyReq.Policies[0].Name},
ExpirationTTL: 10 * time.Minute,
},
},
WriteRequest: structs.WriteRequest{
Region: DefaultRegion,
AuthToken: aclToken.SecretID,
},
}
// Send the RPC request and ensure the expiration time is as
// expected.
var tokenResp structs.ACLTokenUpsertResponse
require.NoError(t, msgpackrpc.CallWithCodec(codec, structs.ACLUpsertTokensRPCMethod, tokenReq, &tokenResp))
require.Equal(t, 10*time.Minute, tokenResp.Tokens[0].ExpirationTime.Sub(tokenResp.Tokens[0].CreateTime))
},
},
{
name: "invalid token type",
testFn: func(testServer *Server, aclToken *structs.ACLToken) {
// Create our RPC request object which includes a token with an
// unknown type. This allows us to ensure the RPC handler calls
// the validation func.
tokenReq := &structs.ACLTokenUpsertRequest{
Tokens: []*structs.ACLToken{
{
Name: "my-blah-token-" + uuid.Generate(),
Type: "blah",
},
},
WriteRequest: structs.WriteRequest{
Region: DefaultRegion,
AuthToken: aclToken.SecretID,
},
}
// Send the RPC request and ensure the expiration time is as
// expected.
var tokenResp structs.ACLTokenUpsertResponse
err := msgpackrpc.CallWithCodec(codec, structs.ACLUpsertTokensRPCMethod, tokenReq, &tokenResp)
require.ErrorContains(t, err, "token type must be client or management")
require.Empty(t, tokenResp.Tokens)
},
},
}
var resp structs.GenericResponse
err := msgpackrpc.CallWithCodec(codec, "ACL.UpsertTokens", req, &resp)
assert.NotNil(t, err)
if !strings.Contains(err.Error(), "client or management") {
t.Fatalf("bad: %s", err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.testFn(testServer, rootACLToken)
})
}
}
......
package indexer
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"time"
"github.com/hashicorp/go-memdb"
)
var (
// Ensure the required memdb interfaces are met at compile time.
_ memdb.Indexer = SingleIndexer{}
_ memdb.SingleIndexer = SingleIndexer{}
)
// SingleIndexer implements both memdb.Indexer and memdb.SingleIndexer. It may
// be used in a memdb.IndexSchema to specify functions that generate the index
// value for memdb.Txn operations.
type SingleIndexer struct {
// readIndex is used by memdb for Txn.Get, Txn.First, and other operations
// that read data.
ReadIndex
// writeIndex is used by memdb for Txn.Insert, Txn.Delete, and other
// operations that write data to the index.
WriteIndex
}
// ReadIndex implements memdb.Indexer. It exists so that a function can be used
// to provide the interface.
//
// Unlike memdb.Indexer, a readIndex function accepts only a single argument. To
// generate an index from multiple values, use a struct type with multiple fields.
type ReadIndex func(arg interface{}) ([]byte, error)
func (f ReadIndex) FromArgs(args ...interface{}) ([]byte, error) {
if len(args) != 1 {
return nil, fmt.Errorf("index supports only a single arg")
}
return f(args[0])
}
var ErrMissingValueForIndex = fmt.Errorf("object is missing a value for this index")
// WriteIndex implements memdb.SingleIndexer. It exists so that a function
// can be used to provide this interface.
//
// Instead of a bool return value, writeIndex expects errMissingValueForIndex to
// indicate that an index could not be build for the object. It will translate
// this error into a false value to satisfy the memdb.SingleIndexer interface.
type WriteIndex func(raw interface{}) ([]byte, error)
func (f WriteIndex) FromObject(raw interface{}) (bool, []byte, error) {
v, err := f(raw)
if errors.Is(err, ErrMissingValueForIndex) {
return false, nil, nil
}
return err == nil, v, err
}
// IndexBuilder is a buffer used to construct memdb index values.
type IndexBuilder bytes.Buffer
// Bytes returns the stored IndexBuilder value as a byte array.
func (b *IndexBuilder) Bytes() []byte { return (*bytes.Buffer)(b).Bytes() }
// Time is used to write the passed time into the IndexBuilder for use as a
// memdb index value.
func (b *IndexBuilder) Time(t time.Time) {
val := t.Unix()
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(val))
(*bytes.Buffer)(b).Write(buf)
}
package indexer
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_IndexBuilder_Time(t *testing.T) {
builder := &IndexBuilder{}
testTime := time.Date(1987, time.April, 13, 8, 3, 0, 0, time.UTC)
builder.Time(testTime)
require.Equal(t, []byte{0, 0, 0, 0, 32, 128, 155, 180}, builder.Bytes())
}
package indexer
import (
"fmt"
"time"
)
type TimeQuery struct {
Value time.Time
}
// IndexFromTimeQuery can be used as a memdb.Indexer query via ReadIndex and
// allows querying by time.
func IndexFromTimeQuery(arg interface{}) ([]byte, error) {
p, ok := arg.(*TimeQuery)
if !ok {
return nil, fmt.Errorf("unexpected type %T for TimeQuery index", arg)
}
// Construct the index value and return the byte array representation of
// the time value.
var b IndexBuilder
b.Time(p.Value)
return b.Bytes(), nil
}
package indexer
import (
"testing"
"time"
"github.com/hashicorp/nomad/ci"
"github.com/stretchr/testify/require"
)
func Test_IndexFromTimeQuery(t *testing.T) {
ci.Parallel(t)
testCases := []struct {
inputArg interface{}
expectedOutputBytes []byte
expectedOutputError error
name string
}{
{
inputArg: &TimeQuery{
Value: time.Date(1987, time.April, 13, 8, 3, 0, 0, time.UTC),
},
expectedOutputBytes: []byte{0x0, 0x0, 0x0, 0x0, 0x20, 0x80, 0x9b, 0xb4},
expectedOutputError: nil,
name: "generic test 1",
},
{
inputArg: &TimeQuery{
Value: time.Date(2022, time.April, 27, 14, 12, 0, 0, time.UTC),
},
expectedOutputBytes: []byte{0x0, 0x0, 0x0, 0x0, 0x62, 0x69, 0x4f, 0x30},
expectedOutputError: nil,
name: "generic test 2",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualOutput, actualError := IndexFromTimeQuery(tc.inputArg)
require.Equal(t, tc.expectedOutputError, actualError)
require.Equal(t, tc.expectedOutputBytes, actualOutput)
})
}
}
......@@ -5,7 +5,7 @@ import (
"sync"
memdb "github.com/hashicorp/go-memdb"
"github.com/hashicorp/nomad/nomad/state/indexer"
"github.com/hashicorp/nomad/nomad/structs"
)
......@@ -17,11 +17,13 @@ const (
)
const (
indexID = "id"
indexJob = "job"
indexNodeID = "node_id"
indexAllocID = "alloc_id"
indexServiceName = "service_name"
indexID = "id"
indexJob = "job"
indexNodeID = "node_id"
indexAllocID = "alloc_id"
indexServiceName = "service_name"
indexExpiresGlobal = "expires-global"
indexExpiresLocal = "expires-local"
)
var (
......@@ -816,10 +818,60 @@ func aclTokenTableSchema() *memdb.TableSchema {
Field: "Global",
},
},
indexExpiresGlobal: {
Name: indexExpiresGlobal,
AllowMissing: true,
Unique: false,
Indexer: indexer.SingleIndexer{
ReadIndex: indexer.ReadIndex(indexer.IndexFromTimeQuery),
WriteIndex: indexer.WriteIndex(indexExpiresGlobalFromACLToken),
},
},
indexExpiresLocal: {
Name: indexExpiresLocal,
AllowMissing: true,
Unique: false,
Indexer: indexer.SingleIndexer{
ReadIndex: indexer.ReadIndex(indexer.IndexFromTimeQuery),
WriteIndex: indexer.WriteIndex(indexExpiresLocalFromACLToken),
},
},
},
}
}
func indexExpiresLocalFromACLToken(raw interface{}) ([]byte, error) {
return indexExpiresFromACLToken(raw, false)
}
func indexExpiresGlobalFromACLToken(raw interface{}) ([]byte, error) {
return indexExpiresFromACLToken(raw, true)
}
// indexExpiresFromACLToken implements the indexer.WriteIndex interface and
// allows us to use an ACL tokens ExpirationTime as an index, if it is a
// non-default value. This allows for efficient lookups when trying to deal
// with removal of expired tokens from state.
func indexExpiresFromACLToken(raw interface{}, global bool) ([]byte, error) {
p, ok := raw.(*structs.ACLToken)
if !ok {
return nil, fmt.Errorf("unexpected type %T for structs.ACLToken index", raw)
}
if p.Global != global {
return nil, indexer.ErrMissingValueForIndex
}
if !p.HasExpirationTime() {
return nil, indexer.ErrMissingValueForIndex
}
if p.ExpirationTime.Unix() < 0 {
return nil, fmt.Errorf("token expiration time cannot be before the unix epoch: %s", p.ExpirationTime)
}
var b indexer.IndexBuilder
b.Time(*p.ExpirationTime)
return b.Bytes(), nil
}
// oneTimeTokenTableSchema returns the MemDB schema for the tokens table.
// This table is used to store one-time tokens for ACL tokens
func oneTimeTokenTableSchema() *memdb.TableSchema {
......
package state
import (
"fmt"
"time"
"github.com/hashicorp/nomad/nomad/structs"
)
// ACLTokensByExpired returns an array accessor IDs of expired ACL tokens.
// Their expiration is determined against the passed time.Time value.
//
// The function handles global and local tokens independently as determined by
// the global boolean argument. The number of returned IDs can be limited by
// the max integer, which is useful to limit the number of tokens we attempt to
// delete in a single transaction.
func (s *StateStore) ACLTokensByExpired(global bool, now time.Time, max int) ([]string, error) {
tnx := s.db.ReadTxn()
iter, err := tnx.Get("acl_token", expiresIndexName(global))
if err != nil {
return nil, fmt.Errorf("failed acl token listing: %v", err)
}
var (
accessorIDs []string
num int
)
for raw := iter.Next(); raw != nil; raw = iter.Next() {
token := raw.(*structs.ACLToken)
// The indexes mean if we come across an unexpired token, we can exit
// as we have found all currently expired tokens.
if !token.IsExpired(now) {
return accessorIDs, nil
}
accessorIDs = append(accessorIDs, token.AccessorID)
// Increment the counter. If this is at or above our limit, we return
// what we have so far.
num++
if num >= max {
return accessorIDs, nil
}
}
return accessorIDs, nil
}
// expiresIndexName is a helper function to identify the correct ACL token
// table expiry index to use.
func expiresIndexName(global bool) string {
if global {
return indexExpiresGlobal
}
return indexExpiresLocal
}
package state
import (
"testing"
"time"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/stretchr/testify/require"
)
func TestStateStore_ACLTokensByExpired(t *testing.T) {
ci.Parallel(t)
testState := testStateStore(t)
// This time is the threshold for all expiry calls to be based on. All
// tokens with expiry can use this as their base and use Add().
expiryTimeThreshold := time.Date(2022, time.April, 27, 14, 50, 0, 0, time.UTC)
// Generate two tokens without an expiry time. These tokens should never
// show up in calls to ACLTokensByExpired.
neverExpireLocalToken := mock.ACLToken()
neverExpireGlobalToken := mock.ACLToken()
neverExpireLocalToken.Global = true
// Upsert the tokens into state and perform a global and local read of
// the state.
err := testState.UpsertACLTokens(structs.MsgTypeTestSetup, 10, []*structs.ACLToken{
neverExpireLocalToken, neverExpireGlobalToken})
require.NoError(t, err)
ids, err := testState.ACLTokensByExpired(true, expiryTimeThreshold, 10)
require.NoError(t, err)
require.Len(t, ids, 0)
ids, err = testState.ACLTokensByExpired(false, expiryTimeThreshold, 10)
require.NoError(t, err)
require.Len(t, ids, 0)
// Generate, upsert, and test an expired local token. This token expired
// long ago and therefore before all others coming in the tests. It should
// therefore always be the first out.
expiredLocalToken := mock.ACLToken()
expiredLocalToken.ExpirationTime = helper.TimeTimeToPtr(expiryTimeThreshold.Add(-48 * time.Hour))
err = testState.UpsertACLTokens(structs.MsgTypeTestSetup, 20, []*structs.ACLToken{expiredLocalToken})
require.NoError(t, err)
ids, err = testState.ACLTokensByExpired(false, expiryTimeThreshold, 10)
require.NoError(t, err)
require.Len(t, ids, 1)
require.Equal(t, expiredLocalToken.AccessorID, ids[0])
// Generate, upsert, and test an expired global token. This token expired
// long ago and therefore before all others coming in the tests. It should
// therefore always be the first out.
expiredGlobalToken := mock.ACLToken()
expiredGlobalToken.Global = true
expiredGlobalToken.ExpirationTime = helper.TimeTimeToPtr(expiryTimeThreshold.Add(-48 * time.Hour))
err = testState.UpsertACLTokens(structs.MsgTypeTestSetup, 30, []*structs.ACLToken{expiredGlobalToken})
require.NoError(t, err)
ids, err = testState.ACLTokensByExpired(true, expiryTimeThreshold, 10)
require.NoError(t, err)
require.Len(t, ids, 1)
require.Equal(t, expiredGlobalToken.AccessorID, ids[0])
// This test function allows us to run the same test for local and global
// tokens.
testFn := func(oldID string, global bool) {
// Track all the expected expired accessor IDs including the long
// expired token.
var expiredLocalAccessorIDs []string
expiredLocalAccessorIDs = append(expiredLocalAccessorIDs, oldID)
// Generate and upsert a number of mixed expired, non-expired local tokens.
mixedLocalTokens := make([]*structs.ACLToken, 20)
for i := 0; i < 20; i++ {
mockedToken := mock.ACLToken()
mockedToken.Global = global
if i%2 == 0 {
expiredLocalAccessorIDs = append(expiredLocalAccessorIDs, mockedToken.AccessorID)
mockedToken.ExpirationTime = helper.TimeTimeToPtr(expiryTimeThreshold.Add(-24 * time.Hour))
} else {
mockedToken.ExpirationTime = helper.TimeTimeToPtr(expiryTimeThreshold.Add(24 * time.Hour))
}
mixedLocalTokens[i] = mockedToken
}
err = testState.UpsertACLTokens(structs.MsgTypeTestSetup, 40, mixedLocalTokens)
require.NoError(t, err)
// Use a max value higher than the number we have to check the full listing
// works as expected. Ensure our oldest expired token is first in the list.
ids, err = testState.ACLTokensByExpired(global, expiryTimeThreshold, 100)
require.NoError(t, err)
require.ElementsMatch(t, ids, expiredLocalAccessorIDs)
require.Equal(t, ids[0], oldID)
// Use a lower max value than the number of known expired tokens to ensure
// this is working.
ids, err = testState.ACLTokensByExpired(global, expiryTimeThreshold, 3)
require.NoError(t, err)
require.Len(t, ids, 3)
require.Equal(t, ids[0], oldID)
}
testFn(expiredLocalToken.AccessorID, false)
testFn(expiredGlobalToken.AccessorID, true)
}
func Test_expiresIndexName(t *testing.T) {
testCases := []struct {
globalInput bool
expectedOutput string
name string
}{
{
globalInput: false,
expectedOutput: indexExpiresLocal,
name: "local",
},
{
globalInput: true,
expectedOutput: indexExpiresGlobal,
name: "global",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualOutput := expiresIndexName(tc.globalInput)
require.Equal(t, tc.expectedOutput, actualOutput)
})
}
}
package structs
import (
"errors"
"fmt"
"time"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/helper/uuid"
)
const (
// ACLUpsertPoliciesRPCMethod is the RPC method for batch creating or
// modifying ACL policies.
//
// Args: ACLPolicyUpsertRequest
// Reply: GenericResponse
ACLUpsertPoliciesRPCMethod = "ACL.UpsertPolicies"
// ACLUpsertTokensRPCMethod is the RPC method for batch creating or
// modifying ACL tokens.
//
// Args: ACLTokenUpsertRequest
// Reply: ACLTokenUpsertResponse
ACLUpsertTokensRPCMethod = "ACL.UpsertTokens"
)
// Canonicalize performs basic canonicalization on the ACL token object. It is
// important for callers to understand certain fields such as AccessorID are
// set if it is empty, so copies should be taken if needed before calling this
// function.
func (a *ACLToken) Canonicalize() {
// If the accessor ID is empty, it means this is creation of a new token,
// therefore we need to generate base information.
if a.AccessorID == "" {
a.AccessorID = uuid.Generate()
a.SecretID = uuid.Generate()
a.CreateTime = time.Now().UTC()
// If the user has not set the expiration time, but has provided a TTL, we
// calculate and populate the former filed.
if a.ExpirationTime == nil && a.ExpirationTTL != 0 {
a.ExpirationTime = helper.TimeTimeToPtr(a.CreateTime.Add(a.ExpirationTTL))
}
}
}
// Validate is used to check a token for reasonableness
func (a *ACLToken) Validate(minTTL, maxTTL time.Duration, existing *ACLToken) error {
var mErr multierror.Error
// The human friendly name of an ACL token cannot exceed 256 characters.
if len(a.Name) > maxTokenNameLength {
mErr.Errors = append(mErr.Errors, errors.New("token name too long"))
}
// The type of an ACL token must be set. An ACL token of type client must
// have associated policies, whereas a management token cannot be
// associated with policies.
switch a.Type {
case ACLClientToken:
if len(a.Policies) == 0 {
mErr.Errors = append(mErr.Errors, errors.New("client token missing policies"))
}
case ACLManagementToken:
if len(a.Policies) != 0 {
mErr.Errors = append(mErr.Errors, errors.New("management token cannot be associated with policies"))
}
default:
mErr.Errors = append(mErr.Errors, errors.New("token type must be client or management"))
}
// There are different validation rules depending on whether the ACL token
// is being created or updated.
switch existing {
case nil:
if a.ExpirationTTL < 0 {
mErr.Errors = append(mErr.Errors,
fmt.Errorf("token expiration TTL '%s' should not be negative", a.ExpirationTTL))
}
if a.ExpirationTime != nil && !a.ExpirationTime.IsZero() {
if a.CreateTime.After(*a.ExpirationTime) {
mErr.Errors = append(mErr.Errors, errors.New("expiration time cannot be before create time"))
}
// Create a time duration which details the time-til-expiry, so we can
// check this against the regions max and min values.
expiresIn := a.ExpirationTime.Sub(a.CreateTime)
if expiresIn > maxTTL {
mErr.Errors = append(mErr.Errors,
fmt.Errorf("expiration time cannot be more than %s in the future (was %s)",
maxTTL, expiresIn))
} else if expiresIn < minTTL {
mErr.Errors = append(mErr.Errors,
fmt.Errorf("expiration time cannot be less than %s in the future (was %s)",
minTTL, expiresIn))
}
}
default:
if existing.Global != a.Global {
mErr.Errors = append(mErr.Errors, errors.New("cannot toggle global mode"))
}
if existing.ExpirationTTL != a.ExpirationTTL {
mErr.Errors = append(mErr.Errors, errors.New("cannot update expiration TTL"))
}
if existing.ExpirationTime != a.ExpirationTime {
mErr.Errors = append(mErr.Errors, errors.New("cannot update expiration time"))
}
}
return mErr.ErrorOrNil()
}
// HasExpirationTime checks whether the ACL token has an expiration time value
// set.
func (a *ACLToken) HasExpirationTime() bool {
if a == nil || a.ExpirationTime == nil {
return false
}
return !a.ExpirationTime.IsZero()
}
// IsExpired compares the ACLToken.ExpirationTime against the passed t to
// identify whether the token is considered expired. The function can be called
// without checking whether the ACL token has an expiry time.
func (a *ACLToken) IsExpired(t time.Time) bool {
// This is the fastest exit point and means we don't need to perform any
// expiration checks.
if t.IsZero() || !a.HasExpirationTime() {
return false
}
// Check and ensure the time location is set to UTC. This is vital for
// consistency with multi-region global tokens.
if t.Location() != time.UTC {
t = t.UTC()
}
return a.ExpirationTime.Before(t)
}
package structs
import (
"testing"
"time"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/stretchr/testify/require"
)
func TestACLToken_Canonicalize(t *testing.T) {
testCases := []struct {
name string
testFn func()
}{
{
name: "token with accessor",
testFn: func() {
mockToken := &ACLToken{
AccessorID: uuid.Generate(),
SecretID: uuid.Generate(),
Name: "my cool token " + uuid.Generate(),
Type: "client",
Policies: []string{"foo", "bar"},
Global: false,
CreateTime: time.Now().UTC(),
CreateIndex: 10,
ModifyIndex: 20,
}
mockToken.SetHash()
copiedMockToken := mockToken.Copy()
mockToken.Canonicalize()
require.Equal(t, copiedMockToken, mockToken)
},
},
{
name: "token without accessor",
testFn: func() {
mockToken := &ACLToken{
Name: "my cool token " + uuid.Generate(),
Type: "client",
Policies: []string{"foo", "bar"},
Global: false,
}
mockToken.Canonicalize()
require.NotEmpty(t, mockToken.AccessorID)
require.NotEmpty(t, mockToken.SecretID)
require.NotEmpty(t, mockToken.CreateTime)
},
},
{
name: "token with ttl without accessor",
testFn: func() {
mockToken := &ACLToken{
Name: "my cool token " + uuid.Generate(),
Type: "client",
Policies: []string{"foo", "bar"},
Global: false,
ExpirationTTL: 10 * time.Hour,
}
mockToken.Canonicalize()
require.NotEmpty(t, mockToken.AccessorID)
require.NotEmpty(t, mockToken.SecretID)
require.NotEmpty(t, mockToken.CreateTime)
require.NotEmpty(t, mockToken.ExpirationTime)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.testFn()
})
}
}
func TestACLTokenValidate(t *testing.T) {
ci.Parallel(t)
testCases := []struct {
name string
inputACLToken *ACLToken
inputExistingACLToken *ACLToken
expectedErrorContains string
}{
{
name: "missing type",
inputACLToken: &ACLToken{},
inputExistingACLToken: nil,
expectedErrorContains: "client or management",
},
{
name: "missing policies",
inputACLToken: &ACLToken{
Type: ACLClientToken,
},
inputExistingACLToken: nil,
expectedErrorContains: "missing policies",
},
{
name: "invalid policies",
inputACLToken: &ACLToken{
Type: ACLManagementToken,
Policies: []string{"foo"},
},
inputExistingACLToken: nil,
expectedErrorContains: "associated with policies",
},
{
name: "name too long",
inputACLToken: &ACLToken{
Type: ACLManagementToken,
Name: uuid.Generate() + uuid.Generate() + uuid.Generate() + uuid.Generate() +
uuid.Generate() + uuid.Generate() + uuid.Generate() + uuid.Generate(),
},
inputExistingACLToken: nil,
expectedErrorContains: "name too long",
},
{
name: "negative TTL",
inputACLToken: &ACLToken{
Type: ACLManagementToken,
Name: "foo",
ExpirationTTL: -1 * time.Hour,
},
inputExistingACLToken: nil,
expectedErrorContains: "should not be negative",
},
{
name: "TTL too small",
inputACLToken: &ACLToken{
Type: ACLManagementToken,
Name: "foo",
CreateTime: time.Date(2022, time.July, 11, 16, 23, 0, 0, time.UTC),
ExpirationTime: helper.TimeTimeToPtr(time.Date(2022, time.July, 11, 16, 23, 10, 0, time.UTC)),
},
inputExistingACLToken: nil,
expectedErrorContains: "expiration time cannot be less than",
},
{
name: "TTL too large",
inputACLToken: &ACLToken{
Type: ACLManagementToken,
Name: "foo",
CreateTime: time.Date(2022, time.July, 11, 16, 23, 0, 0, time.UTC),
ExpirationTime: helper.TimeTimeToPtr(time.Date(2042, time.July, 11, 16, 23, 0, 0, time.UTC)),
},
inputExistingACLToken: nil,
expectedErrorContains: "expiration time cannot be more than",
},
{
name: "valid management",
inputACLToken: &ACLToken{
Type: ACLManagementToken,
Name: "foo",
},
inputExistingACLToken: nil,
expectedErrorContains: "",
},
{
name: "valid client",
inputACLToken: &ACLToken{
Type: ACLClientToken,
Name: "foo",
Policies: []string{"foo"},
},
inputExistingACLToken: nil,
expectedErrorContains: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualOutputError := tc.inputACLToken.Validate(1*time.Minute, 24*time.Hour, tc.inputExistingACLToken)
if tc.expectedErrorContains != "" {
require.ErrorContains(t, actualOutputError, tc.expectedErrorContains)
} else {
require.NoError(t, actualOutputError)
}
})
}
}
func TestACLToken_HasExpirationTime(t *testing.T) {
testCases := []struct {
name string
inputACLToken *ACLToken
expectedOutput bool ``
}{
{
name: "nil acl token",
inputACLToken: nil,
expectedOutput: false,
},
{
name: "default empty value",
inputACLToken: &ACLToken{},
expectedOutput: false,
},
{
name: "expiration set to now",
inputACLToken: &ACLToken{
ExpirationTime: helper.TimeTimeToPtr(time.Now().UTC()),
},
expectedOutput: true,
},
{
name: "expiration set to past",
inputACLToken: &ACLToken{
ExpirationTime: helper.TimeTimeToPtr(time.Date(2022, time.February, 21, 19, 35, 0, 0, time.UTC)),
},
expectedOutput: true,
},
{
name: "expiration set to future",
inputACLToken: &ACLToken{
ExpirationTime: helper.TimeTimeToPtr(time.Date(2087, time.April, 25, 12, 0, 0, 0, time.UTC)),
},
expectedOutput: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualOutput := tc.inputACLToken.HasExpirationTime()
require.Equal(t, tc.expectedOutput, actualOutput)
})
}
}
func TestACLToken_IsExpired(t *testing.T) {
testCases := []struct {
name string
inputACLToken *ACLToken
inputTime time.Time
expectedOutput bool
}{
{
name: "token without expiry",
inputACLToken: &ACLToken{},
inputTime: time.Now().UTC(),
expectedOutput: false,
},
{
name: "empty input time",
inputACLToken: &ACLToken{},
inputTime: time.Time{},
expectedOutput: false,
},
{
name: "token not expired",
inputACLToken: &ACLToken{
ExpirationTime: helper.TimeTimeToPtr(time.Date(2022, time.May, 9, 10, 27, 0, 0, time.UTC)),
},
inputTime: time.Date(2022, time.May, 9, 10, 26, 0, 0, time.UTC),
expectedOutput: false,
},
{
name: "token expired",
inputACLToken: &ACLToken{
ExpirationTime: helper.TimeTimeToPtr(time.Date(2022, time.May, 9, 10, 27, 0, 0, time.UTC)),
},
inputTime: time.Date(2022, time.May, 9, 10, 28, 0, 0, time.UTC),
expectedOutput: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualOutput := tc.inputACLToken.IsExpired(tc.inputTime)
require.Equal(t, tc.expectedOutput, actualOutput)
})
}
}
......@@ -11721,14 +11721,26 @@ type ACLPolicyUpsertRequest struct {
// ACLToken represents a client token which is used to Authenticate
type ACLToken struct {
AccessorID string // Public Accessor ID (UUID)
SecretID string // Secret ID, private (UUID)
Name string // Human friendly name
Type string // Client or Management
Policies []string // Policies this token ties to
Global bool // Global or Region local
Hash []byte
CreateTime time.Time // Time of creation
AccessorID string // Public Accessor ID (UUID)
SecretID string // Secret ID, private (UUID)
Name string // Human friendly name
Type string // Client or Management
Policies []string // Policies this token ties to
Global bool // Global or Region local
Hash []byte
CreateTime time.Time // Time of creation
// ExpirationTime represents the point after which a token should be
// considered revoked and is eligible for destruction. This time should
// always use UTC to account for multi-region global tokens. It is a
// pointer, so we can store nil, rather than the zero value of time.Time.
ExpirationTime *time.Time
// ExpirationTTL is a convenience field for helping set ExpirationTime to a
// value of CreateTime+ExpirationTTL. This can only be set during token
// creation. This is a string version of a time.Duration like "2m".
ExpirationTTL time.Duration
CreateIndex uint64
ModifyIndex uint64
}
......@@ -11775,18 +11787,21 @@ var (
)
type ACLTokenListStub struct {
AccessorID string
Name string
Type string
Policies []string
Global bool
Hash []byte
CreateTime time.Time
CreateIndex uint64
ModifyIndex uint64
AccessorID string
Name string
Type string
Policies []string
Global bool
Hash []byte
CreateTime time.Time
ExpirationTime *time.Time
CreateIndex uint64
ModifyIndex uint64
}
// SetHash is used to compute and set the hash of the ACL token
// SetHash is used to compute and set the hash of the ACL token. It only hashes
// fields which can be updated, and as such, does not hash fields such as
// ExpirationTime.
func (a *ACLToken) SetHash() []byte {
// Initialize a 256bit Blake2 hash (32 bytes)
hash, err := blake2b.New256(nil)
......@@ -11816,37 +11831,17 @@ func (a *ACLToken) SetHash() []byte {
func (a *ACLToken) Stub() *ACLTokenListStub {
return &ACLTokenListStub{
AccessorID: a.AccessorID,
Name: a.Name,
Type: a.Type,
Policies: a.Policies,
Global: a.Global,
Hash: a.Hash,
CreateTime: a.CreateTime,
CreateIndex: a.CreateIndex,
ModifyIndex: a.ModifyIndex,
}
}
// Validate is used to check a token for reasonableness
func (a *ACLToken) Validate() error {
var mErr multierror.Error
if len(a.Name) > maxTokenNameLength {
mErr.Errors = append(mErr.Errors, fmt.Errorf("token name too long"))
AccessorID: a.AccessorID,
Name: a.Name,
Type: a.Type,
Policies: a.Policies,
Global: a.Global,
Hash: a.Hash,
CreateTime: a.CreateTime,
ExpirationTime: a.ExpirationTime,
CreateIndex: a.CreateIndex,
ModifyIndex: a.ModifyIndex,
}
switch a.Type {
case ACLClientToken:
if len(a.Policies) == 0 {
mErr.Errors = append(mErr.Errors, fmt.Errorf("client token missing policies"))
}
case ACLManagementToken:
if len(a.Policies) != 0 {
mErr.Errors = append(mErr.Errors, fmt.Errorf("management token cannot be associated with policies"))
}
default:
mErr.Errors = append(mErr.Errors, fmt.Errorf("token type must be client or management"))
}
return mErr.ErrorOrNil()
}
// PolicySubset checks if a given set of policies is a subset of the token
......
......@@ -5994,53 +5994,6 @@ func TestIsRecoverable(t *testing.T) {
}
}
func TestACLTokenValidate(t *testing.T) {
ci.Parallel(t)
tk := &ACLToken{}
// Missing a type
err := tk.Validate()
assert.NotNil(t, err)
if !strings.Contains(err.Error(), "client or management") {
t.Fatalf("bad: %v", err)
}
// Missing policies
tk.Type = ACLClientToken
err = tk.Validate()
assert.NotNil(t, err)
if !strings.Contains(err.Error(), "missing policies") {
t.Fatalf("bad: %v", err)
}
// Invalid policies
tk.Type = ACLManagementToken
tk.Policies = []string{"foo"}
err = tk.Validate()
assert.NotNil(t, err)
if !strings.Contains(err.Error(), "associated with policies") {
t.Fatalf("bad: %v", err)
}
// Name too long policies
tk.Name = ""
for i := 0; i < 8; i++ {
tk.Name += uuid.Generate()
}
tk.Policies = nil
err = tk.Validate()
assert.NotNil(t, err)
if !strings.Contains(err.Error(), "too long") {
t.Fatalf("bad: %v", err)
}
// Make it valid
tk.Name = "foo"
err = tk.Validate()
assert.Nil(t, err)
}
func TestACLTokenPolicySubset(t *testing.T) {
ci.Parallel(t)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment