Unverified Commit 915c1176 authored by Alex Dadgar's avatar Alex Dadgar Committed by GitHub
Browse files

Merge pull request #4506 from hashicorp/f-arv2-vault

Basic Vault hook 
parents f7a6132a 6d92c2da
Branches unavailable
No related merge requests found
Showing with 535 additions and 34 deletions
+535 -34
......@@ -14,6 +14,7 @@ import (
"github.com/hashicorp/nomad/client/allocrunnerv2/taskrunner"
"github.com/hashicorp/nomad/client/config"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/nomad/structs"
)
......@@ -24,6 +25,9 @@ type allocRunner struct {
clientConfig *config.Config
// vaultClient is the used to manage Vault tokens
vaultClient vaultclient.VaultClient
// waitCh is closed when the alloc runner has transitioned to a terminal
// state
waitCh chan struct{}
......@@ -55,8 +59,9 @@ type allocRunner struct {
// NewAllocRunner returns a new allocation runner.
func NewAllocRunner(config *Config) *allocRunner {
ar := &allocRunner{
alloc: config.Alloc,
clientConfig: config.ClientConfig,
vaultClient: config.Vault,
alloc: config.Alloc,
tasks: make(map[string]*taskrunner.TaskRunner),
waitCh: make(chan struct{}),
updateCh: make(chan *structs.Allocation),
......@@ -164,6 +169,7 @@ func (ar *allocRunner) runTask(alloc *structs.Allocation, task *structs.Task) er
TaskDir: ar.allocDir.NewTaskDir(task.Name),
Logger: ar.logger,
StateDB: ar.stateDB,
VaultClient: ar.vaultClient,
}
tr, err := taskrunner.NewTaskRunner(config)
if err != nil {
......
package allocrunnerv2
import (
"context"
"testing"
"github.com/hashicorp/nomad/client/allocrunnerv2/config"
"github.com/hashicorp/nomad/client/allocrunnerv2/interfaces"
clientconfig "github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/helper/testlog"
......@@ -15,19 +13,14 @@ import (
func testAllocRunnerFromAlloc(t *testing.T, alloc *structs.Allocation) *allocRunner {
cconf := clientconfig.DefaultConfig()
config := &config.Config{
config := &Config{
ClientConfig: cconf,
Logger: testlog.HCLogger(t).With("unit_test", t.Name()),
Allocation: alloc,
}
ar, err := NewAllocRunner(context.Background(), config)
if err != nil {
t.Fatalf("Failed to create test alloc runner: %v", err)
Alloc: alloc,
}
ar := NewAllocRunner(config)
return ar
}
func testAllocRunner(t *testing.T) *allocRunner {
......
......@@ -4,6 +4,7 @@ import (
"github.com/boltdb/bolt"
log "github.com/hashicorp/go-hclog"
clientconfig "github.com/hashicorp/nomad/client/config"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/nomad/structs"
)
......@@ -21,6 +22,9 @@ type Config struct {
// StateDB is used to store and restore state.
StateDB *bolt.DB
// Vault is the Vault client to use to retrieve Vault tokens
Vault vaultclient.VaultClient
// XXX Can have a OnStateTransistion hook that we can use to update the
// server
}
......@@ -72,8 +72,7 @@ type TaskPoststopResponse struct{}
type TaskPoststopHook interface {
TaskHook
Postrun() error
//Postrun(context.Context, *TaskPostrunRequest, *TaskPostrunResponse) error
Postrun(context.Context, *TaskPostrunRequest, *TaskPostrunResponse) error
}
type TaskDestroyRequest struct{}
......@@ -86,13 +85,11 @@ type TaskDestroyHook interface {
}
type TaskUpdateRequest struct {
Alloc string
Vault string // Don't need message bus then
VaultToken string
}
type TaskUpdateResponse struct{}
type TaskUpdateHook interface {
TaskHook
Update() error
//Update(context.Context, *TaskUpdateRequest, *TaskUpdateResponse) error
Update(context.Context, *TaskUpdateRequest, *TaskUpdateResponse) error
}
package taskrunner
import "os"
// XXX These should probably all return an error and we should have predefined
// error types for the task not currently running
type TaskLifecycle interface {
Restart(source, reason string, failure bool)
Signal(source, reason string, s os.Signal) error
Kill(source, reason string, fail bool)
}
func (tr *TaskRunner) Restart(source, reason string, failure bool) {
// TODO
}
func (tr *TaskRunner) Signal(source, reason string, s os.Signal) error {
// TODO
return nil
}
func (tr *TaskRunner) Kill(source, reason string, fail bool) {
// TODO
}
......@@ -10,9 +10,6 @@ import (
type LocalState struct {
Hooks map[string]*HookState
// VaultToken is the current Vault token for the task
VaultToken string
// DriverNetwork is the network information returned by the task
// driver's Start method
DriverNetwork *structs.DriverNetwork
......@@ -29,7 +26,6 @@ func (s *LocalState) Copy() *LocalState {
// Create a copy
c := &LocalState{
Hooks: make(map[string]*HookState, len(s.Hooks)),
VaultToken: s.VaultToken,
DriverNetwork: s.DriverNetwork,
}
......
......@@ -19,6 +19,7 @@ import (
"github.com/hashicorp/nomad/client/driver"
"github.com/hashicorp/nomad/client/driver/env"
oldstate "github.com/hashicorp/nomad/client/state"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/ugorji/go/codec"
"golang.org/x/crypto/blake2b"
......@@ -46,7 +47,8 @@ type TaskRunner struct {
// localState captures the node-local state of the task for when the
// Nomad agent restarts
localState *state.LocalState
localState *state.LocalState
localStateLock sync.RWMutex
// stateDB is for persisting localState
stateDB *bolt.DB
......@@ -99,6 +101,14 @@ type TaskRunner struct {
// transistions.
runnerHooks []interfaces.TaskHook
// vaultClient is the client to use to derive and renew Vault tokens
vaultClient vaultclient.VaultClient
// vaultToken is the current Vault token. It should be accessed with the
// getter.
vaultToken string
vaultTokenLock sync.Mutex
// baseLabels are used when emitting tagged metrics. All task runner metrics
// will have these tags, and optionally more.
baseLabels []metrics.Label
......@@ -111,6 +121,9 @@ type Config struct {
TaskDir *allocdir.TaskDir
Logger log.Logger
// VaultClient is the client to use to derive and renew Vault tokens
VaultClient vaultclient.VaultClient
// LocalState is optionally restored task state
LocalState *state.LocalState
......@@ -138,6 +151,7 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) {
taskDir: config.TaskDir,
taskName: config.Task.Name,
envBuilder: envBuilder,
vaultClient: config.VaultClient,
//XXX Make a Copy to avoid races?
state: config.Alloc.TaskStates[config.Task.Name],
localState: config.LocalState,
......@@ -374,7 +388,10 @@ func (tr *TaskRunner) persistLocalState() error {
w := io.MultiWriter(h, &buf)
// Encode as msgpack value
if err := codec.NewEncoder(w, structs.MsgpackHandle).Encode(&tr.localState); err != nil {
tr.localStateLock.Lock()
err = codec.NewEncoder(w, structs.MsgpackHandle).Encode(&tr.localState)
tr.localStateLock.Unlock()
if err != nil {
return fmt.Errorf("failed to serialize snapshot: %v", err)
}
......@@ -509,12 +526,6 @@ func (tr *TaskRunner) Shutdown() {
tr.ctxCancel()
}
func (tr *TaskRunner) Alloc() *structs.Allocation {
tr.allocLock.Lock()
defer tr.allocLock.Unlock()
return tr.alloc
}
// appendTaskEvent updates the task status by appending the new event.
func appendTaskEvent(state *structs.TaskState, event *structs.TaskEvent) {
capacity := 10
......
......@@ -2,8 +2,26 @@ package taskrunner
import "github.com/hashicorp/nomad/nomad/structs"
func (tr *TaskRunner) Alloc() *structs.Allocation {
tr.allocLock.Lock()
defer tr.allocLock.Unlock()
return tr.alloc
}
func (tr *TaskRunner) Task() *structs.Task {
tr.taskLock.RLock()
defer tr.taskLock.RUnlock()
return tr.task
}
func (tr *TaskRunner) getVaultToken() string {
tr.vaultTokenLock.Lock()
defer tr.vaultTokenLock.Unlock()
return tr.vaultToken
}
func (tr *TaskRunner) setVaultToken(token string) {
tr.vaultTokenLock.Lock()
defer tr.vaultTokenLock.Unlock()
tr.vaultToken = token
}
......@@ -15,6 +15,10 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
)
type EventEmitter interface {
SetState(state string, event *structs.TaskEvent)
}
// initHooks intializes the tasks hooks.
func (tr *TaskRunner) initHooks() {
hookLogger := tr.logger.Named("task_hook")
......@@ -25,6 +29,20 @@ func (tr *TaskRunner) initHooks() {
newTaskDirHook(tr, hookLogger),
newArtifactHook(tr, hookLogger),
}
// If Vault is enabled, add the hook
if task := tr.Task(); task.Vault != nil {
tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{
vaultStanza: task.Vault,
client: tr.vaultClient,
events: tr,
lifecycle: tr,
updater: tr,
logger: hookLogger,
alloc: tr.Alloc(),
task: tr.taskName,
}))
}
}
// prerun is used to run the runners prerun hooks.
......@@ -63,13 +81,15 @@ func (tr *TaskRunner) prerun() error {
TaskEnv: tr.envBuilder.Build(),
}
tr.localStateLock.RLock()
origHookState := tr.localState.Hooks[name]
tr.localStateLock.RUnlock()
if origHookState != nil && origHookState.PrerunDone {
tr.logger.Trace("skipping done prerun hook", "name", pre.Name())
continue
}
req.VaultToken = tr.localState.VaultToken
req.VaultToken = tr.getVaultToken()
// Time the prerun hook
var start time.Time
......@@ -86,6 +106,7 @@ func (tr *TaskRunner) prerun() error {
// Store the hook state
{
tr.localStateLock.Lock()
hookState, ok := tr.localState.Hooks[name]
if !ok {
hookState = &state.HookState{}
......@@ -96,6 +117,7 @@ func (tr *TaskRunner) prerun() error {
hookState.Data = resp.HookData
hookState.PrerunDone = resp.Done
}
tr.localStateLock.Unlock()
// Persist local state if the hook state has changed
if !hookState.Equal(origHookState) {
......@@ -195,6 +217,51 @@ func (tr *TaskRunner) shutdown() error {
return nil
}
// update is used to run the runners update hooks.
func (tr *TaskRunner) updateHooks() {
if tr.logger.IsTrace() {
start := time.Now()
tr.logger.Trace("running update hooks", "start", start)
defer func() {
end := time.Now()
tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start))
}()
}
for _, hook := range tr.runnerHooks {
upd, ok := hook.(interfaces.TaskUpdateHook)
if !ok {
tr.logger.Trace("skipping non-update hook", "name", hook.Name())
continue
}
name := upd.Name()
// Build the request
req := interfaces.TaskUpdateRequest{
VaultToken: tr.getVaultToken(),
}
// Time the prerun hook
var start time.Time
if tr.logger.IsTrace() {
start = time.Now()
tr.logger.Trace("running update hook", "name", name, "start", start)
}
// Run the update hook
var resp interfaces.TaskUpdateResponse
if err := upd.Update(tr.ctx, &req, &resp); err != nil {
tr.logger.Error("update hook failed", "name", name, "error", err)
}
if tr.logger.IsTrace() {
end := time.Now()
tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start))
}
}
}
type taskDirHook struct {
runner *TaskRunner
logger log.Logger
......@@ -235,10 +302,6 @@ func (h *taskDirHook) Prerun(ctx context.Context, req *interfaces.TaskPrerunRequ
return nil
}
type EventEmitter interface {
SetState(state string, event *structs.TaskEvent)
}
// artifactHook downloads artifacts for a task.
type artifactHook struct {
eventEmitter EventEmitter
......
package taskrunner
import (
"context"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"time"
"github.com/hashicorp/consul-template/signals"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocdir"
"github.com/hashicorp/nomad/client/allocrunnerv2/interfaces"
"github.com/hashicorp/nomad/client/vaultclient"
"github.com/hashicorp/nomad/nomad/structs"
)
const (
// vaultBackoffBaseline is the baseline time for exponential backoff when
// attempting to retrieve a Vault token
vaultBackoffBaseline = 5 * time.Second
// vaultBackoffLimit is the limit of the exponential backoff when attempting
// to retrieve a Vault token
vaultBackoffLimit = 3 * time.Minute
// vaultTokenFile is the name of the file holding the Vault token inside the
// task's secret directory
vaultTokenFile = "vault_token"
)
type vaultTokenUpdateHandler interface {
updatedVaultToken(token string)
}
func (tr *TaskRunner) updatedVaultToken(token string) {
// Update the Vault token on the runner
tr.setVaultToken(token)
// Update the tasks environment
tr.envBuilder.SetVaultToken(token, tr.task.Vault.Env)
// Update the hooks with the new Vault token
tr.updateHooks()
}
type vaultHookConfig struct {
vaultStanza *structs.Vault
client vaultclient.VaultClient
events EventEmitter
lifecycle TaskLifecycle
updater vaultTokenUpdateHandler
logger log.Logger
alloc *structs.Allocation
task string
}
type vaultHook struct {
// vaultStanza is the vault stanza for the task
vaultStanza *structs.Vault
// eventEmitter is used to emit events to the task
eventEmitter EventEmitter
// lifecycle is used to signal, restart and kill a task
lifecycle TaskLifecycle
// updater is used to update the Vault token
updater vaultTokenUpdateHandler
// client is the Vault client to retrieve and renew the Vault token
client vaultclient.VaultClient
// logger is used to log
logger log.Logger
// ctx and cancel are used to kill the long running token manager
ctx context.Context
cancel context.CancelFunc
// tokenPath is the path in which to read and write the token
tokenPath string
// alloc is the allocation
alloc *structs.Allocation
// taskName is the name of the task
taskName string
// firstRun stores whether it is the first run for the hook
firstRun bool
// future is used to wait on retrieving a Vault token
future *tokenFuture
}
func newVaultHook(config *vaultHookConfig) *vaultHook {
ctx, cancel := context.WithCancel(context.Background())
h := &vaultHook{
vaultStanza: config.vaultStanza,
client: config.client,
eventEmitter: config.events,
lifecycle: config.lifecycle,
updater: config.updater,
alloc: config.alloc,
taskName: config.task,
firstRun: true,
ctx: ctx,
cancel: cancel,
future: newTokenFuture(),
}
h.logger = config.logger.Named(h.Name())
return h
}
func (*vaultHook) Name() string {
return "vault"
}
func (h *vaultHook) Prerun(ctx context.Context, req *interfaces.TaskPrerunRequest, resp *interfaces.TaskPrerunResponse) error {
// If we have already run prerun before exit early. We do not use the
// PrerunDone value because we want to recover the token on restoration.
first := h.firstRun
h.firstRun = false
if !first {
return nil
}
// Try to recover a token if it was previously written in the secrets
// directory
recoveredToken := ""
h.tokenPath = filepath.Join(req.TaskDir, allocdir.TaskSecrets, vaultTokenFile)
data, err := ioutil.ReadFile(h.tokenPath)
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("failed to recover vault token: %v", err)
}
// Token file doesn't exist
} else {
// Store the recovered token
recoveredToken = string(data)
}
// Launch the token manager
go h.run(recoveredToken)
// Block until we get a token
select {
case <-h.future.Wait():
case <-ctx.Done():
return nil
}
h.updater.updatedVaultToken(h.future.Get())
return nil
}
func (h *vaultHook) Poststop(ctx context.Context, req *interfaces.TaskPoststopRequest, resp *interfaces.TaskPoststopResponse) error {
// Shutdown any created manager
h.cancel()
return nil
}
// run should be called in a go-routine and manages the derivation, renewal and
// handling of errors with the Vault token. The optional parameter allows
// setting the initial Vault token. This is useful when the Vault token is
// recovered off disk.
func (h *vaultHook) run(token string) {
// Helper for stopping token renewal
stopRenewal := func() {
if err := h.client.StopRenewToken(h.future.Get()); err != nil {
h.logger.Warn("failed to stop token renewal", "error", err)
}
}
// updatedToken lets us store state between loops. If true, a new token
// has been retrieved and we need to apply the Vault change mode
var updatedToken bool
OUTER:
for {
// Check if we should exit
if h.ctx.Err() != nil {
stopRenewal()
return
}
// Clear the token
h.future.Clear()
// Check if there already is a token which can be the case for
// restoring the TaskRunner
if token == "" {
// Get a token
var exit bool
token, exit = h.deriveVaultToken()
if exit {
// Exit the manager
return
}
// Write the token to disk
if err := h.writeToken(token); err != nil {
errorString := "failed to write Vault token to disk"
h.logger.Error(errorString, "error", err)
h.lifecycle.Kill("vault", errorString, true)
return
}
}
// Start the renewal process
renewCh, err := h.client.RenewToken(token, 30)
// An error returned means the token is not being renewed
if err != nil {
h.logger.Error("failed to start renewal of Vault token", "error", err)
token = ""
goto OUTER
}
// The Vault token is valid now, so set it
h.future.Set(token)
if updatedToken {
switch h.vaultStanza.ChangeMode {
case structs.VaultChangeModeSignal:
s, err := signals.Parse(h.vaultStanza.ChangeSignal)
if err != nil {
h.logger.Error("failed to parse signal", "error", err)
h.lifecycle.Kill("vault", fmt.Sprintf("failed to parse signal: %v", err), true)
return
}
if err := h.lifecycle.Signal("vault", "new Vault token acquired", s); err != nil {
h.logger.Error("failed to send signal", "error", err)
h.lifecycle.Kill("vault", fmt.Sprintf("failed to send signal: %v", err), true)
return
}
case structs.VaultChangeModeRestart:
const noFailure = false
h.lifecycle.Restart("vault", "new Vault token acquired", noFailure)
case structs.VaultChangeModeNoop:
fallthrough
default:
h.logger.Error("invalid Vault change mode", "mode", h.vaultStanza.ChangeMode)
}
// We have handled it
updatedToken = false
// Call the handler
h.updater.updatedVaultToken(token)
}
// Start watching for renewal errors
select {
case err := <-renewCh:
// Clear the token
token = ""
h.logger.Error("failed to renew Vault token", "error", err)
stopRenewal()
// Check if we have to do anything
if h.vaultStanza.ChangeMode != structs.VaultChangeModeNoop {
updatedToken = true
}
case <-h.ctx.Done():
stopRenewal()
return
}
}
}
// deriveVaultToken derives the Vault token using exponential backoffs. It
// returns the Vault token and whether the manager should exit.
func (h *vaultHook) deriveVaultToken() (token string, exit bool) {
attempts := 0
for {
tokens, err := h.client.DeriveToken(h.alloc, []string{h.taskName})
if err == nil {
return tokens[h.taskName], false
}
// Check if this is a server side error
if structs.IsServerSide(err) {
h.logger.Error("failed to derive Vault token", "error", err, "server_side", true)
h.lifecycle.Kill("vault", fmt.Sprintf("server error deriving vault token: %v", err), true)
return "", true
}
// Check if we can't recover from the error
if !structs.IsRecoverable(err) {
h.logger.Error("failed to derive Vault token", "error", err, "recoverable", false)
h.lifecycle.Kill("vault", fmt.Sprintf("failed to derive token: %v", err), true)
return "", true
}
// Handle the retry case
backoff := (1 << (2 * uint64(attempts))) * vaultBackoffBaseline
if backoff > vaultBackoffLimit {
backoff = vaultBackoffLimit
}
h.logger.Error("failed to derive Vault token", "error", err, "recoverable", true, "backoff", backoff)
attempts++
// Wait till retrying
select {
case <-h.ctx.Done():
return "", true
case <-time.After(backoff):
}
}
}
// writeToken writes the given token to disk
func (h *vaultHook) writeToken(token string) error {
if err := ioutil.WriteFile(h.tokenPath, []byte(token), 0666); err != nil {
return fmt.Errorf("failed to write vault token: %v", err)
}
return nil
}
// tokenFuture stores the Vault token and allows consumers to block till a valid
// token exists
type tokenFuture struct {
waiting []chan struct{}
token string
set bool
m sync.Mutex
}
// newTokenFuture returns a new token future without any token set
func newTokenFuture() *tokenFuture {
return &tokenFuture{}
}
// Wait returns a channel that can be waited on. When this channel unblocks, a
// valid token will be available via the Get method
func (f *tokenFuture) Wait() <-chan struct{} {
f.m.Lock()
defer f.m.Unlock()
c := make(chan struct{})
if f.set {
close(c)
return c
}
f.waiting = append(f.waiting, c)
return c
}
// Set sets the token value and unblocks any caller of Wait
func (f *tokenFuture) Set(token string) *tokenFuture {
f.m.Lock()
defer f.m.Unlock()
f.set = true
f.token = token
for _, w := range f.waiting {
close(w)
}
f.waiting = nil
return f
}
// Clear clears the set vault token.
func (f *tokenFuture) Clear() *tokenFuture {
f.m.Lock()
defer f.m.Unlock()
f.token = ""
f.set = false
return f
}
// Get returns the set Vault token
func (f *tokenFuture) Get() string {
f.m.Lock()
defer f.m.Unlock()
return f.token
}
......@@ -1961,6 +1961,7 @@ func (c *Client) addAlloc(alloc *structs.Allocation, migrateToken string) error
Logger: logger,
ClientConfig: c.config,
StateDB: c.stateDB,
Vault: c.vaultClient,
}
c.configLock.RUnlock()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment