Commit 9d42f4d0 authored by Alex Dadgar's avatar Alex Dadgar
Browse files

Plugin client's handle plugin dying

This PR plumbs the plugins done ctx through the base and driver plugin
clients (device already had it). Further, it adds generic handling of
gRPC stream errors.
parent 0200000d
Showing with 165 additions and 110 deletions
+165 -110
package exec
import (
"context"
"fmt"
"os"
"path/filepath"
......@@ -20,7 +21,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (
......
package java
import (
"context"
"fmt"
"os"
"os/exec"
......@@ -23,7 +24,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (
......
......@@ -16,7 +16,6 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
netctx "golang.org/x/net/context"
)
const (
......@@ -232,7 +231,7 @@ func (d *Driver) Capabilities() (*drivers.Capabilities, error) {
return capabilities, nil
}
func (d *Driver) Fingerprint(ctx netctx.Context) (<-chan *drivers.Fingerprint, error) {
func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) {
ch := make(chan *drivers.Fingerprint)
go d.handleFingerprint(ctx, ch)
return ch, nil
......@@ -365,7 +364,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru
}
func (d *Driver) WaitTask(ctx netctx.Context, taskID string) (<-chan *drivers.ExitResult, error) {
func (d *Driver) WaitTask(ctx context.Context, taskID string) (<-chan *drivers.ExitResult, error) {
handle, ok := d.tasks.Get(taskID)
if !ok {
return nil, drivers.ErrTaskNotFound
......@@ -430,7 +429,7 @@ func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
return nil, nil
}
func (d *Driver) TaskEvents(ctx netctx.Context) (<-chan *drivers.TaskEvent, error) {
func (d *Driver) TaskEvents(ctx context.Context) (<-chan *drivers.TaskEvent, error) {
return d.eventer.TaskEvents(ctx)
}
......
package qemu
import (
"context"
"errors"
"fmt"
"net"
......@@ -25,7 +26,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (
......
package rawexec
import (
"context"
"fmt"
"os"
"path/filepath"
......@@ -22,7 +23,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (
......
......@@ -4,6 +4,7 @@ package rkt
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
......@@ -36,7 +37,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
rktv1 "github.com/rkt/rkt/api/v1"
"golang.org/x/net/context"
)
const (
......
......@@ -3,17 +3,16 @@
package rkt
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"testing"
"time"
"os"
"bytes"
"github.com/hashicorp/hcl2/hcl"
ctestutil "github.com/hashicorp/nomad/client/testutil"
"github.com/hashicorp/nomad/helper/testlog"
......@@ -26,7 +25,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
)
var _ drivers.DriverPlugin = (*Driver)(nil)
......
package eventer
import (
"context"
"sync"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/plugins/drivers"
"golang.org/x/net/context"
)
var (
......
......@@ -12,10 +12,13 @@ import (
// gRPC to communicate to the remote plugin.
type BasePluginClient struct {
Client proto.BasePluginClient
// DoneCtx is closed when the plugin exits
DoneCtx context.Context
}
func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
presp, err := b.Client.PluginInfo(context.Background(), &proto.PluginInfoRequest{})
presp, err := b.Client.PluginInfo(b.DoneCtx, &proto.PluginInfoRequest{})
if err != nil {
return nil, err
}
......@@ -41,7 +44,7 @@ func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
}
func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
presp, err := b.Client.ConfigSchema(context.Background(), &proto.ConfigSchemaRequest{})
presp, err := b.Client.ConfigSchema(b.DoneCtx, &proto.ConfigSchemaRequest{})
if err != nil {
return nil, err
}
......@@ -51,7 +54,7 @@ func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
func (b *BasePluginClient) SetConfig(data []byte, config *ClientAgentConfig) error {
// Send the config
_, err := b.Client.SetConfig(context.Background(), &proto.SetConfigRequest{
_, err := b.Client.SetConfig(b.DoneCtx, &proto.SetConfigRequest{
MsgpackConfig: data,
NomadConfig: config.toProto(),
})
......
......@@ -51,7 +51,10 @@ func (p *PluginBase) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error
}
func (p *PluginBase) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &BasePluginClient{Client: proto.NewBasePluginClient(c)}, nil
return &BasePluginClient{
Client: proto.NewBasePluginClient(c),
DoneCtx: ctx,
}, nil
}
// MsgpackHandle is a shared handle for encoding/decoding of structs
......
......@@ -9,9 +9,7 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/device/proto"
netctx "golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/nomad/plugins/shared"
)
// devicePluginClient implements the client side of a remote device plugin, using
......@@ -49,28 +47,33 @@ func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri
// the gRPC stream to a channel. Exits either when context is cancelled or the
// stream has an error.
func (d *devicePluginClient) handleFingerprint(
ctx netctx.Context,
ctx context.Context,
stream proto.DevicePlugin_FingerprintClient,
out chan *FingerprintResponse) {
defer close(out)
for {
resp, err := stream.Recv()
if err != nil {
if err != io.EOF {
out <- &FingerprintResponse{
Error: d.handleStreamErr(err, ctx),
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
close(out)
return
}
// Send the response
out <- &FingerprintResponse{
f := &FingerprintResponse{
Devices: convertProtoDeviceGroups(resp.GetDeviceGroup()),
}
select {
case <-ctx.Done():
return
case out <- f:
}
}
}
......@@ -116,69 +119,32 @@ func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration)
// the gRPC stream to a channel. Exits either when context is cancelled or the
// stream has an error.
func (d *devicePluginClient) handleStats(
ctx netctx.Context,
ctx context.Context,
stream proto.DevicePlugin_StatsClient,
out chan *StatsResponse) {
defer close(out)
for {
resp, err := stream.Recv()
if err != nil {
if err != io.EOF {
out <- &StatsResponse{
Error: d.handleStreamErr(err, ctx),
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
close(out)
return
}
// Send the response
out <- &StatsResponse{
s := &StatsResponse{
Groups: convertProtoDeviceGroupsStats(resp.GetGroups()),
}
}
}
// handleStreamErr is used to handle a non io.EOF error in a stream. It handles
// detecting if the plugin has shutdown
func (d *devicePluginClient) handleStreamErr(err error, ctx context.Context) error {
if err == nil {
return nil
}
// Determine if the error is because the plugin shutdown
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
// Potentially wait a little before returning an error so we can detect
// the exit
select {
case <-d.doneCtx.Done():
err = base.ErrPluginShutdown
case <-ctx.Done():
err = ctx.Err()
// There is no guarantee that the select will choose the
// doneCtx first so we have to double check
select {
case <-d.doneCtx.Done():
err = base.ErrPluginShutdown
default:
}
case <-time.After(3 * time.Second):
// Its okay to wait a while since the connection isn't available and
// on local host it is likely shutting down. It is not expected for
// this to ever reach even close to 3 seconds.
return
case out <- s:
}
// It is an error we don't know how to handle, so return it
return err
}
// Context was cancelled
if errStatus := status.FromContextError(ctx.Err()); errStatus.Code() == codes.Canceled {
return context.Canceled
}
return err
}
......@@ -31,7 +31,8 @@ func (p *PluginDevice) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker
doneCtx: ctx,
client: proto.NewDevicePluginClient(c),
BasePluginClient: &base.BasePluginClient{
Client: bproto.NewBasePluginClient(c),
Client: bproto.NewBasePluginClient(c),
DoneCtx: ctx,
},
}, nil
}
......
package drivers
import (
"context"
"errors"
"fmt"
"io"
"time"
"github.com/LK4D4/joincontext"
"github.com/golang/protobuf/ptypes"
hclog "github.com/hashicorp/go-hclog"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/drivers/proto"
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"golang.org/x/net/context"
)
var _ DriverPlugin = &driverPluginClient{}
......@@ -22,12 +23,15 @@ type driverPluginClient struct {
client proto.DriverClient
logger hclog.Logger
// doneCtx is closed when the plugin exits
doneCtx context.Context
}
func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
req := &proto.TaskConfigSchemaRequest{}
resp, err := d.client.TaskConfigSchema(context.Background(), req)
resp, err := d.client.TaskConfigSchema(d.doneCtx, req)
if err != nil {
return nil, err
}
......@@ -38,7 +42,7 @@ func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
req := &proto.CapabilitiesRequest{}
resp, err := d.client.Capabilities(context.Background(), req)
resp, err := d.client.Capabilities(d.doneCtx, req)
if err != nil {
return nil, err
}
......@@ -67,12 +71,15 @@ func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerprint, error) {
req := &proto.FingerprintRequest{}
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
stream, err := d.client.Fingerprint(ctx, req)
if err != nil {
return nil, err
}
ch := make(chan *Fingerprint)
ch := make(chan *Fingerprint, 1)
go d.handleFingerprint(ctx, ch, stream)
return ch, nil
......@@ -82,17 +89,18 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin
defer close(ch)
for {
pb, err := stream.Recv()
if err == io.EOF {
return
}
if err != nil {
select {
case <-ctx.Done():
case ch <- &Fingerprint{Err: fmt.Errorf("error from RPC stream: %v", err)}:
if err != io.EOF {
d.logger.Error("error receiving stream from Fingerprint driver RPC", "error", err)
ch <- &Fingerprint{
Err: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
return
}
f := &Fingerprint{
Attributes: pb.Attributes,
Health: healthStateFromProto(pb.Health),
......@@ -112,7 +120,7 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin
func (d *driverPluginClient) RecoverTask(h *TaskHandle) error {
req := &proto.RecoverTaskRequest{Handle: taskHandleToProto(h)}
_, err := d.client.RecoverTask(context.Background(), req)
_, err := d.client.RecoverTask(d.doneCtx, req)
return err
}
......@@ -124,7 +132,7 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr
Task: taskConfigToProto(c),
}
resp, err := d.client.StartTask(context.Background(), req)
resp, err := d.client.StartTask(d.doneCtx, req)
if err != nil {
return nil, nil, err
}
......@@ -150,6 +158,10 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr
// the same task without issue.
func (d *driverPluginClient) WaitTask(ctx context.Context, id string) (<-chan *ExitResult, error) {
ch := make(chan *ExitResult)
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
go d.handleWaitTask(ctx, id, ch)
return ch, nil
}
......@@ -186,7 +198,7 @@ func (d *driverPluginClient) StopTask(taskID string, timeout time.Duration, sign
Signal: signal,
}
_, err := d.client.StopTask(context.Background(), req)
_, err := d.client.StopTask(d.doneCtx, req)
return err
}
......@@ -199,7 +211,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
Force: force,
}
_, err := d.client.DestroyTask(context.Background(), req)
_, err := d.client.DestroyTask(d.doneCtx, req)
return err
}
......@@ -207,7 +219,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
req := &proto.InspectTaskRequest{TaskId: taskID}
resp, err := d.client.InspectTask(context.Background(), req)
resp, err := d.client.InspectTask(d.doneCtx, req)
if err != nil {
return nil, err
}
......@@ -238,7 +250,7 @@ func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
req := &proto.TaskStatsRequest{TaskId: taskID}
resp, err := d.client.TaskStats(context.Background(), req)
resp, err := d.client.TaskStats(d.doneCtx, req)
if err != nil {
return nil, err
}
......@@ -255,28 +267,36 @@ func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsa
// tasks such as lifecycle events, terminal errors, etc.
func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, error) {
req := &proto.TaskEventsRequest{}
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
stream, err := d.client.TaskEvents(ctx, req)
if err != nil {
return nil, err
}
ch := make(chan *TaskEvent)
go d.handleTaskEvents(ch, stream)
ch := make(chan *TaskEvent, 1)
go d.handleTaskEvents(ctx, ch, stream)
return ch, nil
}
func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
func (d *driverPluginClient) handleTaskEvents(ctx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
defer close(ch)
for {
ev, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err)
ch <- &TaskEvent{Err: err}
break
if err != io.EOF {
d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err)
ch <- &TaskEvent{
Err: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
return
}
timestamp, _ := ptypes.Timestamp(ev.Timestamp)
event := &TaskEvent{
TaskID: ev.TaskId,
......@@ -284,7 +304,11 @@ func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.D
Message: ev.Message,
Timestamp: timestamp,
}
ch <- event
select {
case <-ctx.Done():
return
case ch <- event:
}
}
}
......@@ -294,7 +318,7 @@ func (d *driverPluginClient) SignalTask(taskID string, signal string) error {
TaskId: taskID,
Signal: signal,
}
_, err := d.client.SignalTask(context.Background(), req)
_, err := d.client.SignalTask(d.doneCtx, req)
return err
}
......@@ -309,7 +333,7 @@ func (d *driverPluginClient) ExecTask(taskID string, cmd []string, timeout time.
Timeout: ptypes.DurationProto(timeout),
}
resp, err := d.client.ExecTask(context.Background(), req)
resp, err := d.client.ExecTask(d.doneCtx, req)
if err != nil {
return nil, err
}
......
package drivers
import (
"context"
"fmt"
"path/filepath"
"sort"
......@@ -14,7 +15,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/msgpack"
"golang.org/x/net/context"
)
// DriverPlugin is the interface with drivers will implement. It is also
......
......@@ -38,9 +38,11 @@ func (p *PluginDriver) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) err
func (p *PluginDriver) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &driverPluginClient{
BasePluginClient: &base.BasePluginClient{
Client: baseproto.NewBasePluginClient(c),
DoneCtx: ctx,
Client: baseproto.NewBasePluginClient(c),
},
client: proto.NewDriverClient(c),
logger: p.logger,
client: proto.NewDriverClient(c),
logger: p.logger,
doneCtx: ctx,
}, nil
}
......@@ -2,6 +2,7 @@ package drivers
import (
"bytes"
"context"
"sync"
"testing"
"time"
......@@ -10,7 +11,6 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
"github.com/stretchr/testify/require"
"github.com/ugorji/go/codec"
"golang.org/x/net/context"
)
type testDriverState struct {
......
......@@ -4,13 +4,12 @@ import (
"fmt"
"io"
"golang.org/x/net/context"
"github.com/golang/protobuf/ptypes"
hclog "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/plugins/drivers/proto"
context "golang.org/x/net/context"
)
type driverPluginServer struct {
......
package drivers
import (
"context"
"fmt"
"io/ioutil"
"path/filepath"
"runtime"
"time"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
hclog "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/nomad/client/allocdir"
......@@ -21,6 +18,8 @@ import (
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
)
type DriverHarness struct {
......
package shared
import (
"context"
"time"
"github.com/hashicorp/nomad/plugins/base"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// HandleStreamErr is used to handle a non io.EOF error in a stream. It handles
// detecting if the plugin has shutdown via the passeed pluginCtx. The
// parameters are:
// - err: the error returned from the streaming RPC
// - reqCtx: the context passed to the streaming request
// - pluginCtx: the plugins done ctx used to detect the plugin dying
//
// The return values are:
// - base.ErrPluginShutdown if the error is because the plugin shutdown
// - context.Canceled if the reqCtx is canceled
// - The original error
func HandleStreamErr(err error, reqCtx, pluginCtx context.Context) error {
if err == nil {
return nil
}
// Determine if the error is because the plugin shutdown
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
// Potentially wait a little before returning an error so we can detect
// the exit
select {
case <-pluginCtx.Done():
err = base.ErrPluginShutdown
case <-reqCtx.Done():
err = reqCtx.Err()
// There is no guarantee that the select will choose the
// doneCtx first so we have to double check
select {
case <-pluginCtx.Done():
err = base.ErrPluginShutdown
default:
}
case <-time.After(3 * time.Second):
// Its okay to wait a while since the connection isn't available and
// on local host it is likely shutting down. It is not expected for
// this to ever reach even close to 3 seconds.
}
// It is an error we don't know how to handle, so return it
return err
}
// Context was cancelled
if errStatus := status.FromContextError(reqCtx.Err()); errStatus.Code() == codes.Canceled {
return context.Canceled
}
return err
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment