Commit a6b07473 authored by Alex Dadgar's avatar Alex Dadgar Committed by GitHub
Browse files

Merge pull request #1828 from hashicorp/f-vault-options

Vault token renewal errors handled by client
parents 917c7e50 70d0e6e1
Showing with 1279 additions and 773 deletions
+1279 -773
......@@ -173,12 +173,13 @@ type Template struct {
ChangeMode string
ChangeSignal string
Splay time.Duration
Once bool
}
type Vault struct {
Policies []string
Env bool
Policies []string
Env bool
ChangeMode string
ChangeSignal string
}
// NewTask creates and initializes a new Task.
......
......@@ -2,7 +2,6 @@ package client
import (
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
......@@ -31,10 +30,6 @@ const (
// watchdogInterval is the interval at which resource constraints for the
// allocation are being checked and enforced.
watchdogInterval = 5 * time.Second
// vaultTokenFile is the name of the file holding the Vault token inside the
// task's secret directory
vaultTokenFile = "vault_token"
)
// AllocStateUpdater is used to update the status of an allocation
......@@ -69,7 +64,6 @@ type AllocRunner struct {
updateCh chan *structs.Allocation
vaultClient vaultclient.VaultClient
vaultTokens map[string]vaultToken
otherAllocDir *allocdir.AllocDir
......@@ -145,9 +139,6 @@ func (r *AllocRunner) RestoreState() error {
return e
}
// Recover the Vault tokens
vaultErr := r.recoverVaultTokens()
// Restore the task runners
var mErr multierror.Error
for name, state := range r.taskStates {
......@@ -156,13 +147,9 @@ func (r *AllocRunner) RestoreState() error {
task := &structs.Task{Name: name}
tr := NewTaskRunner(r.logger, r.config, r.setTaskState, r.ctx, r.Alloc(),
task)
task, r.vaultClient)
r.tasks[name] = tr
if vt, ok := r.vaultTokens[name]; ok {
tr.SetVaultToken(vt.token, vt.renewalCh)
}
// Skip tasks in terminal states.
if state.State == structs.TaskStateDead {
continue
......@@ -177,20 +164,6 @@ func (r *AllocRunner) RestoreState() error {
}
}
// Since this is somewhat of an expected case we do not return an error but
// handle it gracefully.
if vaultErr != nil {
msg := fmt.Sprintf("failed to recover Vault tokens for allocation %q: %v", r.alloc.ID, vaultErr)
r.logger.Printf("[ERR] client: %s", msg)
r.setStatus(structs.AllocClientStatusFailed, msg)
// Destroy the task runners and set the error
r.destroyTaskRunners(structs.NewTaskEvent(structs.TaskVaultRenewalFailed).SetVaultRenewalError(vaultErr))
// Handle cleanup
go r.handleDestroy()
}
return mErr.ErrorOrNil()
}
......@@ -376,13 +349,6 @@ func (r *AllocRunner) setTaskState(taskName, state string, event *structs.TaskEv
r.appendTaskEvent(taskState, event)
if state == structs.TaskStateDead {
// If the task has a Vault token, stop renewing it
if vt, ok := r.vaultTokens[taskName]; ok {
if err := r.vaultClient.StopRenewToken(vt.token); err != nil {
r.logger.Printf("[ERR] client: stopping token renewal for task %q failed: %v", taskName, err)
}
}
// If the task failed, we should kill all the other tasks in the task group.
if taskState.Failed() {
var destroyingTasks []string
......@@ -467,15 +433,6 @@ func (r *AllocRunner) Run() {
return
}
// Request Vault tokens for the tasks that require them
err := r.deriveVaultTokens()
if err != nil {
msg := fmt.Sprintf("failed to derive Vault token for allocation %q: %v", r.alloc.ID, err)
r.logger.Printf("[ERR] client: %s", msg)
r.setStatus(structs.AllocClientStatusFailed, msg)
return
}
// Start the task runners
r.logger.Printf("[DEBUG] client: starting task runners for alloc '%s'", r.alloc.ID)
r.taskLock.Lock()
......@@ -484,15 +441,10 @@ func (r *AllocRunner) Run() {
continue
}
tr := NewTaskRunner(r.logger, r.config, r.setTaskState, r.ctx, r.Alloc(), task.Copy())
tr := NewTaskRunner(r.logger, r.config, r.setTaskState, r.ctx, r.Alloc(), task.Copy(), r.vaultClient)
r.tasks[task.Name] = tr
tr.MarkReceived()
// If the task has a vault token set it before running
if vt, ok := r.vaultTokens[task.Name]; ok {
tr.SetVaultToken(vt.token, vt.renewalCh)
}
go tr.Run()
}
r.taskLock.Unlock()
......@@ -575,149 +527,6 @@ func (r *AllocRunner) destroyTaskRunners(destroyEvent *structs.TaskEvent) {
r.syncStatus()
}
// vaultToken acts as a tuple of the token and renewal channel
type vaultToken struct {
token string
renewalCh <-chan error
}
// deriveVaultTokens derives the required vault tokens and returns a map of the
// tasks to their respective vault token and renewal channel. This must be
// called after the allocation directory is created as the vault tokens are
// written to disk.
func (r *AllocRunner) deriveVaultTokens() error {
required, err := r.tasksRequiringVaultTokens()
if err != nil {
return err
}
if len(required) == 0 {
return nil
}
if r.vaultTokens == nil {
r.vaultTokens = make(map[string]vaultToken, len(required))
}
// Get the tokens
tokens, err := r.vaultClient.DeriveToken(r.Alloc(), required)
if err != nil {
return fmt.Errorf("failed to derive Vault tokens: %v", err)
}
// Persist the tokens to the appropriate secret directories
adir := r.ctx.AllocDir
for task, token := range tokens {
// Has been recovered
if _, ok := r.vaultTokens[task]; ok {
continue
}
secretDir, err := adir.GetSecretDir(task)
if err != nil {
return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err)
}
// Write the token to the file system
tokenPath := filepath.Join(secretDir, vaultTokenFile)
if err := ioutil.WriteFile(tokenPath, []byte(token), 0777); err != nil {
return fmt.Errorf("failed to save Vault tokens to secret dir for task %q in alloc %q: %v", task, r.alloc.ID, err)
}
// Start renewing the token
renewCh, err := r.vaultClient.RenewToken(token, 10)
if err != nil {
var mErr multierror.Error
errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err)
multierror.Append(&mErr, errMsg)
// Clean up any token that we have started renewing
for _, token := range r.vaultTokens {
if err := r.vaultClient.StopRenewToken(token.token); err != nil {
multierror.Append(&mErr, err)
}
}
return mErr.ErrorOrNil()
}
r.vaultTokens[task] = vaultToken{token: token, renewalCh: renewCh}
}
return nil
}
// tasksRequiringVaultTokens returns the set of tasks that require a Vault token
func (r *AllocRunner) tasksRequiringVaultTokens() ([]string, error) {
// Get the tasks
tg := r.alloc.Job.LookupTaskGroup(r.alloc.TaskGroup)
if tg == nil {
return nil, fmt.Errorf("Failed to lookup task group in alloc")
}
// Retrieve any required Vault tokens
var required []string
for _, task := range tg.Tasks {
if task.Vault != nil && len(task.Vault.Policies) != 0 {
required = append(required, task.Name)
}
}
return required, nil
}
// recoverVaultTokens reads the Vault tokens for the tasks that have Vault
// tokens off disk. If there is an error, it is returned, otherwise token
// renewal is started.
func (r *AllocRunner) recoverVaultTokens() error {
required, err := r.tasksRequiringVaultTokens()
if err != nil {
return err
}
if len(required) == 0 {
return nil
}
// Read the tokens and start renewing them
adir := r.ctx.AllocDir
renewingTokens := make(map[string]vaultToken, len(required))
for _, task := range required {
secretDir, err := adir.GetSecretDir(task)
if err != nil {
return fmt.Errorf("failed to determine task %s secret dir in alloc %q: %v", task, r.alloc.ID, err)
}
// Read the token from the secret directory
tokenPath := filepath.Join(secretDir, vaultTokenFile)
data, err := ioutil.ReadFile(tokenPath)
if err != nil {
return fmt.Errorf("failed to read token for task %q in alloc %q: %v", task, r.alloc.ID, err)
}
token := string(data)
renewCh, err := r.vaultClient.RenewToken(token, 10)
if err != nil {
var mErr multierror.Error
errMsg := fmt.Errorf("failed to renew Vault token for task %q in alloc %q: %v", task, r.alloc.ID, err)
multierror.Append(&mErr, errMsg)
// Clean up any token that we have started renewing
for _, token := range renewingTokens {
if err := r.vaultClient.StopRenewToken(token.token); err != nil {
multierror.Append(&mErr, err)
}
}
return mErr.ErrorOrNil()
}
renewingTokens[task] = vaultToken{token: token, renewalCh: renewCh}
}
r.vaultTokens = renewingTokens
return nil
}
// checkResources monitors and enforces alloc resource usage. It returns an
// appropriate task event describing why the allocation had to be killed.
func (r *AllocRunner) checkResources() (*structs.TaskEvent, string) {
......
......@@ -620,248 +620,6 @@ func TestAllocRunner_TaskFailed_KillTG(t *testing.T) {
})
}
func TestAllocRunner_SimpleRun_VaultToken(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{"exit_code": "0"}
task.Vault = &structs.Vault{
Policies: []string{"default"},
}
upd, ar := testAllocRunnerFromAlloc(alloc, false)
go ar.Run()
defer ar.Destroy()
testutil.WaitForResult(func() (bool, error) {
if upd.Count == 0 {
return false, fmt.Errorf("No updates")
}
last := upd.Allocs[upd.Count-1]
if last.ClientStatus != structs.AllocClientStatusComplete {
return false, fmt.Errorf("got status %v; want %v", last.ClientStatus, structs.AllocClientStatusComplete)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v", err)
})
tr, ok := ar.tasks[task.Name]
if !ok {
t.Fatalf("No task runner made")
}
// Check that the task runner was given the token
token := tr.vaultToken
if token == "" || tr.vaultRenewalCh == nil {
t.Fatalf("Vault token not set properly")
}
// Check that it was written to disk
secretDir, err := ar.ctx.AllocDir.GetSecretDir(task.Name)
if err != nil {
t.Fatalf("bad: %v", err)
}
tokenPath := filepath.Join(secretDir, vaultTokenFile)
data, err := ioutil.ReadFile(tokenPath)
if err != nil {
t.Fatalf("token not written to disk: %v", err)
}
if string(data) != token {
t.Fatalf("Bad token written to disk")
}
// Check that we stopped renewing the token
mockVC := ar.vaultClient.(*vaultclient.MockVaultClient)
if len(mockVC.StoppedTokens) != 1 || mockVC.StoppedTokens[0] != token {
t.Fatalf("We didn't stop renewing the token")
}
}
func TestAllocRunner_SaveRestoreState_VaultTokens_Valid(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "10s",
}
task.Vault = &structs.Vault{
Policies: []string{"default"},
}
upd, ar := testAllocRunnerFromAlloc(alloc, false)
go ar.Run()
// Snapshot state
var token string
testutil.WaitForResult(func() (bool, error) {
if len(ar.tasks) != 1 {
return false, fmt.Errorf("Task not started")
}
tr, ok := ar.tasks[task.Name]
if !ok {
return false, fmt.Errorf("Incorrect task runner")
}
if tr.vaultToken == "" {
return false, fmt.Errorf("Bad token")
}
token = tr.vaultToken
return true, nil
}, func(err error) {
t.Fatalf("task never started: %v", err)
})
err := ar.SaveState()
if err != nil {
t.Fatalf("err: %v", err)
}
// Create a new alloc runner
ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update,
&structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient)
err = ar2.RestoreState()
if err != nil {
t.Fatalf("err: %v", err)
}
go ar2.Run()
testutil.WaitForResult(func() (bool, error) {
if len(ar2.tasks) != 1 {
return false, fmt.Errorf("Incorrect number of tasks")
}
tr, ok := ar2.tasks[task.Name]
if !ok {
return false, fmt.Errorf("Incorrect task runner")
}
if tr.vaultToken != token {
return false, fmt.Errorf("Got token %q; want %q", tr.vaultToken, token)
}
if upd.Count == 0 {
return false, nil
}
last := upd.Allocs[upd.Count-1]
return last.ClientStatus == structs.AllocClientStatusRunning, nil
}, func(err error) {
t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates)
})
// Destroy and wait
ar2.Destroy()
start := time.Now()
testutil.WaitForResult(func() (bool, error) {
alloc := ar2.Alloc()
if alloc.ClientStatus != structs.AllocClientStatusComplete {
return false, fmt.Errorf("Bad client status; got %v; want %v", alloc.ClientStatus, structs.AllocClientStatusComplete)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates)
})
if time.Since(start) > time.Duration(testutil.TestMultiplier()*5)*time.Second {
t.Fatalf("took too long to terminate")
}
}
func TestAllocRunner_SaveRestoreState_VaultTokens_Invalid(t *testing.T) {
alloc := mock.Alloc()
task := alloc.Job.TaskGroups[0].Tasks[0]
task.Driver = "mock_driver"
task.Config = map[string]interface{}{
"exit_code": "0",
"run_for": "10s",
}
task.Vault = &structs.Vault{
Policies: []string{"default"},
}
upd, ar := testAllocRunnerFromAlloc(alloc, false)
go ar.Run()
// Snapshot state
var token string
testutil.WaitForResult(func() (bool, error) {
if len(ar.tasks) != 1 {
return false, fmt.Errorf("Task not started")
}
tr, ok := ar.tasks[task.Name]
if !ok {
return false, fmt.Errorf("Incorrect task runner")
}
if tr.vaultToken == "" {
return false, fmt.Errorf("Bad token")
}
token = tr.vaultToken
return true, nil
}, func(err error) {
t.Fatalf("task never started: %v", err)
})
err := ar.SaveState()
if err != nil {
t.Fatalf("err: %v", err)
}
// Create a new alloc runner
ar2 := NewAllocRunner(ar.logger, ar.config, upd.Update,
&structs.Allocation{ID: ar.alloc.ID}, ar.vaultClient)
// Invalidate the token
mockVC := ar2.vaultClient.(*vaultclient.MockVaultClient)
renewErr := fmt.Errorf("Test disallowing renewal")
mockVC.SetRenewTokenError(token, renewErr)
// Restore and run
err = ar2.RestoreState()
if err != nil {
t.Fatalf("err: %v", err)
}
go ar2.Run()
testutil.WaitForResult(func() (bool, error) {
if upd.Count == 0 {
return false, nil
}
last := upd.Allocs[upd.Count-1]
return last.ClientStatus == structs.AllocClientStatusFailed, nil
}, func(err error) {
t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates)
})
// Destroy and wait
ar2.Destroy()
start := time.Now()
testutil.WaitForResult(func() (bool, error) {
alloc := ar2.Alloc()
if alloc.ClientStatus != structs.AllocClientStatusFailed {
return false, fmt.Errorf("Bad client status; got %v; want %v", alloc.ClientStatus, structs.AllocClientStatusFailed)
}
return true, nil
}, func(err error) {
t.Fatalf("err: %v %#v %#v", err, upd.Allocs[0], ar.alloc.TaskStates)
})
if time.Since(start) > time.Duration(testutil.TestMultiplier()*5)*time.Second {
t.Fatalf("took too long to terminate")
}
}
func TestAllocRunner_MoveAllocDir(t *testing.T) {
// Create an alloc runner
alloc := mock.Alloc()
......
......@@ -141,6 +141,7 @@ func (tm *TaskTemplateManager) Stop() {
// run is the long lived loop that handles errors and templates being rendered
func (tm *TaskTemplateManager) run() {
// Runner is nil if there is no templates
if tm.runner == nil {
// Unblock the start if there is nothing to do
if !tm.allRendered {
......@@ -189,6 +190,10 @@ func (tm *TaskTemplateManager) run() {
break WAIT
}
// TODO Thinking, I believe we could check every 30 seconds and if
// they are all would be rendered we should start anyways. That is
// the reattach mechanism when they have all been rendered
}
allRenderedTime = time.Now()
......
......@@ -1021,7 +1021,6 @@ func (h *DockerHandle) Signal(s os.Signal) error {
ID: h.containerID,
Signal: dockerSignal,
}
h.logger.Printf("Sending: %v", dockerSignal)
return h.client.KillContainer(opts)
}
......
......@@ -101,6 +101,19 @@ func (r *RestartTracker) GetState() (string, time.Duration) {
r.lock.Lock()
defer r.lock.Unlock()
// Clear out the existing state
defer func() {
r.startErr = nil
r.waitRes = nil
r.restartTriggered = false
}()
// Hot path if a restart was triggered
if r.restartTriggered {
r.reason = ""
return structs.TaskRestarting, 0
}
// Hot path if no attempts are expected
if r.policy.Attempts == 0 {
r.reason = ReasonNoRestartsAllowed
......@@ -121,25 +134,13 @@ func (r *RestartTracker) GetState() (string, time.Duration) {
r.startTime = now
}
var state string
var dur time.Duration
if r.startErr != nil {
state, dur = r.handleStartError()
return r.handleStartError()
} else if r.waitRes != nil {
state, dur = r.handleWaitResult()
} else if r.restartTriggered {
state, dur = structs.TaskRestarting, 0
r.reason = ""
} else {
state, dur = "", 0
return r.handleWaitResult()
}
// Clear out the existing state
r.startErr = nil
r.waitRes = nil
r.restartTriggered = false
return state, dur
return "", 0
}
// handleStartError returns the new state and potential wait duration for
......
This diff is collapsed.
This diff is collapsed.
......@@ -350,19 +350,21 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
c.lock.Lock()
defer c.lock.Unlock()
if !c.config.IsEnabled() {
return fmt.Errorf("vault client not enabled")
}
if !c.running {
return fmt.Errorf("vault client is not running")
}
if req == nil {
return fmt.Errorf("nil renewal request")
}
if req.errCh == nil {
return fmt.Errorf("renewal request error channel nil")
}
if !c.config.IsEnabled() {
close(req.errCh)
return fmt.Errorf("vault client not enabled")
}
if !c.running {
close(req.errCh)
return fmt.Errorf("vault client is not running")
}
if req.id == "" {
close(req.errCh)
return fmt.Errorf("missing id in renewal request")
......
......@@ -21,12 +21,21 @@ type MockVaultClient struct {
// DeriveTokenErrors maps an allocation ID and tasks to an error when the
// token is derived
DeriveTokenErrors map[string]map[string]error
// DeriveTokenFn allows the caller to control the DeriveToken function. If
// not set an error is returned if found in DeriveTokenErrors and otherwise
// a token is generated and returned
DeriveTokenFn func(a *structs.Allocation, tasks []string) (map[string]string, error)
}
// NewMockVaultClient returns a MockVaultClient for testing
func NewMockVaultClient() *MockVaultClient { return &MockVaultClient{} }
func (vc *MockVaultClient) DeriveToken(a *structs.Allocation, tasks []string) (map[string]string, error) {
if vc.DeriveTokenFn != nil {
return vc.DeriveTokenFn(a, tasks)
}
tokens := make(map[string]string, len(tasks))
for _, task := range tasks {
if tasks, ok := vc.DeriveTokenErrors[a.ID]; ok {
......
......@@ -209,8 +209,8 @@ func parseJob(result *structs.Job, list *ast.ObjectList) error {
// If we have a vault block, then parse that
if o := listVal.Filter("vault"); len(o.Items) > 0 {
var jobVault structs.Vault
if err := parseVault(&jobVault, o); err != nil {
jobVault := structs.DefaultVaultBlock()
if err := parseVault(jobVault, o); err != nil {
return multierror.Prefix(err, "vault ->")
}
......@@ -218,7 +218,7 @@ func parseJob(result *structs.Job, list *ast.ObjectList) error {
for _, tg := range result.TaskGroups {
for _, task := range tg.Tasks {
if task.Vault == nil {
task.Vault = &jobVault
task.Vault = jobVault
}
}
}
......@@ -335,15 +335,15 @@ func parseGroups(result *structs.Job, list *ast.ObjectList) error {
// If we have a vault block, then parse that
if o := listVal.Filter("vault"); len(o.Items) > 0 {
var tgVault structs.Vault
if err := parseVault(&tgVault, o); err != nil {
tgVault := structs.DefaultVaultBlock()
if err := parseVault(tgVault, o); err != nil {
return multierror.Prefix(err, fmt.Sprintf("'%s', vault ->", n))
}
// Go through the tasks and if they don't have a Vault block, set it
for _, task := range g.Tasks {
if task.Vault == nil {
task.Vault = &tgVault
task.Vault = tgVault
}
}
}
......@@ -717,12 +717,12 @@ func parseTasks(jobName string, taskGroupName string, result *[]*structs.Task, l
// If we have a vault block, then parse that
if o := listVal.Filter("vault"); len(o.Items) > 0 {
var v structs.Vault
if err := parseVault(&v, o); err != nil {
v := structs.DefaultVaultBlock()
if err := parseVault(v, o); err != nil {
return multierror.Prefix(err, fmt.Sprintf("'%s', vault ->", n))
}
t.Vault = &v
t.Vault = v
}
*result = append(*result, &t)
......@@ -1177,6 +1177,8 @@ func parseVault(result *structs.Vault, list *ast.ObjectList) error {
valid := []string{
"policies",
"env",
"change_mode",
"change_signal",
}
if err := checkHCLKeys(listVal, valid); err != nil {
return multierror.Prefix(err, "vault ->")
......@@ -1187,11 +1189,6 @@ func parseVault(result *structs.Vault, list *ast.ObjectList) error {
return err
}
// Default the env bool
if _, ok := m["env"]; !ok {
m["env"] = true
}
if err := mapstructure.WeakDecode(m, result); err != nil {
return err
}
......
......@@ -159,8 +159,9 @@ func TestParse(t *testing.T) {
},
},
Vault: &structs.Vault{
Policies: []string{"foo", "bar"},
Env: true,
Policies: []string{"foo", "bar"},
Env: true,
ChangeMode: structs.VaultChangeModeRestart,
},
Templates: []*structs.Template{
{
......@@ -199,6 +200,12 @@ func TestParse(t *testing.T) {
},
},
LogConfig: structs.DefaultLogConfig(),
Vault: &structs.Vault{
Policies: []string{"foo", "bar"},
Env: false,
ChangeMode: structs.VaultChangeModeSignal,
ChangeSignal: "SIGUSR1",
},
},
},
},
......@@ -475,16 +482,18 @@ func TestParse(t *testing.T) {
Name: "redis",
LogConfig: structs.DefaultLogConfig(),
Vault: &structs.Vault{
Policies: []string{"group"},
Env: true,
Policies: []string{"group"},
Env: true,
ChangeMode: structs.VaultChangeModeRestart,
},
},
&structs.Task{
Name: "redis2",
LogConfig: structs.DefaultLogConfig(),
Vault: &structs.Vault{
Policies: []string{"task"},
Env: false,
Policies: []string{"task"},
Env: false,
ChangeMode: structs.VaultChangeModeRestart,
},
},
},
......@@ -498,8 +507,9 @@ func TestParse(t *testing.T) {
Name: "redis",
LogConfig: structs.DefaultLogConfig(),
Vault: &structs.Vault{
Policies: []string{"job"},
Env: true,
Policies: []string{"job"},
Env: true,
ChangeMode: structs.VaultChangeModeRestart,
},
},
},
......@@ -526,6 +536,10 @@ func TestParse(t *testing.T) {
}
if !reflect.DeepEqual(actual, tc.Result) {
diff, err := actual.Diff(tc.Result, true)
if err == nil {
t.Logf("file %s diff:\n%#v\n", tc.File, diff)
}
t.Fatalf("file: %s\n\n%#v\n\n%#v", tc.File, actual, tc.Result)
}
}
......
......@@ -165,6 +165,13 @@ job "binstore-storagelocker" {
attribute = "kernel.arch"
value = "amd64"
}
vault {
policies = ["foo", "bar"]
env = false
change_mode = "signal"
change_signal = "SIGUSR1"
}
}
constraint {
......
......@@ -373,7 +373,10 @@ func TestJobEndpoint_Register_Vault_Disabled(t *testing.T) {
// Create the register request with a job asking for a vault policy
job := mock.Job()
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{"foo"}}
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{
Policies: []string{"foo"},
ChangeMode: structs.VaultChangeModeRestart,
}
req := &structs.JobRegisterRequest{
Job: job,
WriteRequest: structs.WriteRequest{Region: "global"},
......@@ -405,7 +408,10 @@ func TestJobEndpoint_Register_Vault_AllowUnauthenticated(t *testing.T) {
// Create the register request with a job asking for a vault policy
job := mock.Job()
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{"foo"}}
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{
Policies: []string{"foo"},
ChangeMode: structs.VaultChangeModeRestart,
}
req := &structs.JobRegisterRequest{
Job: job,
WriteRequest: structs.WriteRequest{Region: "global"},
......@@ -451,7 +457,10 @@ func TestJobEndpoint_Register_Vault_NoToken(t *testing.T) {
// Create the register request with a job asking for a vault policy but
// don't send a Vault token
job := mock.Job()
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{"foo"}}
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{
Policies: []string{"foo"},
ChangeMode: structs.VaultChangeModeRestart,
}
req := &structs.JobRegisterRequest{
Job: job,
WriteRequest: structs.WriteRequest{Region: "global"},
......@@ -506,7 +515,10 @@ func TestJobEndpoint_Register_Vault_Policies(t *testing.T) {
// send the bad Vault token
job := mock.Job()
job.VaultToken = badToken
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{policy}}
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{
Policies: []string{policy},
ChangeMode: structs.VaultChangeModeRestart,
}
req := &structs.JobRegisterRequest{
Job: job,
WriteRequest: structs.WriteRequest{Region: "global"},
......@@ -565,7 +577,10 @@ func TestJobEndpoint_Register_Vault_Policies(t *testing.T) {
// send the root Vault token
job2 := mock.Job()
job2.VaultToken = rootToken
job2.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{policy}}
job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{
Policies: []string{policy},
ChangeMode: structs.VaultChangeModeRestart,
}
req = &structs.JobRegisterRequest{
Job: job2,
WriteRequest: structs.WriteRequest{Region: "global"},
......
......@@ -3032,8 +3032,10 @@ func TestTaskDiff(t *testing.T) {
Old: &Task{},
New: &Task{
Vault: &Vault{
Policies: []string{"foo", "bar"},
Env: true,
Policies: []string{"foo", "bar"},
Env: true,
ChangeMode: "signal",
ChangeSignal: "SIGUSR1",
},
},
Expected: &TaskDiff{
......@@ -3043,6 +3045,18 @@ func TestTaskDiff(t *testing.T) {
Type: DiffTypeAdded,
Name: "Vault",
Fields: []*FieldDiff{
{
Type: DiffTypeAdded,
Name: "ChangeMode",
Old: "",
New: "signal",
},
{
Type: DiffTypeAdded,
Name: "ChangeSignal",
Old: "",
New: "SIGUSR1",
},
{
Type: DiffTypeAdded,
Name: "Env",
......@@ -3078,8 +3092,10 @@ func TestTaskDiff(t *testing.T) {
// Vault deleted
Old: &Task{
Vault: &Vault{
Policies: []string{"foo", "bar"},
Env: true,
Policies: []string{"foo", "bar"},
Env: true,
ChangeMode: "signal",
ChangeSignal: "SIGUSR1",
},
},
New: &Task{},
......@@ -3090,6 +3106,18 @@ func TestTaskDiff(t *testing.T) {
Type: DiffTypeDeleted,
Name: "Vault",
Fields: []*FieldDiff{
{
Type: DiffTypeDeleted,
Name: "ChangeMode",
Old: "signal",
New: "",
},
{
Type: DiffTypeDeleted,
Name: "ChangeSignal",
Old: "SIGUSR1",
New: "",
},
{
Type: DiffTypeDeleted,
Name: "Env",
......@@ -3125,14 +3153,18 @@ func TestTaskDiff(t *testing.T) {
// Vault edited
Old: &Task{
Vault: &Vault{
Policies: []string{"foo", "bar"},
Env: true,
Policies: []string{"foo", "bar"},
Env: true,
ChangeMode: "signal",
ChangeSignal: "SIGUSR1",
},
},
New: &Task{
Vault: &Vault{
Policies: []string{"bar", "baz"},
Env: false,
Policies: []string{"bar", "baz"},
Env: false,
ChangeMode: "restart",
ChangeSignal: "foo",
},
},
Expected: &TaskDiff{
......@@ -3142,6 +3174,18 @@ func TestTaskDiff(t *testing.T) {
Type: DiffTypeEdited,
Name: "Vault",
Fields: []*FieldDiff{
{
Type: DiffTypeEdited,
Name: "ChangeMode",
Old: "signal",
New: "restart",
},
{
Type: DiffTypeEdited,
Name: "ChangeSignal",
Old: "SIGUSR1",
New: "foo",
},
{
Type: DiffTypeEdited,
Name: "Env",
......@@ -3174,18 +3218,22 @@ func TestTaskDiff(t *testing.T) {
},
},
{
// LogConfig edited with context
// Vault edited with context
Contextual: true,
Old: &Task{
Vault: &Vault{
Policies: []string{"foo", "bar"},
Env: true,
Policies: []string{"foo", "bar"},
Env: true,
ChangeMode: "signal",
ChangeSignal: "SIGUSR1",
},
},
New: &Task{
Vault: &Vault{
Policies: []string{"bar", "baz"},
Env: true,
Policies: []string{"bar", "baz"},
Env: true,
ChangeMode: "signal",
ChangeSignal: "SIGUSR1",
},
},
Expected: &TaskDiff{
......@@ -3195,6 +3243,18 @@ func TestTaskDiff(t *testing.T) {
Type: DiffTypeEdited,
Name: "Vault",
Fields: []*FieldDiff{
{
Type: DiffTypeNone,
Name: "ChangeMode",
Old: "signal",
New: "signal",
},
{
Type: DiffTypeNone,
Name: "ChangeSignal",
Old: "SIGUSR1",
New: "SIGUSR1",
},
{
Type: DiffTypeNone,
Name: "Env",
......
......@@ -2818,6 +2818,17 @@ func (d *EphemeralDisk) Copy() *EphemeralDisk {
return ld
}
const (
// VaultChangeModeNoop takes no action when a new token is retrieved.
VaultChangeModeNoop = "noop"
// VaultChangeModeSignal signals the task when a new token is retrieved.
VaultChangeModeSignal = "signal"
// VaultChangeModeRestart restarts the task when a new token is retrieved.
VaultChangeModeRestart = "restart"
)
// Vault stores the set of premissions a task needs access to from Vault.
type Vault struct {
// Policies is the set of policies that the task needs access to
......@@ -2826,6 +2837,21 @@ type Vault struct {
// Env marks whether the Vault Token should be exposed as an environment
// variable
Env bool
// ChangeMode is used to configure the task's behavior when the Vault
// token changes because the original token could not be renewed in time.
ChangeMode string `mapstructure:"change_mode"`
// ChangeSignal is the signal sent to the task when a new token is
// retrieved. This is only valid when using the signal change mode.
ChangeSignal string `mapstructure:"change_signal"`
}
func DefaultVaultBlock() *Vault {
return &Vault{
Env: true,
ChangeMode: VaultChangeModeRestart,
}
}
// Copy returns a copy of this Vault block.
......@@ -2849,6 +2875,16 @@ func (v *Vault) Validate() error {
return fmt.Errorf("Policy list can not be empty")
}
switch v.ChangeMode {
case VaultChangeModeSignal:
if v.ChangeSignal == "" {
return fmt.Errorf("Signal must be specified when using change mode %q", VaultChangeModeSignal)
}
case VaultChangeModeNoop, VaultChangeModeRestart:
default:
return fmt.Errorf("Unknown change mode %q", v.ChangeMode)
}
return nil
}
......
......@@ -1307,3 +1307,21 @@ func TestAllocation_Terminated(t *testing.T) {
}
}
}
func TestVault_Validate(t *testing.T) {
v := &Vault{
Env: true,
ChangeMode: VaultChangeModeNoop,
}
if err := v.Validate(); err == nil || !strings.Contains(err.Error(), "Policy list") {
t.Fatalf("Expected policy list empty error")
}
v.Policies = []string{"foo"}
v.ChangeMode = VaultChangeModeSignal
if err := v.Validate(); err == nil || !strings.Contains(err.Error(), "Signal must") {
t.Fatalf("Expected signal empty error")
}
}
......@@ -367,6 +367,9 @@ func tasksUpdated(a, b *structs.TaskGroup) bool {
if !reflect.DeepEqual(at.Vault, bt.Vault) {
return true
}
if !reflect.DeepEqual(at.Templates, bt.Templates) {
return true
}
// Inspect the network to see if the dynamic ports are different
if len(at.Resources.Networks) != len(bt.Resources.Networks) {
......
......@@ -139,22 +139,24 @@ func (c *Config) Copy() *Config {
if c.Vault.SSL != nil {
config.Vault.SSL = &SSLConfig{
Enabled: c.Vault.SSL.Enabled,
Verify: c.Vault.SSL.Verify,
Cert: c.Vault.SSL.Cert,
Key: c.Vault.SSL.Key,
CaCert: c.Vault.SSL.CaCert,
Enabled: c.Vault.SSL.Enabled,
Verify: c.Vault.SSL.Verify,
Cert: c.Vault.SSL.Cert,
Key: c.Vault.SSL.Key,
CaCert: c.Vault.SSL.CaCert,
ServerName: c.Vault.SSL.ServerName,
}
}
}
if c.SSL != nil {
config.SSL = &SSLConfig{
Enabled: c.SSL.Enabled,
Verify: c.SSL.Verify,
Cert: c.SSL.Cert,
Key: c.SSL.Key,
CaCert: c.SSL.CaCert,
Enabled: c.SSL.Enabled,
Verify: c.SSL.Verify,
Cert: c.SSL.Cert,
Key: c.SSL.Key,
CaCert: c.SSL.CaCert,
ServerName: c.SSL.ServerName,
}
}
......@@ -284,6 +286,9 @@ func (c *Config) Merge(config *Config) {
if config.WasSet("vault.ssl.enabled") {
c.Vault.SSL.Enabled = config.Vault.SSL.Enabled
}
if config.WasSet("vault.ssl.server_name") {
c.Vault.SSL.ServerName = config.Vault.SSL.ServerName
}
}
}
......@@ -327,6 +332,9 @@ func (c *Config) Merge(config *Config) {
if config.WasSet("ssl.enabled") {
c.SSL.Enabled = config.SSL.Enabled
}
if config.WasSet("ssl.server_name") {
c.SSL.ServerName = config.SSL.ServerName
}
}
if config.WasSet("syslog") {
......@@ -447,27 +455,20 @@ func (c *Config) Set(key string) {
}
}
// ParseConfig reads the configuration file at the given path and returns a new
// Config struct with the data populated.
func ParseConfig(path string) (*Config, error) {
// Parse parses the given string contents as a config
func Parse(s string) (*Config, error) {
var errs *multierror.Error
// Read the contents of the file
contents, err := ioutil.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("error reading config at %q: %s", path, err)
}
// Parse the file (could be HCL or JSON)
var shadow interface{}
if err := hcl.Decode(&shadow, string(contents)); err != nil {
return nil, fmt.Errorf("error decoding config at %q: %s", path, err)
if err := hcl.Decode(&shadow, s); err != nil {
return nil, fmt.Errorf("error decoding config: %s", err)
}
// Convert to a map and flatten the keys we want to flatten
parsed, ok := shadow.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("error converting config at %q", path)
return nil, fmt.Errorf("error converting config")
}
flattenKeys(parsed, []string{
"auth",
......@@ -514,9 +515,6 @@ func ParseConfig(path string) (*Config, error) {
return nil, errs.ErrorOrNil()
}
// Store a reference to the path where this config was read from
config.Path = path
// Explicitly check for the nil signal and set the value back to nil
if config.ReloadSignal == signals.SIGNIL {
config.ReloadSignal = nil
......@@ -573,9 +571,30 @@ func ParseConfig(path string) (*Config, error) {
return config, errs.ErrorOrNil()
}
// ConfigFromPath iterates and merges all configuration files in a given
// Must returns a config object that must compile. If there are any errors, this
// function will panic. This is most useful in testing or constants.
func Must(s string) *Config {
c, err := Parse(s)
if err != nil {
panic(err)
}
return c
}
// FromFile reads the configuration file at the given path and returns a new
// Config struct with the data populated.
func FromFile(path string) (*Config, error) {
c, err := ioutil.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("error reading config at %q: %s", path, err)
}
return Parse(string(c))
}
// FromPath iterates and merges all configuration files in a given
// directory, returning the resulting config.
func ConfigFromPath(path string) (*Config, error) {
func FromPath(path string) (*Config, error) {
// Ensure the given filepath exists
if _, err := os.Stat(path); os.IsNotExist(err) {
return nil, fmt.Errorf("config: missing file/folder: %s", path)
......@@ -611,7 +630,7 @@ func ConfigFromPath(path string) (*Config, error) {
}
// Parse and merge the config
newConfig, err := ParseConfig(path)
newConfig, err := FromFile(path)
if err != nil {
return err
}
......@@ -626,7 +645,7 @@ func ConfigFromPath(path string) (*Config, error) {
return config, nil
} else if stat.Mode().IsRegular() {
return ParseConfig(path)
return FromFile(path)
}
return nil, fmt.Errorf("config: unknown filetype: %q", stat.Mode().String())
......@@ -710,6 +729,10 @@ func DefaultConfig() *Config {
config.Vault.SSL.Verify = false
}
if v := os.Getenv("VAULT_TLS_SERVER_NAME"); v != "" {
config.Vault.SSL.ServerName = v
}
return config
}
......@@ -773,11 +796,12 @@ type DeduplicateConfig struct {
// SSLConfig is the configuration for SSL.
type SSLConfig struct {
Enabled bool `mapstructure:"enabled"`
Verify bool `mapstructure:"verify"`
Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
CaCert string `mapstructure:"ca_cert"`
Enabled bool `mapstructure:"enabled"`
Verify bool `mapstructure:"verify"`
Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
CaCert string `mapstructure:"ca_cert"`
ServerName string `mapstructure:"server_name"`
}
// SyslogConfig is the configuration for syslog.
......
package config
import (
"io/ioutil"
"os"
"testing"
)
func TestConfig(contents string, t *testing.T) *Config {
f, err := ioutil.TempFile(os.TempDir(), "")
if err != nil {
t.Fatal(err)
}
_, err = f.Write([]byte(contents))
if err != nil {
t.Fatal(err)
}
config, err := ParseConfig(f.Name())
if err != nil {
t.Fatal(err)
}
return config
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment