Commit 67481cdb authored by Alex Dadgar's avatar Alex Dadgar Committed by GitHub
Browse files

Merge pull request #1659 from hashicorp/f-revoke-accessors

Token revocation and keeping only a single Vault client active among servers
parents daf3ca2c 743cbeba
Showing with 1305 additions and 183 deletions
+1305 -183
......@@ -140,6 +140,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} {
return n.applyReconcileSummaries(buf[1:], log.Index)
case structs.VaultAccessorRegisterRequestType:
return n.applyUpsertVaultAccessor(buf[1:], log.Index)
case structs.VaultAccessorDegisterRequestType:
return n.applyDeregisterVaultAccessor(buf[1:], log.Index)
default:
if ignoreUnknown {
n.logger.Printf("[WARN] nomad.fsm: ignoring unknown message type (%d), upgrade to newer version", msgType)
......@@ -472,7 +474,7 @@ func (n *nomadFSM) applyReconcileSummaries(buf []byte, index uint64) interface{}
// and task
func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{} {
defer metrics.MeasureSince([]string{"nomad", "fsm", "upsert_vault_accessor"}, time.Now())
var req structs.VaultAccessorRegisterRequest
var req structs.VaultAccessorsRequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}
......@@ -485,6 +487,22 @@ func (n *nomadFSM) applyUpsertVaultAccessor(buf []byte, index uint64) interface{
return nil
}
// applyDeregisterVaultAccessor deregisters a set of Vault accessors
func (n *nomadFSM) applyDeregisterVaultAccessor(buf []byte, index uint64) interface{} {
defer metrics.MeasureSince([]string{"nomad", "fsm", "deregister_vault_accessor"}, time.Now())
var req structs.VaultAccessorsRequest
if err := structs.Decode(buf, &req); err != nil {
panic(fmt.Errorf("failed to decode request: %v", err))
}
if err := n.state.DeleteVaultAccessors(index, req.Accessors); err != nil {
n.logger.Printf("[ERR] nomad.fsm: DeregisterVaultAccessor failed: %v", err)
return err
}
return nil
}
func (n *nomadFSM) Snapshot() (raft.FSMSnapshot, error) {
// Create a new snapshot
snap, err := n.state.Snapshot()
......
......@@ -777,7 +777,7 @@ func TestFSM_UpsertVaultAccessor(t *testing.T) {
va := mock.VaultAccessor()
va2 := mock.VaultAccessor()
req := structs.VaultAccessorRegisterRequest{
req := structs.VaultAccessorsRequest{
Accessors: []*structs.VaultAccessor{va, va2},
}
buf, err := structs.Encode(structs.VaultAccessorRegisterRequestType, req)
......@@ -819,6 +819,47 @@ func TestFSM_UpsertVaultAccessor(t *testing.T) {
}
}
func TestFSM_DeregisterVaultAccessor(t *testing.T) {
fsm := testFSM(t)
fsm.blockedEvals.SetEnabled(true)
va := mock.VaultAccessor()
va2 := mock.VaultAccessor()
accessors := []*structs.VaultAccessor{va, va2}
// Insert the accessors
if err := fsm.State().UpsertVaultAccessor(1000, accessors); err != nil {
t.Fatalf("bad: %v", err)
}
req := structs.VaultAccessorsRequest{
Accessors: accessors,
}
buf, err := structs.Encode(structs.VaultAccessorDegisterRequestType, req)
if err != nil {
t.Fatalf("err: %v", err)
}
resp := fsm.Apply(makeLog(buf))
if resp != nil {
t.Fatalf("resp: %v", resp)
}
out1, err := fsm.State().VaultAccessor(va.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if out1 != nil {
t.Fatalf("not deleted!")
}
tt := fsm.TimeTable()
index := tt.NearestIndex(time.Now().UTC())
if index != 1 {
t.Fatalf("bad: %d", index)
}
}
func testSnapshotRestore(t *testing.T, fsm *nomadFSM) *nomadFSM {
// Snapshot
snap, err := fsm.Snapshot()
......
package nomad
import (
"context"
"errors"
"fmt"
"time"
......@@ -132,6 +133,12 @@ func (s *Server) establishLeadership(stopCh chan struct{}) error {
return err
}
// Activate the vault client
s.vault.SetActive(true)
if err := s.restoreRevokingAccessors(); err != nil {
return err
}
// Enable the periodic dispatcher, since we are now the leader.
s.periodicDispatcher.SetEnabled(true)
s.periodicDispatcher.Start()
......@@ -205,6 +212,57 @@ func (s *Server) restoreEvals() error {
return nil
}
// restoreRevokingAccessors is used to restore Vault accessors that should be
// revoked.
func (s *Server) restoreRevokingAccessors() error {
// An accessor should be revoked if its allocation or node is terminal
state := s.fsm.State()
iter, err := state.VaultAccessors()
if err != nil {
return fmt.Errorf("failed to get vault accessors: %v", err)
}
var revoke []*structs.VaultAccessor
for {
raw := iter.Next()
if raw == nil {
break
}
va := raw.(*structs.VaultAccessor)
// Check the allocation
alloc, err := state.AllocByID(va.AllocID)
if err != nil {
return fmt.Errorf("failed to lookup allocation: %v", va.AllocID, err)
}
if alloc == nil || alloc.Terminated() {
// No longer running and should be revoked
revoke = append(revoke, va)
continue
}
// Check the node
node, err := state.NodeByID(va.NodeID)
if err != nil {
return fmt.Errorf("failed to lookup node %q: %v", va.NodeID, err)
}
if node == nil || node.TerminalStatus() {
// Node is terminal so any accessor from it should be revoked
revoke = append(revoke, va)
continue
}
}
if len(revoke) != 0 {
if err := s.vault.RevokeTokens(context.Background(), revoke, true); err != nil {
return fmt.Errorf("failed to revoke tokens: %v", err)
}
}
return nil
}
// restorePeriodicDispatcher is used to restore all periodic jobs into the
// periodic dispatcher. It also determines if a periodic job should have been
// created during the leadership transition and force runs them. The periodic
......@@ -409,6 +467,9 @@ func (s *Server) revokeLeadership() error {
// Disable the periodic dispatcher, since it is only useful as a leader
s.periodicDispatcher.SetEnabled(false)
// Disable the Vault client as it is only useful as a leader.
s.vault.SetActive(false)
// Clear the heartbeat timers on either shutdown or step down,
// since we are no longer responsible for TTL expirations.
if err := s.clearAllHeartbeatTimers(); err != nil {
......
......@@ -544,3 +544,31 @@ func TestLeader_ReapDuplicateEval(t *testing.T) {
t.Fatalf("err: %v", err)
})
}
func TestLeader_RestoreVaultAccessors(t *testing.T) {
s1 := testServer(t, func(c *Config) {
c.NumSchedulers = 0
})
defer s1.Shutdown()
testutil.WaitForLeader(t, s1.RPC)
// Insert a vault accessor that should be revoked
state := s1.fsm.State()
va := mock.VaultAccessor()
if err := state.UpsertVaultAccessor(100, []*structs.VaultAccessor{va}); err != nil {
t.Fatalf("bad: %v", err)
}
// Swap the Vault client
tvc := &TestVaultClient{}
s1.vault = tvc
// Do a restore
if err := s1.restoreRevokingAccessors(); err != nil {
t.Fatalf("Failed to restore: %v", err)
}
if len(tvc.RevokedTokens) != 1 && tvc.RevokedTokens[0].Accessor != va.Accessor {
t.Fatalf("Bad revoked accessors: %v", tvc.RevokedTokens)
}
}
......@@ -11,6 +11,7 @@ import (
"github.com/armon/go-metrics"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/nomad/nomad/state"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/watch"
......@@ -215,7 +216,7 @@ func (n *Node) constructNodeServerInfoResponse(snap *state.StateSnapshot, reply
return nil
}
// Deregister is used to remove a client from the client. If a client should
// Deregister is used to remove a client from the cluster. If a client should
// just be made unavailable for scheduling, a status update is preferred.
func (n *Node) Deregister(args *structs.NodeDeregisterRequest, reply *structs.NodeUpdateResponse) error {
if done, err := n.srv.forward("Node.Deregister", args, args, reply); done {
......@@ -245,6 +246,20 @@ func (n *Node) Deregister(args *structs.NodeDeregisterRequest, reply *structs.No
return err
}
// Determine if there are any Vault accessors on the node
accessors, err := n.srv.State().VaultAccessorsByNode(args.NodeID)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for node %q failed: %v", args.NodeID, err)
return err
}
if len(accessors) != 0 {
if err := n.srv.vault.RevokeTokens(context.Background(), accessors, true); err != nil {
n.srv.logger.Printf("[ERR] nomad.client: revoking accessors for node %q failed: %v", args.NodeID, err)
return err
}
}
// Setup the reply
reply.EvalIDs = evalIDs
reply.EvalCreateIndex = evalIndex
......@@ -311,7 +326,22 @@ func (n *Node) UpdateStatus(args *structs.NodeUpdateStatusRequest, reply *struct
}
// Check if we need to setup a heartbeat
if args.Status != structs.NodeStatusDown {
switch args.Status {
case structs.NodeStatusDown:
// Determine if there are any Vault accessors on the node
accessors, err := n.srv.State().VaultAccessorsByNode(args.NodeID)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for node %q failed: %v", args.NodeID, err)
return err
}
if len(accessors) != 0 {
if err := n.srv.vault.RevokeTokens(context.Background(), accessors, true); err != nil {
n.srv.logger.Printf("[ERR] nomad.client: revoking accessors for node %q failed: %v", args.NodeID, err)
return err
}
}
default:
ttl, err := n.srv.resetHeartbeatTimer(args.NodeID)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: heartbeat reset failed: %v", err)
......@@ -686,13 +716,41 @@ func (n *Node) batchUpdate(future *batchFuture, updates []*structs.Allocation) {
}
// Commit this update via Raft
var mErr multierror.Error
_, index, err := n.srv.raftApply(structs.AllocClientUpdateRequestType, batch)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: alloc update failed: %v", err)
mErr.Errors = append(mErr.Errors, err)
}
// For each allocation we are updating check if we should revoke any
// Vault Accessors
var revoke []*structs.VaultAccessor
for _, alloc := range updates {
// Skip any allocation that isn't dead on the client
if !alloc.Terminated() {
continue
}
// Determine if there are any Vault accessors for the allocation
accessors, err := n.srv.State().VaultAccessorsByAlloc(alloc.ID)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: looking up accessors for alloc %q failed: %v", alloc.ID, err)
mErr.Errors = append(mErr.Errors, err)
}
revoke = append(revoke, accessors...)
}
if len(revoke) != 0 {
if err := n.srv.vault.RevokeTokens(context.Background(), revoke, true); err != nil {
n.srv.logger.Printf("[ERR] nomad.client: batched accessor revocation failed: %v", err)
mErr.Errors = append(mErr.Errors, err)
}
}
// Respond to the future
future.Respond(index, err)
future.Respond(index, mErr.ErrorOrNil())
}
// List is used to list the available nodes
......@@ -1011,10 +1069,6 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
// Wait for everything to complete or for an error
err = g.Wait()
if err != nil {
// TODO Revoke any created token
return err
}
// Commit to Raft before returning any of the tokens
accessors := make([]*structs.VaultAccessor, 0, len(results))
......@@ -1037,7 +1091,17 @@ func (n *Node) DeriveVaultToken(args *structs.DeriveVaultTokenRequest,
accessors = append(accessors, accessor)
}
req := structs.VaultAccessorRegisterRequest{Accessors: accessors}
// If there was an error revoke the created tokens
if err != nil {
var mErr multierror.Error
mErr.Errors = append(mErr.Errors, err)
if err := n.srv.vault.RevokeTokens(context.Background(), accessors, false); err != nil {
mErr.Errors = append(mErr.Errors, err)
}
return mErr.ErrorOrNil()
}
req := structs.VaultAccessorsRequest{Accessors: accessors}
_, index, err := n.srv.raftApply(structs.VaultAccessorRegisterRequestType, &req)
if err != nil {
n.srv.logger.Printf("[ERR] nomad.client: Register Vault accessors failed: %v", err)
......
......@@ -170,6 +170,65 @@ func TestClientEndpoint_Deregister(t *testing.T) {
}
}
func TestClientEndpoint_Deregister_Vault(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
codec := rpcClient(t, s1)
testutil.WaitForLeader(t, s1.RPC)
// Create the register request
node := mock.Node()
reg := &structs.NodeRegisterRequest{
Node: node,
WriteRequest: structs.WriteRequest{Region: "global"},
}
// Fetch the response
var resp structs.GenericResponse
if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil {
t.Fatalf("err: %v", err)
}
// Swap the servers Vault Client
tvc := &TestVaultClient{}
s1.vault = tvc
// Put some Vault accessors in the state store for that node
state := s1.fsm.State()
va1 := mock.VaultAccessor()
va1.NodeID = node.ID
va2 := mock.VaultAccessor()
va2.NodeID = node.ID
state.UpsertVaultAccessor(100, []*structs.VaultAccessor{va1, va2})
// Deregister
dereg := &structs.NodeDeregisterRequest{
NodeID: node.ID,
WriteRequest: structs.WriteRequest{Region: "global"},
}
var resp2 structs.GenericResponse
if err := msgpackrpc.CallWithCodec(codec, "Node.Deregister", dereg, &resp2); err != nil {
t.Fatalf("err: %v", err)
}
if resp2.Index == 0 {
t.Fatalf("bad index: %d", resp2.Index)
}
// Check for the node in the FSM
out, err := state.NodeByID(node.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if out != nil {
t.Fatalf("unexpected node")
}
// Check that the endpoint revoked the tokens
if l := len(tvc.RevokedTokens); l != 2 {
t.Fatalf("Deregister revoked %d tokens; want 2", l)
}
}
func TestClientEndpoint_UpdateStatus(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
......@@ -229,6 +288,63 @@ func TestClientEndpoint_UpdateStatus(t *testing.T) {
}
}
func TestClientEndpoint_UpdateStatus_Vault(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
codec := rpcClient(t, s1)
testutil.WaitForLeader(t, s1.RPC)
// Create the register request
node := mock.Node()
reg := &structs.NodeRegisterRequest{
Node: node,
WriteRequest: structs.WriteRequest{Region: "global"},
}
// Fetch the response
var resp structs.NodeUpdateResponse
if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil {
t.Fatalf("err: %v", err)
}
// Check for heartbeat interval
ttl := resp.HeartbeatTTL
if ttl < s1.config.MinHeartbeatTTL || ttl > 2*s1.config.MinHeartbeatTTL {
t.Fatalf("bad: %#v", ttl)
}
// Swap the servers Vault Client
tvc := &TestVaultClient{}
s1.vault = tvc
// Put some Vault accessors in the state store for that node
state := s1.fsm.State()
va1 := mock.VaultAccessor()
va1.NodeID = node.ID
va2 := mock.VaultAccessor()
va2.NodeID = node.ID
state.UpsertVaultAccessor(100, []*structs.VaultAccessor{va1, va2})
// Update the status to be down
dereg := &structs.NodeUpdateStatusRequest{
NodeID: node.ID,
Status: structs.NodeStatusDown,
WriteRequest: structs.WriteRequest{Region: "global"},
}
var resp2 structs.NodeUpdateResponse
if err := msgpackrpc.CallWithCodec(codec, "Node.UpdateStatus", dereg, &resp2); err != nil {
t.Fatalf("err: %v", err)
}
if resp2.Index == 0 {
t.Fatalf("bad index: %d", resp2.Index)
}
// Check that the endpoint revoked the tokens
if l := len(tvc.RevokedTokens); l != 2 {
t.Fatalf("Deregister revoked %d tokens; want 2", l)
}
}
func TestClientEndpoint_Register_GetEvals(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
......@@ -1235,6 +1351,81 @@ func TestClientEndpoint_BatchUpdate(t *testing.T) {
}
}
func TestClientEndpoint_UpdateAlloc_Vault(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
codec := rpcClient(t, s1)
testutil.WaitForLeader(t, s1.RPC)
// Create the register request
node := mock.Node()
reg := &structs.NodeRegisterRequest{
Node: node,
WriteRequest: structs.WriteRequest{Region: "global"},
}
// Fetch the response
var resp structs.GenericResponse
if err := msgpackrpc.CallWithCodec(codec, "Node.Register", reg, &resp); err != nil {
t.Fatalf("err: %v", err)
}
// Swap the servers Vault Client
tvc := &TestVaultClient{}
s1.vault = tvc
// Inject fake allocation and vault accessor
alloc := mock.Alloc()
alloc.NodeID = node.ID
state := s1.fsm.State()
state.UpsertJobSummary(99, mock.JobSummary(alloc.JobID))
if err := state.UpsertAllocs(100, []*structs.Allocation{alloc}); err != nil {
t.Fatalf("err: %v", err)
}
va := mock.VaultAccessor()
va.NodeID = node.ID
va.AllocID = alloc.ID
if err := state.UpsertVaultAccessor(101, []*structs.VaultAccessor{va}); err != nil {
t.Fatalf("err: %v", err)
}
// Attempt update
clientAlloc := new(structs.Allocation)
*clientAlloc = *alloc
clientAlloc.ClientStatus = structs.AllocClientStatusFailed
// Update the alloc
update := &structs.AllocUpdateRequest{
Alloc: []*structs.Allocation{clientAlloc},
WriteRequest: structs.WriteRequest{Region: "global"},
}
var resp2 structs.NodeAllocsResponse
start := time.Now()
if err := msgpackrpc.CallWithCodec(codec, "Node.UpdateAlloc", update, &resp2); err != nil {
t.Fatalf("err: %v", err)
}
if resp2.Index == 0 {
t.Fatalf("Bad index: %d", resp2.Index)
}
if diff := time.Since(start); diff < batchUpdateInterval {
t.Fatalf("too fast: %v", diff)
}
// Lookup the alloc
out, err := state.AllocByID(alloc.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if out.ClientStatus != structs.AllocClientStatusFailed {
t.Fatalf("Bad: %#v", out)
}
if l := len(tvc.RevokedTokens); l != 1 {
t.Fatalf("Deregister revoked %d tokens; want 1", l)
}
}
func TestClientEndpoint_CreateNodeEvals(t *testing.T) {
s1 := testServer(t, nil)
defer s1.Shutdown()
......
......@@ -567,7 +567,7 @@ func (s *Server) setupConsulSyncer() error {
// setupVaultClient is used to set up the Vault API client.
func (s *Server) setupVaultClient() error {
v, err := NewVaultClient(s.config.VaultConfig, s.logger)
v, err := NewVaultClient(s.config.VaultConfig, s.logger, s.purgeVaultAccessors)
if err != nil {
return err
}
......
......@@ -1154,24 +1154,19 @@ func (s *StateStore) UpsertVaultAccessor(index uint64, accessors []*structs.Vaul
return nil
}
// DeleteVaultAccessor is used to delete a Vault Accessor
func (s *StateStore) DeleteVaultAccessor(index uint64, accessor string) error {
// DeleteVaultAccessors is used to delete a set of Vault Accessors
func (s *StateStore) DeleteVaultAccessors(index uint64, accessors []*structs.VaultAccessor) error {
txn := s.db.Txn(true)
defer txn.Abort()
// Lookup the accessor
existing, err := txn.First("vault_accessors", "id", accessor)
if err != nil {
return fmt.Errorf("accessor lookup failed: %v", err)
}
if existing == nil {
return fmt.Errorf("vault_accessor not found")
for _, accessor := range accessors {
// Delete the accessor
if err := txn.Delete("vault_accessors", accessor); err != nil {
return fmt.Errorf("accessor delete failed: %v", err)
}
}
// Delete the accessor
if err := txn.Delete("vault_accessors", existing); err != nil {
return fmt.Errorf("accessor delete failed: %v", err)
}
if err := txn.Insert("index", &IndexEntry{"vault_accessors", index}); err != nil {
return fmt.Errorf("index update failed: %v", err)
}
......
......@@ -2973,27 +2973,35 @@ func TestStateStore_UpsertVaultAccessors(t *testing.T) {
}
}
func TestStateStore_DeleteVaultAccessor(t *testing.T) {
func TestStateStore_DeleteVaultAccessors(t *testing.T) {
state := testStateStore(t)
accessor := mock.VaultAccessor()
a1 := mock.VaultAccessor()
a2 := mock.VaultAccessor()
accessors := []*structs.VaultAccessor{a1, a2}
err := state.UpsertVaultAccessor(1000, []*structs.VaultAccessor{accessor})
err := state.UpsertVaultAccessor(1000, accessors)
if err != nil {
t.Fatalf("err: %v", err)
}
err = state.DeleteVaultAccessor(1001, accessor.Accessor)
err = state.DeleteVaultAccessors(1001, accessors)
if err != nil {
t.Fatalf("err: %v", err)
}
out, err := state.VaultAccessor(accessor.Accessor)
out, err := state.VaultAccessor(a1.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if out != nil {
t.Fatalf("bad: %#v %#v", a1, out)
}
out, err = state.VaultAccessor(a2.Accessor)
if err != nil {
t.Fatalf("err: %v", err)
}
if out != nil {
t.Fatalf("bad: %#v %#v", accessor, out)
t.Fatalf("bad: %#v %#v", a2, out)
}
index, err := state.Index("vault_accessors")
......
......@@ -48,6 +48,7 @@ const (
AllocClientUpdateRequestType
ReconcileJobSummariesRequestType
VaultAccessorRegisterRequestType
VaultAccessorDegisterRequestType
)
const (
......@@ -365,8 +366,8 @@ type DeriveVaultTokenRequest struct {
QueryOptions
}
// VaultAccessorRegisterRequest is used to register a set of Vault accessors
type VaultAccessorRegisterRequest struct {
// VaultAccessorsRequest is used to operate on a set of Vault accessors
type VaultAccessorsRequest struct {
Accessors []*VaultAccessor
}
......
This diff is collapsed.
......@@ -31,24 +31,19 @@ func TestVaultClient_BadConfig(t *testing.T) {
logger := log.New(os.Stderr, "", log.LstdFlags)
// Should be no error since Vault is not enabled
client, err := NewVaultClient(conf, logger)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer client.Stop()
if client.ConnectionEstablished() {
t.Fatalf("bad")
_, err := NewVaultClient(nil, logger, nil)
if err == nil || !strings.Contains(err.Error(), "valid") {
t.Fatalf("expected config error: %v", err)
}
conf.Enabled = true
_, err = NewVaultClient(conf, logger)
_, err = NewVaultClient(conf, logger, nil)
if err == nil || !strings.Contains(err.Error(), "token must be set") {
t.Fatalf("Expected token unset error: %v", err)
}
conf.Token = "123"
_, err = NewVaultClient(conf, logger)
_, err = NewVaultClient(conf, logger, nil)
if err == nil || !strings.Contains(err.Error(), "address must be set") {
t.Fatalf("Expected address unset error: %v", err)
}
......@@ -62,7 +57,7 @@ func TestVaultClient_EstablishConnection(t *testing.T) {
logger := log.New(os.Stderr, "", log.LstdFlags)
v.Config.ConnectionRetryIntv = 100 * time.Millisecond
client, err := NewVaultClient(v.Config, logger)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
......@@ -79,11 +74,69 @@ func TestVaultClient_EstablishConnection(t *testing.T) {
v.Start()
waitForConnection(client, t)
}
func TestVaultClient_SetActive(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
// Ensure that since we are using a root token that we haven started the
// renewal loop.
if client.renewalRunning {
t.Fatalf("No renewal loop should be running")
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()
waitForConnection(client, t)
// Do a lookup and expect an error about not being active
_, err = client.LookupToken(context.Background(), "123")
if err == nil || !strings.Contains(err.Error(), "not active") {
t.Fatalf("Expected not-active error: %v", err)
}
client.SetActive(true)
// Do a lookup of ourselves
_, err = client.LookupToken(context.Background(), v.RootToken)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
}
// Test that we can update the config and things keep working
func TestVaultClient_SetConfig(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
v2 := testutil.NewTestVault(t).Start()
defer v2.Stop()
// Set the configs token in a new test role
v2.Config.Token = testVaultRoleAndToken(v2, t, 20)
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
defer client.Stop()
waitForConnection(client, t)
if client.tokenData == nil || len(client.tokenData.Policies) != 1 {
t.Fatalf("unexpected token: %v", client.tokenData)
}
// Update the config
if err := client.SetConfig(v2.Config); err != nil {
t.Fatalf("SetConfig failed: %v", err)
}
waitForConnection(client, t)
if client.tokenData == nil || len(client.tokenData.Policies) != 2 {
t.Fatalf("unexpected token: %v", client.tokenData)
}
}
......@@ -128,7 +181,7 @@ func TestVaultClient_RenewalLoop(t *testing.T) {
// Start the client
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
......@@ -177,53 +230,36 @@ func parseTTLFromLookup(s *vapi.Secret, t *testing.T) int64 {
func TestVaultClient_LookupToken_Invalid(t *testing.T) {
conf := &config.VaultConfig{
Enabled: false,
Enabled: true,
Addr: "http://foobar:12345",
Token: structs.GenerateUUID(),
}
// Enable vault but use a bad address so it never establishes a conn
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(conf, logger)
client, err := NewVaultClient(conf, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
_, err = client.LookupToken(context.Background(), "foo")
if err == nil || !strings.Contains(err.Error(), "disabled") {
t.Fatalf("Expected error because Vault is disabled: %v", err)
}
// Enable vault but use a bad address so it never establishes a conn
conf.Enabled = true
conf.Addr = "http://foobar:12345"
conf.Token = structs.GenerateUUID()
client, err = NewVaultClient(conf, logger)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
_, err = client.LookupToken(context.Background(), "foo")
if err == nil || !strings.Contains(err.Error(), "established") {
t.Fatalf("Expected error because connection to Vault hasn't been made: %v", err)
}
}
func waitForConnection(v *vaultClient, t *testing.T) {
testutil.WaitForResult(func() (bool, error) {
return v.ConnectionEstablished(), nil
}, func(err error) {
t.Fatalf("Connection not established")
})
}
func TestVaultClient_LookupToken(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
waitForConnection(client, t)
......@@ -280,10 +316,11 @@ func TestVaultClient_LookupToken_RateLimit(t *testing.T) {
defer v.Stop()
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
client.setLimit(rate.Limit(1.0))
......@@ -334,10 +371,11 @@ func TestVaultClient_CreateToken_Root(t *testing.T) {
defer v.Stop()
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
waitForConnection(client, t)
......@@ -380,10 +418,11 @@ func TestVaultClient_CreateToken_Role(t *testing.T) {
//testVaultRoleAndToken(v, t, 5)
// Start the client
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
waitForConnection(client, t)
......@@ -416,3 +455,110 @@ func TestVaultClient_CreateToken_Role(t *testing.T) {
t.Fatalf("Bad ttl: %v", s.WrapInfo.WrappedAccessor)
}
}
func TestVaultClient_RevokeTokens_PreEstablishs(t *testing.T) {
v := testutil.NewTestVault(t)
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger, nil)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
// Create some VaultAccessors
vas := []*structs.VaultAccessor{
mock.VaultAccessor(),
mock.VaultAccessor(),
}
if err := client.RevokeTokens(context.Background(), vas, false); err != nil {
t.Fatalf("RevokeTokens failed: %v", err)
}
// Wasn't committed
if len(client.revoking) != 0 {
t.Fatalf("didn't add to revoke loop")
}
if err := client.RevokeTokens(context.Background(), vas, true); err != nil {
t.Fatalf("RevokeTokens failed: %v", err)
}
// Was committed
if len(client.revoking) != 2 {
t.Fatalf("didn't add to revoke loop")
}
}
func TestVaultClient_RevokeTokens(t *testing.T) {
v := testutil.NewTestVault(t).Start()
defer v.Stop()
purged := 0
purge := func(accessors []*structs.VaultAccessor) error {
purged += len(accessors)
return nil
}
logger := log.New(os.Stderr, "", log.LstdFlags)
client, err := NewVaultClient(v.Config, logger, purge)
if err != nil {
t.Fatalf("failed to build vault client: %v", err)
}
client.SetActive(true)
defer client.Stop()
waitForConnection(client, t)
// Create some vault tokens
auth := v.Client.Auth().Token()
req := vapi.TokenCreateRequest{
Policies: []string{"default"},
}
t1, err := auth.Create(&req)
if err != nil {
t.Fatalf("Failed to create vault token: %v", err)
}
if t1 == nil || t1.Auth == nil {
t.Fatalf("bad secret response: %+v", t1)
}
t2, err := auth.Create(&req)
if err != nil {
t.Fatalf("Failed to create vault token: %v", err)
}
if t2 == nil || t2.Auth == nil {
t.Fatalf("bad secret response: %+v", t2)
}
// Create two VaultAccessors
vas := []*structs.VaultAccessor{
&structs.VaultAccessor{Accessor: t1.Auth.Accessor},
&structs.VaultAccessor{Accessor: t2.Auth.Accessor},
}
// Issue a token revocation
if err := client.RevokeTokens(context.Background(), vas, true); err != nil {
t.Fatalf("RevokeTokens failed: %v", err)
}
// Lookup the token and make sure we get an error
if s, err := auth.Lookup(t1.Auth.ClientToken); err == nil {
t.Fatalf("Revoked token lookup didn't fail: %+v", s)
}
if s, err := auth.Lookup(t2.Auth.ClientToken); err == nil {
t.Fatalf("Revoked token lookup didn't fail: %+v", s)
}
if purged != 2 {
t.Fatalf("Expected purged 2; got %d", purged)
}
}
func waitForConnection(v *vaultClient, t *testing.T) {
testutil.WaitForResult(func() (bool, error) {
return v.ConnectionEstablished(), nil
}, func(err error) {
t.Fatalf("Connection not established")
})
}
......@@ -4,6 +4,7 @@ import (
"context"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/nomad/structs/config"
vapi "github.com/hashicorp/vault/api"
)
......@@ -26,6 +27,8 @@ type TestVaultClient struct {
// CreateTokenSecret maps a token to the Vault secret that will be returned
// by the CreateToken call
CreateTokenSecret map[string]map[string]*vapi.Secret
RevokedTokens []*structs.VaultAccessor
}
func (v *TestVaultClient) LookupToken(ctx context.Context, token string) (*vapi.Secret, error) {
......@@ -126,4 +129,11 @@ func (v *TestVaultClient) SetCreateTokenSecret(allocID, task string, secret *vap
v.CreateTokenSecret[allocID][task] = secret
}
func (v *TestVaultClient) Stop() {}
func (v *TestVaultClient) RevokeTokens(ctx context.Context, accessors []*structs.VaultAccessor, committed bool) error {
v.RevokedTokens = append(v.RevokedTokens, accessors...)
return nil
}
func (v *TestVaultClient) Stop() {}
func (v *TestVaultClient) SetActive(enabled bool) {}
func (v *TestVaultClient) SetConfig(config *config.VaultConfig) error { return nil }
......@@ -66,7 +66,7 @@ func NewTestVault(t *testing.T) *TestVault {
t: t,
Addr: bind,
HTTPAddr: http,
RootToken: root,
RootToken: token,
Client: client,
Config: &config.VaultConfig{
Enabled: true,
......
tomb - support for clean goroutine termination in Go.
Copyright (c) 2010-2011 - Gustavo Niemeyer <gustavo@niemeyer.net>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Installation and usage
----------------------
See [gopkg.in/tomb.v2](https://gopkg.in/tomb.v2) for documentation and usage details.
// Copyright (c) 2011 - Gustavo Niemeyer <gustavo@niemeyer.net>
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// The tomb package handles clean goroutine tracking and termination.
//
// The zero value of a Tomb is ready to handle the creation of a tracked
// goroutine via its Go method, and then any tracked goroutine may call
// the Go method again to create additional tracked goroutines at
// any point.
//
// If any of the tracked goroutines returns a non-nil error, or the
// Kill or Killf method is called by any goroutine in the system (tracked
// or not), the tomb Err is set, Alive is set to false, and the Dying
// channel is closed to flag that all tracked goroutines are supposed
// to willingly terminate as soon as possible.
//
// Once all tracked goroutines terminate, the Dead channel is closed,
// and Wait unblocks and returns the first non-nil error presented
// to the tomb via a result or an explicit Kill or Killf method call,
// or nil if there were no errors.
//
// It is okay to create further goroutines via the Go method while
// the tomb is in a dying state. The final dead state is only reached
// once all tracked goroutines terminate, at which point calling
// the Go method again will cause a runtime panic.
//
// Tracked functions and methods that are still running while the tomb
// is in dying state may choose to return ErrDying as their error value.
// This preserves the well established non-nil error convention, but is
// understood by the tomb as a clean termination. The Err and Wait
// methods will still return nil if all observed errors were either
// nil or ErrDying.
//
// For background and a detailed example, see the following blog post:
//
// http://blog.labix.org/2011/10/09/death-of-goroutines-under-control
//
package tomb
import (
"errors"
"fmt"
"sync"
)
// A Tomb tracks the lifecycle of one or more goroutines as alive,
// dying or dead, and the reason for their death.
//
// See the package documentation for details.
type Tomb struct {
m sync.Mutex
alive int
dying chan struct{}
dead chan struct{}
reason error
}
var (
ErrStillAlive = errors.New("tomb: still alive")
ErrDying = errors.New("tomb: dying")
)
func (t *Tomb) init() {
t.m.Lock()
if t.dead == nil {
t.dead = make(chan struct{})
t.dying = make(chan struct{})
t.reason = ErrStillAlive
}
t.m.Unlock()
}
// Dead returns the channel that can be used to wait until
// all goroutines have finished running.
func (t *Tomb) Dead() <-chan struct{} {
t.init()
return t.dead
}
// Dying returns the channel that can be used to wait until
// t.Kill is called.
func (t *Tomb) Dying() <-chan struct{} {
t.init()
return t.dying
}
// Wait blocks until all goroutines have finished running, and
// then returns the reason for their death.
func (t *Tomb) Wait() error {
t.init()
<-t.dead
t.m.Lock()
reason := t.reason
t.m.Unlock()
return reason
}
// Go runs f in a new goroutine and tracks its termination.
//
// If f returns a non-nil error, t.Kill is called with that
// error as the death reason parameter.
//
// It is f's responsibility to monitor the tomb and return
// appropriately once it is in a dying state.
//
// It is safe for the f function to call the Go method again
// to create additional tracked goroutines. Once all tracked
// goroutines return, the Dead channel is closed and the
// Wait method unblocks and returns the death reason.
//
// Calling the Go method after all tracked goroutines return
// causes a runtime panic. For that reason, calling the Go
// method a second time out of a tracked goroutine is unsafe.
func (t *Tomb) Go(f func() error) {
t.init()
t.m.Lock()
defer t.m.Unlock()
select {
case <-t.dead:
panic("tomb.Go called after all goroutines terminated")
default:
}
t.alive++
go t.run(f)
}
func (t *Tomb) run(f func() error) {
err := f()
t.m.Lock()
defer t.m.Unlock()
t.alive--
if t.alive == 0 || err != nil {
t.kill(err)
if t.alive == 0 {
close(t.dead)
}
}
}
// Kill puts the tomb in a dying state for the given reason,
// closes the Dying channel, and sets Alive to false.
//
// Althoguh Kill may be called multiple times, only the first
// non-nil error is recorded as the death reason.
//
// If reason is ErrDying, the previous reason isn't replaced
// even if nil. It's a runtime error to call Kill with ErrDying
// if t is not in a dying state.
func (t *Tomb) Kill(reason error) {
t.init()
t.m.Lock()
defer t.m.Unlock()
t.kill(reason)
}
func (t *Tomb) kill(reason error) {
if reason == ErrStillAlive {
panic("tomb: Kill with ErrStillAlive")
}
if reason == ErrDying {
if t.reason == ErrStillAlive {
panic("tomb: Kill with ErrDying while still alive")
}
return
}
if t.reason == ErrStillAlive {
t.reason = reason
close(t.dying)
return
}
if t.reason == nil {
t.reason = reason
return
}
}
// Killf calls the Kill method with an error built providing the received
// parameters to fmt.Errorf. The generated error is also returned.
func (t *Tomb) Killf(f string, a ...interface{}) error {
err := fmt.Errorf(f, a...)
t.Kill(err)
return err
}
// Err returns the death reason, or ErrStillAlive if the tomb
// is not in a dying or dead state.
func (t *Tomb) Err() (reason error) {
t.init()
t.m.Lock()
reason = t.reason
t.m.Unlock()
return
}
// Alive returns true if the tomb is not in a dying or dead state.
func (t *Tomb) Alive() bool {
return t.Err() == ErrStillAlive
}
......@@ -860,6 +860,12 @@
"path": "gopkg.in/tomb.v1",
"revision": "dd632973f1e7218eb1089048e0798ec9ae7dceb8",
"revisionTime": "2014-10-24T13:56:13Z"
},
{
"checksumSHA1": "WiyCOMvfzRdymImAJ3ME6aoYUdM=",
"path": "gopkg.in/tomb.v2",
"revision": "14b3d72120e8d10ea6e6b7f87f7175734b1faab8",
"revisionTime": "2014-06-26T14:46:23Z"
}
],
"rootPath": "github.com/hashicorp/nomad"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment