diff --git a/client/consul_template_test.go b/client/consul_template_test.go index 88ee17b0e50d6478ae43bbcdda99035b322b04ba..f8368520407dbcaac72a9b08e98b1b91281e4db0 100644 --- a/client/consul_template_test.go +++ b/client/consul_template_test.go @@ -57,7 +57,7 @@ func NewMockTaskHooks() *MockTaskHooks { EmitEventCh: make(chan struct{}, 1), } } -func (m *MockTaskHooks) Restart(source, reason string) { +func (m *MockTaskHooks) Restart(source, reason string, failure bool) { m.Restarts++ select { case m.RestartCh <- struct{}{}: diff --git a/client/consul_test.go b/client/consul_test.go index 10d1ebe10d044ad736a417a8896f426c62295622..8703cdd215a98b351e8b8e4f89f98c32e11aa2dc 100644 --- a/client/consul_test.go +++ b/client/consul_test.go @@ -60,7 +60,7 @@ func newMockConsulServiceClient() *mockConsulServiceClient { return &m } -func (m *mockConsulServiceClient) UpdateTask(allocID string, old, new *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { +func (m *mockConsulServiceClient) UpdateTask(allocID string, old, new *structs.Task, restarter consul.TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { m.mu.Lock() defer m.mu.Unlock() m.logger.Printf("[TEST] mock_consul: UpdateTask(%q, %v, %v, %T, %x)", allocID, old, new, exec, net.Hash()) @@ -68,7 +68,7 @@ func (m *mockConsulServiceClient) UpdateTask(allocID string, old, new *structs.T return nil } -func (m *mockConsulServiceClient) RegisterTask(allocID string, task *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { +func (m *mockConsulServiceClient) RegisterTask(allocID string, task *structs.Task, restarter consul.TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { m.mu.Lock() defer m.mu.Unlock() m.logger.Printf("[TEST] mock_consul: RegisterTask(%q, %q, %T, %x)", allocID, task.Name, exec, net.Hash()) diff --git a/client/task_runner.go b/client/task_runner.go index cd7f40b3908fdb44903d51b479f2e2af8af66849..cc1bf5aa61eed47587cb7780264ca327ae2a712d 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -75,9 +75,9 @@ type taskRestartEvent struct { failure bool } -func newTaskRestartEvent(source, reason string, failure bool) *taskRestartEvent { +func newTaskRestartEvent(reason string, failure bool) *taskRestartEvent { return &taskRestartEvent{ - taskEvent: structs.NewTaskEvent(source).SetRestartReason(reason), + taskEvent: structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason(reason), failure: failure, } } @@ -1708,7 +1708,7 @@ func (r *TaskRunner) handleDestroy(handle driver.DriverHandle) (destroyed bool, // Restart will restart the task. func (r *TaskRunner) Restart(source, reason string, failure bool) { reasonStr := fmt.Sprintf("%s: %s", source, reason) - event := newTaskRestartEvent(source, reasonStr, failure) + event := newTaskRestartEvent(reasonStr, failure) select { case r.restartCh <- event: diff --git a/client/task_runner_test.go b/client/task_runner_test.go index 6894115e3376fdedcd84f0b8ac54514d232c584a..f532e77df457c3191203df958e35952e0e65ffb0 100644 --- a/client/task_runner_test.go +++ b/client/task_runner_test.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/nomad/client/driver/env" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/client/vaultclient" + "github.com/hashicorp/nomad/command/agent/consul" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" @@ -56,10 +57,21 @@ func (m *MockTaskStateUpdater) Update(name, state string, event *structs.TaskEve } } +// String for debugging purposes. +func (m *MockTaskStateUpdater) String() string { + s := fmt.Sprintf("Updates:\n state=%q\n failed=%t\n events=\n", m.state, m.failed) + for _, e := range m.events { + s += fmt.Sprintf(" %#v\n", e) + } + return s +} + type taskRunnerTestCtx struct { upd *MockTaskStateUpdater tr *TaskRunner allocDir *allocdir.AllocDir + vault *vaultclient.MockVaultClient + consul *mockConsulServiceClient } // Cleanup calls Destroy on the task runner and alloc dir @@ -130,7 +142,13 @@ func testTaskRunnerFromAlloc(t *testing.T, restarts bool, alloc *structs.Allocat if !restarts { tr.restartTracker = noRestartsTracker() } - return &taskRunnerTestCtx{upd, tr, allocDir} + return &taskRunnerTestCtx{ + upd: upd, + tr: tr, + allocDir: allocDir, + vault: vclient, + consul: cclient, + } } // testWaitForTaskToStart waits for the task to or fails the test @@ -657,7 +675,7 @@ func TestTaskRunner_RestartTask(t *testing.T) { // Wait for it to start go func() { testWaitForTaskToStart(t, ctx) - ctx.tr.Restart("test", "restart") + ctx.tr.Restart("test", "restart", false) // Wait for it to restart then kill go func() { @@ -1251,8 +1269,7 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { }) // Error the token renewal - vc := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) - renewalCh, ok := vc.RenewTokens[token] + renewalCh, ok := ctx.vault.RenewTokens[token] if !ok { t.Fatalf("no renewal channel") } @@ -1279,13 +1296,12 @@ func TestTaskRunner_Template_NewVaultToken(t *testing.T) { }) // Check the token was revoked - m := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) testutil.WaitForResult(func() (bool, error) { - if len(m.StoppedTokens) != 1 { - return false, fmt.Errorf("Expected a stopped token: %v", m.StoppedTokens) + if len(ctx.vault.StoppedTokens) != 1 { + return false, fmt.Errorf("Expected a stopped token: %v", ctx.vault.StoppedTokens) } - if a := m.StoppedTokens[0]; a != token { + if a := ctx.vault.StoppedTokens[0]; a != token { return false, fmt.Errorf("got stopped token %q; want %q", a, token) } return true, nil @@ -1317,8 +1333,7 @@ func TestTaskRunner_VaultManager_Restart(t *testing.T) { testWaitForTaskToStart(t, ctx) // Error the token renewal - vc := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) - renewalCh, ok := vc.RenewTokens[ctx.tr.vaultFuture.Get()] + renewalCh, ok := ctx.vault.RenewTokens[ctx.tr.vaultFuture.Get()] if !ok { t.Fatalf("no renewal channel") } @@ -1394,8 +1409,7 @@ func TestTaskRunner_VaultManager_Signal(t *testing.T) { testWaitForTaskToStart(t, ctx) // Error the token renewal - vc := ctx.tr.vaultClient.(*vaultclient.MockVaultClient) - renewalCh, ok := vc.RenewTokens[ctx.tr.vaultFuture.Get()] + renewalCh, ok := ctx.vault.RenewTokens[ctx.tr.vaultFuture.Get()] if !ok { t.Fatalf("no renewal channel") } @@ -1726,20 +1740,19 @@ func TestTaskRunner_ShutdownDelay(t *testing.T) { // Service should get removed quickly; loop until RemoveTask is called found := false - mockConsul := ctx.tr.consul.(*mockConsulServiceClient) deadline := destroyed.Add(task.ShutdownDelay) for time.Now().Before(deadline) { time.Sleep(5 * time.Millisecond) - mockConsul.mu.Lock() - n := len(mockConsul.ops) + ctx.consul.mu.Lock() + n := len(ctx.consul.ops) if n < 2 { - mockConsul.mu.Unlock() + ctx.consul.mu.Unlock() continue } - lastOp := mockConsul.ops[n-1].op - mockConsul.mu.Unlock() + lastOp := ctx.consul.ops[n-1].op + ctx.consul.mu.Unlock() if lastOp == "remove" { found = true @@ -1762,3 +1775,97 @@ func TestTaskRunner_ShutdownDelay(t *testing.T) { t.Fatalf("task exited before shutdown delay") } } + +// TestTaskRunner_CheckWatcher_Restart asserts that when enabled an unhealthy +// Consul check will cause a task to restart following restart policy rules. +func TestTaskRunner_CheckWatcher_Restart(t *testing.T) { + t.Parallel() + + alloc := mock.Alloc() + + // Make the restart policy fail within this test + tg := alloc.Job.TaskGroups[0] + tg.RestartPolicy.Attempts = 2 + tg.RestartPolicy.Interval = 1 * time.Minute + tg.RestartPolicy.Delay = 10 * time.Millisecond + tg.RestartPolicy.Mode = structs.RestartPolicyModeFail + + task := tg.Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "exit_code": "0", + "run_for": "100s", + } + + // Make the task register a check that fails + task.Services[0].Checks[0] = &structs.ServiceCheck{ + Name: "test-restarts", + Type: structs.ServiceCheckTCP, + Interval: 50 * time.Millisecond, + CheckRestart: &structs.CheckRestart{ + Limit: 2, + Grace: 100 * time.Millisecond, + }, + } + + ctx := testTaskRunnerFromAlloc(t, true, alloc) + + // Replace mock Consul ServiceClient, with the real ServiceClient + // backed by a mock consul whose checks are always unhealthy. + consulAgent := consul.NewMockAgent() + consulAgent.SetStatus("critical") + consulClient := consul.NewServiceClient(consulAgent, true, ctx.tr.logger) + go consulClient.Run() + defer consulClient.Shutdown() + + ctx.tr.consul = consulClient + ctx.consul = nil // prevent accidental use of old mock + + ctx.tr.MarkReceived() + go ctx.tr.Run() + defer ctx.Cleanup() + + select { + case <-ctx.tr.WaitCh(): + case <-time.After(time.Duration(testutil.TestMultiplier()*15) * time.Second): + t.Fatalf("timeout") + } + + expected := []string{ + "Received", + "Task Setup", + "Started", + "Restart Signaled", + "Killing", + "Killed", + "Restarting", + "Started", + "Restart Signaled", + "Killing", + "Killed", + "Restarting", + "Started", + "Restart Signaled", + "Killing", + "Killed", + "Not Restarting", + } + + if n := len(ctx.upd.events); n != len(expected) { + t.Fatalf("should have %d ctx.updates found %d: %s", len(expected), n, ctx.upd) + } + + if ctx.upd.state != structs.TaskStateDead { + t.Fatalf("TaskState %v; want %v", ctx.upd.state, structs.TaskStateDead) + } + + if !ctx.upd.failed { + t.Fatalf("expected failed") + } + + for i, actual := range ctx.upd.events { + if actual.Type != expected[i] { + t.Errorf("%.2d - Expected %q but found %q", i, expected[i], actual.Type) + } + } +} diff --git a/client/task_runner_unix_test.go b/client/task_runner_unix_test.go index bed7c956d790e0c05e5a686795e72c51f400d4ab..b7c2aa4412bf5bd9800c5bd1cb3bb792dabc81a5 100644 --- a/client/task_runner_unix_test.go +++ b/client/task_runner_unix_test.go @@ -53,7 +53,7 @@ func TestTaskRunner_RestartSignalTask_NotRunning(t *testing.T) { } // Send a restart - ctx.tr.Restart("test", "don't panic") + ctx.tr.Restart("test", "don't panic", false) if len(ctx.upd.events) != 2 { t.Fatalf("should have 2 ctx.updates: %#v", ctx.upd.events) diff --git a/command/agent/consul/catalog_testing.go b/command/agent/consul/catalog_testing.go index f0dd0326ce0ffb8c47b5496efa1528928220b90c..6b28940f1144eb4cfe7fc246f3f8e241572f0c58 100644 --- a/command/agent/consul/catalog_testing.go +++ b/command/agent/consul/catalog_testing.go @@ -1,7 +1,9 @@ package consul import ( + "fmt" "log" + "sync" "github.com/hashicorp/consul/api" ) @@ -25,3 +27,119 @@ func (m *MockCatalog) Service(service, tag string, q *api.QueryOptions) ([]*api. m.logger.Printf("[DEBUG] mock_consul: Service(%q, %q, %#v) -> (nil, nil, nil)", service, tag, q) return nil, nil, nil } + +// MockAgent is a fake in-memory Consul backend for ServiceClient. +type MockAgent struct { + // maps of what services and checks have been registered + services map[string]*api.AgentServiceRegistration + checks map[string]*api.AgentCheckRegistration + mu sync.Mutex + + // when UpdateTTL is called the check ID will have its counter inc'd + checkTTLs map[string]int + + // What check status to return from Checks() + checkStatus string +} + +// NewMockAgent that returns all checks as passing. +func NewMockAgent() *MockAgent { + return &MockAgent{ + services: make(map[string]*api.AgentServiceRegistration), + checks: make(map[string]*api.AgentCheckRegistration), + checkTTLs: make(map[string]int), + checkStatus: api.HealthPassing, + } +} + +// SetStatus that Checks() should return. Returns old status value. +func (c *MockAgent) SetStatus(s string) string { + c.mu.Lock() + old := c.checkStatus + c.checkStatus = s + c.mu.Unlock() + return old +} + +func (c *MockAgent) Services() (map[string]*api.AgentService, error) { + c.mu.Lock() + defer c.mu.Unlock() + + r := make(map[string]*api.AgentService, len(c.services)) + for k, v := range c.services { + r[k] = &api.AgentService{ + ID: v.ID, + Service: v.Name, + Tags: make([]string, len(v.Tags)), + Port: v.Port, + Address: v.Address, + EnableTagOverride: v.EnableTagOverride, + } + copy(r[k].Tags, v.Tags) + } + return r, nil +} + +func (c *MockAgent) Checks() (map[string]*api.AgentCheck, error) { + c.mu.Lock() + defer c.mu.Unlock() + + r := make(map[string]*api.AgentCheck, len(c.checks)) + for k, v := range c.checks { + r[k] = &api.AgentCheck{ + CheckID: v.ID, + Name: v.Name, + Status: c.checkStatus, + Notes: v.Notes, + ServiceID: v.ServiceID, + ServiceName: c.services[v.ServiceID].Name, + } + } + return r, nil +} + +func (c *MockAgent) CheckRegister(check *api.AgentCheckRegistration) error { + c.mu.Lock() + defer c.mu.Unlock() + c.checks[check.ID] = check + + // Be nice and make checks reachable-by-service + scheck := check.AgentServiceCheck + c.services[check.ServiceID].Checks = append(c.services[check.ServiceID].Checks, &scheck) + return nil +} + +func (c *MockAgent) CheckDeregister(checkID string) error { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.checks, checkID) + delete(c.checkTTLs, checkID) + return nil +} + +func (c *MockAgent) ServiceRegister(service *api.AgentServiceRegistration) error { + c.mu.Lock() + defer c.mu.Unlock() + c.services[service.ID] = service + return nil +} + +func (c *MockAgent) ServiceDeregister(serviceID string) error { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.services, serviceID) + return nil +} + +func (c *MockAgent) UpdateTTL(id string, output string, status string) error { + c.mu.Lock() + defer c.mu.Unlock() + check, ok := c.checks[id] + if !ok { + return fmt.Errorf("unknown check id: %q", id) + } + // Flip initial status to passing + check.Status = "passing" + c.checkTTLs[id]++ + return nil +} diff --git a/command/agent/consul/check_watcher.go b/command/agent/consul/check_watcher.go index af19bb0198462b6c51cb923c036bd56efb8634dc..ee28c87fb37d1643427d83c4b97b24fdd9bf2a40 100644 --- a/command/agent/consul/check_watcher.go +++ b/command/agent/consul/check_watcher.go @@ -62,17 +62,25 @@ type checkRestart struct { // timestamp is passed in so all check updates have the same view of time (and // to ease testing). func (c *checkRestart) update(now time.Time, status string) { + healthy := func() { + if !c.unhealthyStart.IsZero() { + c.logger.Printf("[DEBUG] consul.health: alloc %q task %q check %q became healthy; canceling restart", + c.allocID, c.taskName, c.checkName) + c.unhealthyStart = time.Time{} + } + return + } switch status { case api.HealthCritical: case api.HealthWarning: if c.ignoreWarnings { // Warnings are ignored, reset state and exit - c.unhealthyStart = time.Time{} + healthy() return } default: // All other statuses are ok, reset state and exit - c.unhealthyStart = time.Time{} + healthy() return } @@ -83,8 +91,10 @@ func (c *checkRestart) update(now time.Time, status string) { if c.unhealthyStart.IsZero() { // First failure, set restart deadline - c.logger.Printf("[DEBUG] consul.health: alloc %q task %q check %q became unhealthy. Restarting in %s if not healthy", - c.allocID, c.taskName, c.checkName, c.timeLimit) + if c.timeLimit != 0 { + c.logger.Printf("[DEBUG] consul.health: alloc %q task %q check %q became unhealthy. Restarting in %s if not healthy", + c.allocID, c.taskName, c.checkName, c.timeLimit) + } c.unhealthyStart = now } @@ -150,12 +160,6 @@ func (w *checkWatcher) Run(ctx context.Context) { // timer for check polling checkTimer := time.NewTimer(0) defer checkTimer.Stop() // ensure timer is never leaked - resetTimer := func(d time.Duration) { - if !checkTimer.Stop() { - <-checkTimer.C - } - checkTimer.Reset(d) - } // Main watch loop for { @@ -169,9 +173,13 @@ func (w *checkWatcher) Run(ctx context.Context) { w.logger.Printf("[DEBUG] consul.health: told to stop watching an unwatched check: %q", c.checkID) } else { checks[c.checkID] = c + w.logger.Printf("[DEBUG] consul.health: watching alloc %q task %q check %q", c.allocID, c.taskName, c.checkName) - // First check should be after grace period - resetTimer(c.grace) + // Begin polling + if !checkTimer.Stop() { + <-checkTimer.C + } + checkTimer.Reset(w.pollFreq) } case <-ctx.Done(): return diff --git a/command/agent/consul/check_watcher_test.go b/command/agent/consul/check_watcher_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7e4947f51b952eedb2cf9b6fcae52b4b3680a508 --- /dev/null +++ b/command/agent/consul/check_watcher_test.go @@ -0,0 +1,252 @@ +package consul + +import ( + "context" + "testing" + "time" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/nomad/nomad/structs" +) + +// checkResponse is a response returned by the fakeChecksAPI after the given +// time. +type checkResponse struct { + at time.Time + id string + status string +} + +// fakeChecksAPI implements the Checks() method for testing Consul. +type fakeChecksAPI struct { + // responses is a map of check ids to their status at a particular + // time. checkResponses must be in chronological order. + responses map[string][]checkResponse +} + +func newFakeChecksAPI() *fakeChecksAPI { + return &fakeChecksAPI{responses: make(map[string][]checkResponse)} +} + +// add a new check status to Consul at the given time. +func (c *fakeChecksAPI) add(id, status string, at time.Time) { + c.responses[id] = append(c.responses[id], checkResponse{at, id, status}) +} + +func (c *fakeChecksAPI) Checks() (map[string]*api.AgentCheck, error) { + now := time.Now() + result := make(map[string]*api.AgentCheck, len(c.responses)) + + // Use the latest response for each check + for k, vs := range c.responses { + for _, v := range vs { + if v.at.After(now) { + break + } + result[k] = &api.AgentCheck{ + CheckID: k, + Name: k, + Status: v.status, + } + } + } + + return result, nil +} + +// testWatcherSetup sets up a fakeChecksAPI and a real checkWatcher with a test +// logger and faster poll frequency. +func testWatcherSetup() (*fakeChecksAPI, *checkWatcher) { + fakeAPI := newFakeChecksAPI() + cw := newCheckWatcher(testLogger(), fakeAPI) + cw.pollFreq = 10 * time.Millisecond + return fakeAPI, cw +} + +func testCheck() *structs.ServiceCheck { + return &structs.ServiceCheck{ + Name: "testcheck", + Interval: 100 * time.Millisecond, + Timeout: 100 * time.Millisecond, + CheckRestart: &structs.CheckRestart{ + Limit: 3, + Grace: 100 * time.Millisecond, + IgnoreWarnings: false, + }, + } +} + +// TestCheckWatcher_Skip asserts unwatched checks are ignored. +func TestCheckWatcher_Skip(t *testing.T) { + t.Parallel() + + // Create a check with restarting disabled + check := testCheck() + check.CheckRestart = nil + + cw := newCheckWatcher(testLogger(), newFakeChecksAPI()) + restarter1 := newFakeCheckRestarter() + cw.Watch("testalloc1", "testtask1", "testcheck1", check, restarter1) + + // Check should have been dropped as it's not watched + if n := len(cw.watchCh); n != 0 { + t.Fatalf("expected 0 checks to be enqueued for watching but found %d", n) + } +} + +// TestCheckWatcher_Healthy asserts healthy tasks are not restarted. +func TestCheckWatcher_Healthy(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + check1 := testCheck() + restarter1 := newFakeCheckRestarter() + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + + check2 := testCheck() + check2.CheckRestart.Limit = 1 + check2.CheckRestart.Grace = 0 + restarter2 := newFakeCheckRestarter() + cw.Watch("testalloc2", "testtask2", "testcheck2", check2, restarter2) + + // Make both checks healthy from the beginning + fakeAPI.add("testcheck1", "passing", time.Time{}) + fakeAPI.add("testcheck2", "passing", time.Time{}) + + // Run for 1 second + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + cw.Run(ctx) + + // Assert Restart was never called + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d", n) + } + if n := len(restarter2.restarts); n > 0 { + t.Errorf("expected check 2 to not be restarted but found %d", n) + } +} + +// TestCheckWatcher_Unhealthy asserts unhealthy tasks are not restarted. +func TestCheckWatcher_Unhealthy(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + check1 := testCheck() + restarter1 := newFakeCheckRestarter() + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + + check2 := testCheck() + check2.CheckRestart.Limit = 1 + check2.CheckRestart.Grace = 0 + restarter2 := newFakeCheckRestarter() + restarter2.restartDelay = 600 * time.Millisecond + cw.Watch("testalloc2", "testtask2", "testcheck2", check2, restarter2) + + // Check 1 always passes, check 2 always fails + fakeAPI.add("testcheck1", "passing", time.Time{}) + fakeAPI.add("testcheck2", "critical", time.Time{}) + + // Run for 1 second + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + cw.Run(ctx) + + // Ensure restart was never called on check 1 + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d", n) + } + + // Ensure restart was called twice on check 2 + if n := len(restarter2.restarts); n != 2 { + t.Errorf("expected check 2 to be restarted 2 times but found %d:\n%s", n, restarter2) + } +} + +// TestCheckWatcher_HealthyWarning asserts checks in warning with +// ignore_warnings=true do not restart tasks. +func TestCheckWatcher_HealthyWarning(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + check1 := testCheck() + check1.CheckRestart.Limit = 1 + check1.CheckRestart.Grace = 0 + check1.CheckRestart.IgnoreWarnings = true + restarter1 := newFakeCheckRestarter() + restarter1.restartDelay = 1100 * time.Millisecond + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + + // Check is always in warning but that's ok + fakeAPI.add("testcheck1", "warning", time.Time{}) + + // Run for 1 second + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + cw.Run(ctx) + + // Ensure restart was never called on check 1 + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d", n) + } +} + +// TestCheckWatcher_Flapping asserts checks that flap from healthy to unhealthy +// before the unhealthy limit is reached do not restart tasks. +func TestCheckWatcher_Flapping(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + check1 := testCheck() + check1.CheckRestart.Grace = 0 + restarter1 := newFakeCheckRestarter() + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + + // Check flaps and is never failing for the full 200ms needed to restart + now := time.Now() + fakeAPI.add("testcheck1", "passing", now) + fakeAPI.add("testcheck1", "critical", now.Add(100*time.Millisecond)) + fakeAPI.add("testcheck1", "passing", now.Add(250*time.Millisecond)) + fakeAPI.add("testcheck1", "critical", now.Add(300*time.Millisecond)) + fakeAPI.add("testcheck1", "passing", now.Add(450*time.Millisecond)) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + cw.Run(ctx) + + // Ensure restart was never called on check 1 + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d\n%s", n, restarter1) + } +} + +// TestCheckWatcher_Unwatch asserts unwatching checks prevents restarts. +func TestCheckWatcher_Unwatch(t *testing.T) { + t.Parallel() + + fakeAPI, cw := testWatcherSetup() + + // Unwatch immediately + check1 := testCheck() + check1.CheckRestart.Limit = 1 + check1.CheckRestart.Grace = 100 * time.Millisecond + restarter1 := newFakeCheckRestarter() + cw.Watch("testalloc1", "testtask1", "testcheck1", check1, restarter1) + cw.Unwatch("testcheck1") + + // Always failing + fakeAPI.add("testcheck1", "critical", time.Time{}) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + cw.Run(ctx) + + // Ensure restart was never called on check 1 + if n := len(restarter1.restarts); n > 0 { + t.Errorf("expected check 1 to not be restarted but found %d\n%s", n, restarter1) + } +} diff --git a/command/agent/consul/unit_test.go b/command/agent/consul/unit_test.go index 2a83d2989e9bd25d1164a2a4c99c05dda381b916..3e2837c75a797868c51f391653f0fa217bb9061c 100644 --- a/command/agent/consul/unit_test.go +++ b/command/agent/consul/unit_test.go @@ -8,7 +8,6 @@ import ( "os" "reflect" "strings" - "sync" "testing" "time" @@ -54,12 +53,62 @@ func testTask() *structs.Task { } } +// checkRestartRecord is used by a testFakeCtx to record when restarts occur +// due to a watched check. +type checkRestartRecord struct { + timestamp time.Time + source string + reason string + failure bool +} + +// fakeCheckRestarter is a test implementation of +type fakeCheckRestarter struct { + + // restartDelay is returned by RestartDelay to control the behavior of + // the checkWatcher + restartDelay time.Duration + + // restarts is a slice of all of the restarts triggered by the checkWatcher + restarts []checkRestartRecord +} + +func newFakeCheckRestarter() *fakeCheckRestarter { + return &fakeCheckRestarter{} +} + +// RestartDelay implements part of the TaskRestarter interface needed for check +// watching and is normally fulfilled by a task runner. +// +// The return value is determined by the restartDelay field. +func (c *fakeCheckRestarter) RestartDelay() time.Duration { + return c.restartDelay +} + +// Restart implements part of the TaskRestarter interface needed for check +// watching and is normally fulfilled by a TaskRunner. +// +// Restarts are recorded in the []restarts field. +func (c *fakeCheckRestarter) Restart(source, reason string, failure bool) { + c.restarts = append(c.restarts, checkRestartRecord{time.Now(), source, reason, failure}) +} + +// String for debugging +func (c *fakeCheckRestarter) String() string { + s := "" + for _, r := range c.restarts { + s += fmt.Sprintf("%s - %s: %s (failure: %t)\n", r.timestamp, r.source, r.reason, r.failure) + } + return s +} + // testFakeCtx contains a fake Consul AgentAPI and implements the Exec // interface to allow testing without running Consul. type testFakeCtx struct { ServiceClient *ServiceClient - FakeConsul *fakeConsul + FakeConsul *MockAgent Task *structs.Task + Restarter *fakeCheckRestarter // Ticked whenever a script is called execs chan int @@ -99,126 +148,21 @@ func (t *testFakeCtx) syncOnce() error { // setupFake creates a testFakeCtx with a ServiceClient backed by a fakeConsul. // A test Task is also provided. func setupFake() *testFakeCtx { - fc := newFakeConsul() + fc := NewMockAgent() return &testFakeCtx{ ServiceClient: NewServiceClient(fc, true, testLogger()), FakeConsul: fc, Task: testTask(), + Restarter: newFakeCheckRestarter(), execs: make(chan int, 100), } } -// fakeConsul is a fake in-memory Consul backend for ServiceClient. -type fakeConsul struct { - // maps of what services and checks have been registered - services map[string]*api.AgentServiceRegistration - checks map[string]*api.AgentCheckRegistration - mu sync.Mutex - - // when UpdateTTL is called the check ID will have its counter inc'd - checkTTLs map[string]int - - // What check status to return from Checks() - checkStatus string -} - -func newFakeConsul() *fakeConsul { - return &fakeConsul{ - services: make(map[string]*api.AgentServiceRegistration), - checks: make(map[string]*api.AgentCheckRegistration), - checkTTLs: make(map[string]int), - checkStatus: api.HealthPassing, - } -} - -func (c *fakeConsul) Services() (map[string]*api.AgentService, error) { - c.mu.Lock() - defer c.mu.Unlock() - - r := make(map[string]*api.AgentService, len(c.services)) - for k, v := range c.services { - r[k] = &api.AgentService{ - ID: v.ID, - Service: v.Name, - Tags: make([]string, len(v.Tags)), - Port: v.Port, - Address: v.Address, - EnableTagOverride: v.EnableTagOverride, - } - copy(r[k].Tags, v.Tags) - } - return r, nil -} - -func (c *fakeConsul) Checks() (map[string]*api.AgentCheck, error) { - c.mu.Lock() - defer c.mu.Unlock() - - r := make(map[string]*api.AgentCheck, len(c.checks)) - for k, v := range c.checks { - r[k] = &api.AgentCheck{ - CheckID: v.ID, - Name: v.Name, - Status: c.checkStatus, - Notes: v.Notes, - ServiceID: v.ServiceID, - ServiceName: c.services[v.ServiceID].Name, - } - } - return r, nil -} - -func (c *fakeConsul) CheckRegister(check *api.AgentCheckRegistration) error { - c.mu.Lock() - defer c.mu.Unlock() - c.checks[check.ID] = check - - // Be nice and make checks reachable-by-service - scheck := check.AgentServiceCheck - c.services[check.ServiceID].Checks = append(c.services[check.ServiceID].Checks, &scheck) - return nil -} - -func (c *fakeConsul) CheckDeregister(checkID string) error { - c.mu.Lock() - defer c.mu.Unlock() - delete(c.checks, checkID) - delete(c.checkTTLs, checkID) - return nil -} - -func (c *fakeConsul) ServiceRegister(service *api.AgentServiceRegistration) error { - c.mu.Lock() - defer c.mu.Unlock() - c.services[service.ID] = service - return nil -} - -func (c *fakeConsul) ServiceDeregister(serviceID string) error { - c.mu.Lock() - defer c.mu.Unlock() - delete(c.services, serviceID) - return nil -} - -func (c *fakeConsul) UpdateTTL(id string, output string, status string) error { - c.mu.Lock() - defer c.mu.Unlock() - check, ok := c.checks[id] - if !ok { - return fmt.Errorf("unknown check id: %q", id) - } - // Flip initial status to passing - check.Status = "passing" - c.checkTTLs[id]++ - return nil -} - func TestConsul_ChangeTags(t *testing.T) { ctx := setupFake() allocID := "allocid" - if err := ctx.ServiceClient.RegisterTask(allocID, ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask(allocID, ctx.Task, ctx.Restarter, nil, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -260,7 +204,7 @@ func TestConsul_ChangeTags(t *testing.T) { origTask := ctx.Task ctx.Task = testTask() ctx.Task.Services[0].Tags[0] = "newtag" - if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, nil, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } if err := ctx.syncOnce(); err != nil { @@ -342,7 +286,7 @@ func TestConsul_ChangePorts(t *testing.T) { }, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -430,7 +374,7 @@ func TestConsul_ChangePorts(t *testing.T) { // Removed PortLabel; should default to service's (y) }, } - if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } if err := ctx.syncOnce(); err != nil { @@ -509,7 +453,7 @@ func TestConsul_ChangeChecks(t *testing.T) { } allocID := "allocid" - if err := ctx.ServiceClient.RegisterTask(allocID, ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask(allocID, ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -576,7 +520,7 @@ func TestConsul_ChangeChecks(t *testing.T) { PortLabel: "x", }, } - if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, nil, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } if err := ctx.syncOnce(); err != nil { @@ -650,7 +594,7 @@ func TestConsul_ChangeChecks(t *testing.T) { func TestConsul_RegServices(t *testing.T) { ctx := setupFake() - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, nil, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -676,7 +620,7 @@ func TestConsul_RegServices(t *testing.T) { // Make a change which will register a new service ctx.Task.Services[0].Name = "taskname-service2" ctx.Task.Services[0].Tags[0] = "tag3" - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, nil, nil); err != nil { t.Fatalf("unpexpected error registering task: %v", err) } @@ -750,7 +694,7 @@ func TestConsul_ShutdownOK(t *testing.T) { go ctx.ServiceClient.Run() // Register a task and agent - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -823,7 +767,7 @@ func TestConsul_ShutdownSlow(t *testing.T) { go ctx.ServiceClient.Run() // Register a task and agent - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -894,7 +838,7 @@ func TestConsul_ShutdownBlocked(t *testing.T) { go ctx.ServiceClient.Run() // Register a task and agent - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -951,7 +895,7 @@ func TestConsul_NoTLSSkipVerifySupport(t *testing.T) { }, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, nil, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, nil, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -991,7 +935,7 @@ func TestConsul_CancelScript(t *testing.T) { }, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1028,7 +972,7 @@ func TestConsul_CancelScript(t *testing.T) { }, } - if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx, nil); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", origTask, ctx.Task, ctx.Restarter, ctx, nil); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1115,7 +1059,7 @@ func TestConsul_DriverNetwork_AutoUse(t *testing.T) { AutoAdvertise: true, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1218,7 +1162,7 @@ func TestConsul_DriverNetwork_NoAutoUse(t *testing.T) { AutoAdvertise: false, } - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1304,7 +1248,7 @@ func TestConsul_DriverNetwork_Change(t *testing.T) { } // Initial service should advertise host port x - if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.RegisterTask("allocid", ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error registering task: %v", err) } @@ -1314,7 +1258,7 @@ func TestConsul_DriverNetwork_Change(t *testing.T) { orig := ctx.Task.Copy() ctx.Task.Services[0].AddressMode = structs.AddressModeHost - if err := ctx.ServiceClient.UpdateTask("allocid", orig, ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", orig, ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error updating task: %v", err) } @@ -1324,7 +1268,7 @@ func TestConsul_DriverNetwork_Change(t *testing.T) { orig = ctx.Task.Copy() ctx.Task.Services[0].AddressMode = structs.AddressModeDriver - if err := ctx.ServiceClient.UpdateTask("allocid", orig, ctx.Task, ctx, net); err != nil { + if err := ctx.ServiceClient.UpdateTask("allocid", orig, ctx.Task, ctx.Restarter, ctx, net); err != nil { t.Fatalf("unexpected error updating task: %v", err) } diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 91c3a2b76e6661a86b8091cbe601df8fa78f22dc..5335a71056cd8b00603b98cb155550f9c9ff37fa 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -3842,7 +3842,7 @@ type TaskEvent struct { } func (te *TaskEvent) GoString() string { - return fmt.Sprintf("%v at %v", te.Type, te.Time) + return fmt.Sprintf("%v - %v", te.Time, te.Type) } // SetMessage sets the message of TaskEvent