Unverified Commit d6e27a7e authored by hc-github-team-nomad-core's avatar hc-github-team-nomad-core Committed by GitHub
Browse files

Merge pull request #12635 from hashicorp/backport/csi-plugin-client-restarts/shortly-open-lamprey

This pull request was automerged via backport-assistant
No related merge requests found
Showing with 201 additions and 146 deletions
+201 -146
```release-note:bug
csi: Fixed a bug where the plugin instance manager would not retry the initial gRPC connection to plugins
```
```release-note:bug
csi: Fixed a bug where the plugin supervisor would not restart the task if it failed to connect to the plugin
```
...@@ -38,6 +38,7 @@ type csiPluginSupervisorHook struct { ...@@ -38,6 +38,7 @@ type csiPluginSupervisorHook struct {
// eventEmitter is used to emit events to the task // eventEmitter is used to emit events to the task
eventEmitter ti.EventEmitter eventEmitter ti.EventEmitter
lifecycle ti.TaskLifecycle
shutdownCtx context.Context shutdownCtx context.Context
shutdownCancelFn context.CancelFunc shutdownCancelFn context.CancelFunc
...@@ -54,6 +55,7 @@ type csiPluginSupervisorHookConfig struct { ...@@ -54,6 +55,7 @@ type csiPluginSupervisorHookConfig struct {
clientStateDirPath string clientStateDirPath string
events ti.EventEmitter events ti.EventEmitter
runner *TaskRunner runner *TaskRunner
lifecycle ti.TaskLifecycle
capabilities *drivers.Capabilities capabilities *drivers.Capabilities
logger hclog.Logger logger hclog.Logger
} }
...@@ -90,6 +92,7 @@ func newCSIPluginSupervisorHook(config *csiPluginSupervisorHookConfig) *csiPlugi ...@@ -90,6 +92,7 @@ func newCSIPluginSupervisorHook(config *csiPluginSupervisorHookConfig) *csiPlugi
hook := &csiPluginSupervisorHook{ hook := &csiPluginSupervisorHook{
alloc: config.runner.Alloc(), alloc: config.runner.Alloc(),
runner: config.runner, runner: config.runner,
lifecycle: config.lifecycle,
logger: config.logger, logger: config.logger,
task: task, task: task,
mountPoint: pluginRoot, mountPoint: pluginRoot,
...@@ -201,27 +204,41 @@ func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) { ...@@ -201,27 +204,41 @@ func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) {
}() }()
socketPath := filepath.Join(h.mountPoint, structs.CSISocketName) socketPath := filepath.Join(h.mountPoint, structs.CSISocketName)
client := csi.NewClient(socketPath, h.logger.Named("csi_client").With(
"plugin.name", h.task.CSIPluginConfig.ID,
"plugin.type", h.task.CSIPluginConfig.Type))
defer client.Close()
t := time.NewTimer(0) t := time.NewTimer(0)
// We're in Poststart at this point, so if we can't connect within
// this deadline, assume it's broken so we can restart the task
startCtx, startCancelFn := context.WithTimeout(ctx, 30*time.Second)
defer startCancelFn()
var err error
var pluginHealthy bool
// Step 1: Wait for the plugin to initially become available. // Step 1: Wait for the plugin to initially become available.
WAITFORREADY: WAITFORREADY:
for { for {
select { select {
case <-ctx.Done(): case <-startCtx.Done():
h.kill(ctx, fmt.Errorf("CSI plugin failed probe: %v", err))
return return
case <-t.C: case <-t.C:
pluginHealthy, err := h.supervisorLoopOnce(ctx, socketPath) pluginHealthy, err = h.supervisorLoopOnce(startCtx, client)
if err != nil || !pluginHealthy { if err != nil || !pluginHealthy {
h.logger.Debug("CSI Plugin not ready", "error", err) h.logger.Debug("CSI plugin not ready", "error", err)
// Use only a short delay here to optimize for quickly
// Plugin is not yet returning healthy, because we want to optimise for // bringing up a plugin
// quickly bringing a plugin online, we use a short timeout here.
// TODO(dani): Test with more plugins and adjust.
t.Reset(5 * time.Second) t.Reset(5 * time.Second)
continue continue
} }
// Mark the plugin as healthy in a task event // Mark the plugin as healthy in a task event
h.logger.Debug("CSI plugin is ready")
h.previousHealthState = pluginHealthy h.previousHealthState = pluginHealthy
event := structs.NewTaskEvent(structs.TaskPluginHealthy) event := structs.NewTaskEvent(structs.TaskPluginHealthy)
event.SetMessage(fmt.Sprintf("plugin: %s", h.task.CSIPluginConfig.ID)) event.SetMessage(fmt.Sprintf("plugin: %s", h.task.CSIPluginConfig.ID))
...@@ -232,15 +249,14 @@ WAITFORREADY: ...@@ -232,15 +249,14 @@ WAITFORREADY:
} }
// Step 2: Register the plugin with the catalog. // Step 2: Register the plugin with the catalog.
deregisterPluginFn, err := h.registerPlugin(socketPath) deregisterPluginFn, err := h.registerPlugin(client, socketPath)
if err != nil { if err != nil {
h.logger.Error("CSI Plugin registration failed", "error", err) h.kill(ctx, fmt.Errorf("CSI plugin failed to register: %v", err))
event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) return
event.SetMessage(fmt.Sprintf("failed to register plugin: %s, reason: %v", h.task.CSIPluginConfig.ID, err))
h.eventEmitter.EmitEvent(event)
} }
// Step 3: Start the lightweight supervisor loop. // Step 3: Start the lightweight supervisor loop. At this point, failures
// don't cause the task to restart
t.Reset(0) t.Reset(0)
for { for {
select { select {
...@@ -249,9 +265,9 @@ WAITFORREADY: ...@@ -249,9 +265,9 @@ WAITFORREADY:
deregisterPluginFn() deregisterPluginFn()
return return
case <-t.C: case <-t.C:
pluginHealthy, err := h.supervisorLoopOnce(ctx, socketPath) pluginHealthy, err := h.supervisorLoopOnce(ctx, client)
if err != nil { if err != nil {
h.logger.Error("CSI Plugin fingerprinting failed", "error", err) h.logger.Error("CSI plugin fingerprinting failed", "error", err)
} }
// The plugin has transitioned to a healthy state. Emit an event. // The plugin has transitioned to a healthy state. Emit an event.
...@@ -265,7 +281,7 @@ WAITFORREADY: ...@@ -265,7 +281,7 @@ WAITFORREADY:
if h.previousHealthState && !pluginHealthy { if h.previousHealthState && !pluginHealthy {
event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) event := structs.NewTaskEvent(structs.TaskPluginUnhealthy)
if err != nil { if err != nil {
event.SetMessage(fmt.Sprintf("error: %v", err)) event.SetMessage(fmt.Sprintf("Error: %v", err))
} else { } else {
event.SetMessage("Unknown Reason") event.SetMessage("Unknown Reason")
} }
...@@ -281,16 +297,9 @@ WAITFORREADY: ...@@ -281,16 +297,9 @@ WAITFORREADY:
} }
} }
func (h *csiPluginSupervisorHook) registerPlugin(socketPath string) (func(), error) { func (h *csiPluginSupervisorHook) registerPlugin(client csi.CSIPlugin, socketPath string) (func(), error) {
// At this point we know the plugin is ready and we can fingerprint it // At this point we know the plugin is ready and we can fingerprint it
// to get its vendor name and version // to get its vendor name and version
client, err := csi.NewClient(socketPath, h.logger.Named("csi_client").With("plugin.name", h.task.CSIPluginConfig.ID, "plugin.type", h.task.CSIPluginConfig.Type))
if err != nil {
return nil, fmt.Errorf("failed to create csi client: %v", err)
}
defer client.Close()
info, err := client.PluginInfo() info, err := client.PluginInfo()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to probe plugin: %v", err) return nil, fmt.Errorf("failed to probe plugin: %v", err)
...@@ -354,21 +363,13 @@ func (h *csiPluginSupervisorHook) registerPlugin(socketPath string) (func(), err ...@@ -354,21 +363,13 @@ func (h *csiPluginSupervisorHook) registerPlugin(socketPath string) (func(), err
}, nil }, nil
} }
func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, socketPath string) (bool, error) { func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, client csi.CSIPlugin) (bool, error) {
_, err := os.Stat(socketPath) probeCtx, probeCancelFn := context.WithTimeout(ctx, 5*time.Second)
if err != nil { defer probeCancelFn()
return false, fmt.Errorf("failed to stat socket: %v", err)
}
client, err := csi.NewClient(socketPath, h.logger.Named("csi_client").With("plugin.name", h.task.CSIPluginConfig.ID, "plugin.type", h.task.CSIPluginConfig.Type)) healthy, err := client.PluginProbe(probeCtx)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to create csi client: %v", err) return false, err
}
defer client.Close()
healthy, err := client.PluginProbe(ctx)
if err != nil {
return false, fmt.Errorf("failed to probe plugin: %v", err)
} }
return healthy, nil return healthy, nil
...@@ -387,6 +388,21 @@ func (h *csiPluginSupervisorHook) Stop(_ context.Context, req *interfaces.TaskSt ...@@ -387,6 +388,21 @@ func (h *csiPluginSupervisorHook) Stop(_ context.Context, req *interfaces.TaskSt
return nil return nil
} }
func (h *csiPluginSupervisorHook) kill(ctx context.Context, reason error) {
h.logger.Error("killing task because plugin failed", "error", reason)
event := structs.NewTaskEvent(structs.TaskPluginUnhealthy)
event.SetMessage(fmt.Sprintf("Error: %v", reason.Error()))
h.eventEmitter.EmitEvent(event)
if err := h.lifecycle.Kill(ctx,
structs.NewTaskEvent(structs.TaskKilling).
SetFailsTask().
SetDisplayMessage("CSI plugin did not become healthy before timeout"),
); err != nil {
h.logger.Error("failed to kill task", "kill_reason", reason, "error", err)
}
}
func ensureMountpointInserted(mounts []*drivers.MountConfig, mount *drivers.MountConfig) []*drivers.MountConfig { func ensureMountpointInserted(mounts []*drivers.MountConfig, mount *drivers.MountConfig) []*drivers.MountConfig {
for _, mnt := range mounts { for _, mnt := range mounts {
if mnt.IsEqual(mount) { if mnt.IsEqual(mount) {
......
...@@ -76,6 +76,7 @@ func (tr *TaskRunner) initHooks() { ...@@ -76,6 +76,7 @@ func (tr *TaskRunner) initHooks() {
clientStateDirPath: tr.clientConfig.StateDir, clientStateDirPath: tr.clientConfig.StateDir,
events: tr, events: tr,
runner: tr, runner: tr,
lifecycle: tr,
capabilities: tr.driverCapabilities, capabilities: tr.driverCapabilities,
logger: hookLogger, logger: hookLogger,
})) }))
......
...@@ -390,11 +390,11 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie ...@@ -390,11 +390,11 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie
c.dynamicRegistry = c.dynamicRegistry =
dynamicplugins.NewRegistry(c.stateDB, map[string]dynamicplugins.PluginDispenser{ dynamicplugins.NewRegistry(c.stateDB, map[string]dynamicplugins.PluginDispenser{
dynamicplugins.PluginTypeCSIController: func(info *dynamicplugins.PluginInfo) (interface{}, error) { dynamicplugins.PluginTypeCSIController: func(info *dynamicplugins.PluginInfo) (interface{}, error) {
return csi.NewClient(info.ConnectionInfo.SocketPath, logger.Named("csi_client").With("plugin.name", info.Name, "plugin.type", "controller")) return csi.NewClient(info.ConnectionInfo.SocketPath, logger.Named("csi_client").With("plugin.name", info.Name, "plugin.type", "controller")), nil
}, },
dynamicplugins.PluginTypeCSINode: func(info *dynamicplugins.PluginInfo) (interface{}, error) { dynamicplugins.PluginTypeCSINode: func(info *dynamicplugins.PluginInfo) (interface{}, error) {
return csi.NewClient(info.ConnectionInfo.SocketPath, logger.Named("csi_client").With("plugin.name", info.Name, "plugin.type", "client")) return csi.NewClient(info.ConnectionInfo.SocketPath, logger.Named("csi_client").With("plugin.name", info.Name, "plugin.type", "client")), nil
}, // TODO(tgross): refactor these dispenser constructors into csimanager to tidy it up },
}) })
// Setup the clients RPC server // Setup the clients RPC server
......
...@@ -73,12 +73,7 @@ func newInstanceManager(logger hclog.Logger, eventer TriggerNodeEvent, updater U ...@@ -73,12 +73,7 @@ func newInstanceManager(logger hclog.Logger, eventer TriggerNodeEvent, updater U
} }
func (i *instanceManager) run() { func (i *instanceManager) run() {
c, err := csi.NewClient(i.info.ConnectionInfo.SocketPath, i.logger) c := csi.NewClient(i.info.ConnectionInfo.SocketPath, i.logger)
if err != nil {
i.logger.Error("failed to setup instance manager client", "error", err)
close(i.shutdownCh)
return
}
i.client = c i.client = c
i.fp.client = c i.fp.client = c
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"math" "math"
"net" "net"
"os"
"time" "time"
csipbv1 "github.com/container-storage-interface/spec/lib/go/csi" csipbv1 "github.com/container-storage-interface/spec/lib/go/csi"
...@@ -88,6 +89,7 @@ type CSINodeClient interface { ...@@ -88,6 +89,7 @@ type CSINodeClient interface {
} }
type client struct { type client struct {
addr string
conn *grpc.ClientConn conn *grpc.ClientConn
identityClient csipbv1.IdentityClient identityClient csipbv1.IdentityClient
controllerClient CSIControllerClient controllerClient CSIControllerClient
...@@ -102,30 +104,59 @@ func (c *client) Close() error { ...@@ -102,30 +104,59 @@ func (c *client) Close() error {
return nil return nil
} }
func NewClient(addr string, logger hclog.Logger) (CSIPlugin, error) { func NewClient(addr string, logger hclog.Logger) CSIPlugin {
if addr == "" { return &client{
return nil, fmt.Errorf("address is empty") addr: addr,
logger: logger,
} }
}
conn, err := newGrpcConn(addr, logger) func (c *client) ensureConnected(ctx context.Context) error {
if err != nil { if c == nil {
return nil, err return fmt.Errorf("client not initialized")
}
if c.conn != nil {
return nil
}
if c.addr == "" {
return fmt.Errorf("address is empty")
}
var conn *grpc.ClientConn
var err error
t := time.NewTimer(0)
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout while connecting to gRPC socket: %v", err)
case <-t.C:
_, err = os.Stat(c.addr)
if err != nil {
err = fmt.Errorf("failed to stat socket: %v", err)
t.Reset(5 * time.Second)
continue
}
conn, err = newGrpcConn(c.addr, c.logger)
if err != nil {
err = fmt.Errorf("failed to create gRPC connection: %v", err)
t.Reset(time.Second * 5)
continue
}
c.conn = conn
c.identityClient = csipbv1.NewIdentityClient(conn)
c.controllerClient = csipbv1.NewControllerClient(conn)
c.nodeClient = csipbv1.NewNodeClient(conn)
return nil
}
} }
return &client{
conn: conn,
identityClient: csipbv1.NewIdentityClient(conn),
controllerClient: csipbv1.NewControllerClient(conn),
nodeClient: csipbv1.NewNodeClient(conn),
logger: logger,
}, nil
} }
func newGrpcConn(addr string, logger hclog.Logger) (*grpc.ClientConn, error) { func newGrpcConn(addr string, logger hclog.Logger) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) // after DialContext returns w/ initial connection, closing this
// context is a no-op
connectCtx, cancel := context.WithTimeout(context.Background(), time.Second*1)
defer cancel() defer cancel()
conn, err := grpc.DialContext( conn, err := grpc.DialContext(
ctx, connectCtx,
addr, addr,
grpc.WithBlock(), grpc.WithBlock(),
grpc.WithInsecure(), grpc.WithInsecure(),
...@@ -146,10 +177,14 @@ func newGrpcConn(addr string, logger hclog.Logger) (*grpc.ClientConn, error) { ...@@ -146,10 +177,14 @@ func newGrpcConn(addr string, logger hclog.Logger) (*grpc.ClientConn, error) {
// PluginInfo describes the type and version of a plugin as required by the nomad // PluginInfo describes the type and version of a plugin as required by the nomad
// base.BasePlugin interface. // base.BasePlugin interface.
func (c *client) PluginInfo() (*base.PluginInfoResponse, error) { func (c *client) PluginInfo() (*base.PluginInfoResponse, error) {
// note: no grpc retries needed here, as this is called in
// fingerprinting and will get retried by the caller.
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}
// note: no grpc retries needed here, as this is called in
// fingerprinting and will get retried by the caller.
name, version, err := c.PluginGetInfo(ctx) name, version, err := c.PluginGetInfo(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -176,6 +211,10 @@ func (c *client) SetConfig(_ *base.Config) error { ...@@ -176,6 +211,10 @@ func (c *client) SetConfig(_ *base.Config) error {
} }
func (c *client) PluginProbe(ctx context.Context) (bool, error) { func (c *client) PluginProbe(ctx context.Context) (bool, error) {
if err := c.ensureConnected(ctx); err != nil {
return false, err
}
// note: no grpc retries should be done here // note: no grpc retries should be done here
req, err := c.identityClient.Probe(ctx, &csipbv1.ProbeRequest{}) req, err := c.identityClient.Probe(ctx, &csipbv1.ProbeRequest{})
if err != nil { if err != nil {
...@@ -198,11 +237,8 @@ func (c *client) PluginProbe(ctx context.Context) (bool, error) { ...@@ -198,11 +237,8 @@ func (c *client) PluginProbe(ctx context.Context) (bool, error) {
} }
func (c *client) PluginGetInfo(ctx context.Context) (string, string, error) { func (c *client) PluginGetInfo(ctx context.Context) (string, string, error) {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return "", "", fmt.Errorf("Client not initialized") return "", "", err
}
if c.identityClient == nil {
return "", "", fmt.Errorf("Client not initialized")
} }
resp, err := c.identityClient.GetPluginInfo(ctx, &csipbv1.GetPluginInfoRequest{}) resp, err := c.identityClient.GetPluginInfo(ctx, &csipbv1.GetPluginInfoRequest{})
...@@ -220,11 +256,8 @@ func (c *client) PluginGetInfo(ctx context.Context) (string, string, error) { ...@@ -220,11 +256,8 @@ func (c *client) PluginGetInfo(ctx context.Context) (string, string, error) {
} }
func (c *client) PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySet, error) { func (c *client) PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySet, error) {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return nil, fmt.Errorf("Client not initialized") return nil, err
}
if c.identityClient == nil {
return nil, fmt.Errorf("Client not initialized")
} }
// note: no grpc retries needed here, as this is called in // note: no grpc retries needed here, as this is called in
...@@ -243,11 +276,8 @@ func (c *client) PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySe ...@@ -243,11 +276,8 @@ func (c *client) PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySe
// //
func (c *client) ControllerGetCapabilities(ctx context.Context) (*ControllerCapabilitySet, error) { func (c *client) ControllerGetCapabilities(ctx context.Context) (*ControllerCapabilitySet, error) {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return nil, fmt.Errorf("Client not initialized") return nil, err
}
if c.controllerClient == nil {
return nil, fmt.Errorf("controllerClient not initialized")
} }
// note: no grpc retries needed here, as this is called in // note: no grpc retries needed here, as this is called in
...@@ -262,11 +292,8 @@ func (c *client) ControllerGetCapabilities(ctx context.Context) (*ControllerCapa ...@@ -262,11 +292,8 @@ func (c *client) ControllerGetCapabilities(ctx context.Context) (*ControllerCapa
} }
func (c *client) ControllerPublishVolume(ctx context.Context, req *ControllerPublishVolumeRequest, opts ...grpc.CallOption) (*ControllerPublishVolumeResponse, error) { func (c *client) ControllerPublishVolume(ctx context.Context, req *ControllerPublishVolumeRequest, opts ...grpc.CallOption) (*ControllerPublishVolumeResponse, error) {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return nil, fmt.Errorf("Client not initialized") return nil, err
}
if c.controllerClient == nil {
return nil, fmt.Errorf("controllerClient not initialized")
} }
err := req.Validate() err := req.Validate()
...@@ -304,11 +331,8 @@ func (c *client) ControllerPublishVolume(ctx context.Context, req *ControllerPub ...@@ -304,11 +331,8 @@ func (c *client) ControllerPublishVolume(ctx context.Context, req *ControllerPub
} }
func (c *client) ControllerUnpublishVolume(ctx context.Context, req *ControllerUnpublishVolumeRequest, opts ...grpc.CallOption) (*ControllerUnpublishVolumeResponse, error) { func (c *client) ControllerUnpublishVolume(ctx context.Context, req *ControllerUnpublishVolumeRequest, opts ...grpc.CallOption) (*ControllerUnpublishVolumeResponse, error) {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return nil, fmt.Errorf("Client not initialized") return nil, err
}
if c.controllerClient == nil {
return nil, fmt.Errorf("controllerClient not initialized")
} }
err := req.Validate() err := req.Validate()
if err != nil { if err != nil {
...@@ -337,13 +361,9 @@ func (c *client) ControllerUnpublishVolume(ctx context.Context, req *ControllerU ...@@ -337,13 +361,9 @@ func (c *client) ControllerUnpublishVolume(ctx context.Context, req *ControllerU
} }
func (c *client) ControllerValidateCapabilities(ctx context.Context, req *ControllerValidateVolumeRequest, opts ...grpc.CallOption) error { func (c *client) ControllerValidateCapabilities(ctx context.Context, req *ControllerValidateVolumeRequest, opts ...grpc.CallOption) error {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return fmt.Errorf("Client not initialized") return err
}
if c.controllerClient == nil {
return fmt.Errorf("controllerClient not initialized")
} }
if req.ExternalID == "" { if req.ExternalID == "" {
return fmt.Errorf("missing volume ID") return fmt.Errorf("missing volume ID")
} }
...@@ -390,6 +410,10 @@ func (c *client) ControllerValidateCapabilities(ctx context.Context, req *Contro ...@@ -390,6 +410,10 @@ func (c *client) ControllerValidateCapabilities(ctx context.Context, req *Contro
} }
func (c *client) ControllerCreateVolume(ctx context.Context, req *ControllerCreateVolumeRequest, opts ...grpc.CallOption) (*ControllerCreateVolumeResponse, error) { func (c *client) ControllerCreateVolume(ctx context.Context, req *ControllerCreateVolumeRequest, opts ...grpc.CallOption) (*ControllerCreateVolumeResponse, error) {
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}
err := req.Validate() err := req.Validate()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -433,6 +457,10 @@ func (c *client) ControllerCreateVolume(ctx context.Context, req *ControllerCrea ...@@ -433,6 +457,10 @@ func (c *client) ControllerCreateVolume(ctx context.Context, req *ControllerCrea
} }
func (c *client) ControllerListVolumes(ctx context.Context, req *ControllerListVolumesRequest, opts ...grpc.CallOption) (*ControllerListVolumesResponse, error) { func (c *client) ControllerListVolumes(ctx context.Context, req *ControllerListVolumesRequest, opts ...grpc.CallOption) (*ControllerListVolumesResponse, error) {
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}
err := req.Validate() err := req.Validate()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -455,6 +483,10 @@ func (c *client) ControllerListVolumes(ctx context.Context, req *ControllerListV ...@@ -455,6 +483,10 @@ func (c *client) ControllerListVolumes(ctx context.Context, req *ControllerListV
} }
func (c *client) ControllerDeleteVolume(ctx context.Context, req *ControllerDeleteVolumeRequest, opts ...grpc.CallOption) error { func (c *client) ControllerDeleteVolume(ctx context.Context, req *ControllerDeleteVolumeRequest, opts ...grpc.CallOption) error {
if err := c.ensureConnected(ctx); err != nil {
return err
}
err := req.Validate() err := req.Validate()
if err != nil { if err != nil {
return err return err
...@@ -552,6 +584,10 @@ NEXT_CAP: ...@@ -552,6 +584,10 @@ NEXT_CAP:
} }
func (c *client) ControllerCreateSnapshot(ctx context.Context, req *ControllerCreateSnapshotRequest, opts ...grpc.CallOption) (*ControllerCreateSnapshotResponse, error) { func (c *client) ControllerCreateSnapshot(ctx context.Context, req *ControllerCreateSnapshotRequest, opts ...grpc.CallOption) (*ControllerCreateSnapshotResponse, error) {
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}
err := req.Validate() err := req.Validate()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -596,6 +632,10 @@ func (c *client) ControllerCreateSnapshot(ctx context.Context, req *ControllerCr ...@@ -596,6 +632,10 @@ func (c *client) ControllerCreateSnapshot(ctx context.Context, req *ControllerCr
} }
func (c *client) ControllerDeleteSnapshot(ctx context.Context, req *ControllerDeleteSnapshotRequest, opts ...grpc.CallOption) error { func (c *client) ControllerDeleteSnapshot(ctx context.Context, req *ControllerDeleteSnapshotRequest, opts ...grpc.CallOption) error {
if err := c.ensureConnected(ctx); err != nil {
return err
}
err := req.Validate() err := req.Validate()
if err != nil { if err != nil {
return err return err
...@@ -626,6 +666,10 @@ func (c *client) ControllerDeleteSnapshot(ctx context.Context, req *ControllerDe ...@@ -626,6 +666,10 @@ func (c *client) ControllerDeleteSnapshot(ctx context.Context, req *ControllerDe
} }
func (c *client) ControllerListSnapshots(ctx context.Context, req *ControllerListSnapshotsRequest, opts ...grpc.CallOption) (*ControllerListSnapshotsResponse, error) { func (c *client) ControllerListSnapshots(ctx context.Context, req *ControllerListSnapshotsRequest, opts ...grpc.CallOption) (*ControllerListSnapshotsResponse, error) {
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}
err := req.Validate() err := req.Validate()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -657,11 +701,8 @@ func (c *client) ControllerListSnapshots(ctx context.Context, req *ControllerLis ...@@ -657,11 +701,8 @@ func (c *client) ControllerListSnapshots(ctx context.Context, req *ControllerLis
// //
func (c *client) NodeGetCapabilities(ctx context.Context) (*NodeCapabilitySet, error) { func (c *client) NodeGetCapabilities(ctx context.Context) (*NodeCapabilitySet, error) {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return nil, fmt.Errorf("Client not initialized") return nil, err
}
if c.nodeClient == nil {
return nil, fmt.Errorf("Client not initialized")
} }
// note: no grpc retries needed here, as this is called in // note: no grpc retries needed here, as this is called in
...@@ -675,11 +716,8 @@ func (c *client) NodeGetCapabilities(ctx context.Context) (*NodeCapabilitySet, e ...@@ -675,11 +716,8 @@ func (c *client) NodeGetCapabilities(ctx context.Context) (*NodeCapabilitySet, e
} }
func (c *client) NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error) { func (c *client) NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error) {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return nil, fmt.Errorf("Client not initialized") return nil, err
}
if c.nodeClient == nil {
return nil, fmt.Errorf("Client not initialized")
} }
result := &NodeGetInfoResponse{} result := &NodeGetInfoResponse{}
...@@ -706,11 +744,8 @@ func (c *client) NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error) ...@@ -706,11 +744,8 @@ func (c *client) NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error)
} }
func (c *client) NodeStageVolume(ctx context.Context, req *NodeStageVolumeRequest, opts ...grpc.CallOption) error { func (c *client) NodeStageVolume(ctx context.Context, req *NodeStageVolumeRequest, opts ...grpc.CallOption) error {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return fmt.Errorf("Client not initialized") return err
}
if c.nodeClient == nil {
return fmt.Errorf("Client not initialized")
} }
err := req.Validate() err := req.Validate()
if err != nil { if err != nil {
...@@ -741,11 +776,8 @@ func (c *client) NodeStageVolume(ctx context.Context, req *NodeStageVolumeReques ...@@ -741,11 +776,8 @@ func (c *client) NodeStageVolume(ctx context.Context, req *NodeStageVolumeReques
} }
func (c *client) NodeUnstageVolume(ctx context.Context, volumeID string, stagingTargetPath string, opts ...grpc.CallOption) error { func (c *client) NodeUnstageVolume(ctx context.Context, volumeID string, stagingTargetPath string, opts ...grpc.CallOption) error {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return fmt.Errorf("Client not initialized") return err
}
if c.nodeClient == nil {
return fmt.Errorf("Client not initialized")
} }
// These errors should not be returned during production use but exist as aids // These errors should not be returned during production use but exist as aids
// during Nomad development // during Nomad development
...@@ -779,13 +811,9 @@ func (c *client) NodeUnstageVolume(ctx context.Context, volumeID string, staging ...@@ -779,13 +811,9 @@ func (c *client) NodeUnstageVolume(ctx context.Context, volumeID string, staging
} }
func (c *client) NodePublishVolume(ctx context.Context, req *NodePublishVolumeRequest, opts ...grpc.CallOption) error { func (c *client) NodePublishVolume(ctx context.Context, req *NodePublishVolumeRequest, opts ...grpc.CallOption) error {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return fmt.Errorf("Client not initialized") return err
}
if c.nodeClient == nil {
return fmt.Errorf("Client not initialized")
} }
if err := req.Validate(); err != nil { if err := req.Validate(); err != nil {
return fmt.Errorf("validation error: %v", err) return fmt.Errorf("validation error: %v", err)
} }
...@@ -813,13 +841,9 @@ func (c *client) NodePublishVolume(ctx context.Context, req *NodePublishVolumeRe ...@@ -813,13 +841,9 @@ func (c *client) NodePublishVolume(ctx context.Context, req *NodePublishVolumeRe
} }
func (c *client) NodeUnpublishVolume(ctx context.Context, volumeID, targetPath string, opts ...grpc.CallOption) error { func (c *client) NodeUnpublishVolume(ctx context.Context, volumeID, targetPath string, opts ...grpc.CallOption) error {
if c == nil { if err := c.ensureConnected(ctx); err != nil {
return fmt.Errorf("Client not initialized") return err
}
if c.nodeClient == nil {
return fmt.Errorf("Client not initialized")
} }
// These errors should not be returned during production use but exist as aids // These errors should not be returned during production use but exist as aids
// during Nomad development // during Nomad development
if volumeID == "" { if volumeID == "" {
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"path/filepath"
"testing" "testing"
csipbv1 "github.com/container-storage-interface/spec/lib/go/csi" csipbv1 "github.com/container-storage-interface/spec/lib/go/csi"
...@@ -11,15 +12,26 @@ import ( ...@@ -11,15 +12,26 @@ import (
"github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs"
fake "github.com/hashicorp/nomad/plugins/csi/testing" fake "github.com/hashicorp/nomad/plugins/csi/testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
func newTestClient() (*fake.IdentityClient, *fake.ControllerClient, *fake.NodeClient, CSIPlugin) { func newTestClient(t *testing.T) (*fake.IdentityClient, *fake.ControllerClient, *fake.NodeClient, CSIPlugin) {
ic := fake.NewIdentityClient() ic := fake.NewIdentityClient()
cc := fake.NewControllerClient() cc := fake.NewControllerClient()
nc := fake.NewNodeClient() nc := fake.NewNodeClient()
// we've set this as non-blocking so it won't connect to the
// socket unless a RPC is invoked
conn, err := grpc.DialContext(context.Background(),
filepath.Join(t.TempDir(), "csi.sock"), grpc.WithInsecure())
if err != nil {
t.Errorf("failed: %v", err)
}
client := &client{ client := &client{
conn: conn,
identityClient: ic, identityClient: ic,
controllerClient: cc, controllerClient: cc,
nodeClient: nc, nodeClient: nc,
...@@ -69,7 +81,7 @@ func TestClient_RPC_PluginProbe(t *testing.T) { ...@@ -69,7 +81,7 @@ func TestClient_RPC_PluginProbe(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
ic, _, _, client := newTestClient() ic, _, _, client := newTestClient(t)
defer client.Close() defer client.Close()
ic.NextErr = tc.ResponseErr ic.NextErr = tc.ResponseErr
...@@ -121,7 +133,7 @@ func TestClient_RPC_PluginInfo(t *testing.T) { ...@@ -121,7 +133,7 @@ func TestClient_RPC_PluginInfo(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
ic, _, _, client := newTestClient() ic, _, _, client := newTestClient(t)
defer client.Close() defer client.Close()
ic.NextErr = tc.ResponseErr ic.NextErr = tc.ResponseErr
...@@ -186,7 +198,7 @@ func TestClient_RPC_PluginGetCapabilities(t *testing.T) { ...@@ -186,7 +198,7 @@ func TestClient_RPC_PluginGetCapabilities(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
ic, _, _, client := newTestClient() ic, _, _, client := newTestClient(t)
defer client.Close() defer client.Close()
ic.NextErr = tc.ResponseErr ic.NextErr = tc.ResponseErr
...@@ -284,7 +296,7 @@ func TestClient_RPC_ControllerGetCapabilities(t *testing.T) { ...@@ -284,7 +296,7 @@ func TestClient_RPC_ControllerGetCapabilities(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
cc.NextErr = tc.ResponseErr cc.NextErr = tc.ResponseErr
...@@ -342,7 +354,7 @@ func TestClient_RPC_NodeGetCapabilities(t *testing.T) { ...@@ -342,7 +354,7 @@ func TestClient_RPC_NodeGetCapabilities(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, _, nc, client := newTestClient() _, _, nc, client := newTestClient(t)
defer client.Close() defer client.Close()
nc.NextErr = tc.ResponseErr nc.NextErr = tc.ResponseErr
...@@ -407,7 +419,7 @@ func TestClient_RPC_ControllerPublishVolume(t *testing.T) { ...@@ -407,7 +419,7 @@ func TestClient_RPC_ControllerPublishVolume(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
cc.NextErr = tc.ResponseErr cc.NextErr = tc.ResponseErr
...@@ -453,7 +465,7 @@ func TestClient_RPC_ControllerUnpublishVolume(t *testing.T) { ...@@ -453,7 +465,7 @@ func TestClient_RPC_ControllerUnpublishVolume(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
cc.NextErr = tc.ResponseErr cc.NextErr = tc.ResponseErr
...@@ -661,7 +673,7 @@ func TestClient_RPC_ControllerValidateVolume(t *testing.T) { ...@@ -661,7 +673,7 @@ func TestClient_RPC_ControllerValidateVolume(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
requestedCaps := []*VolumeCapability{{ requestedCaps := []*VolumeCapability{{
...@@ -758,7 +770,7 @@ func TestClient_RPC_ControllerCreateVolume(t *testing.T) { ...@@ -758,7 +770,7 @@ func TestClient_RPC_ControllerCreateVolume(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
req := &ControllerCreateVolumeRequest{ req := &ControllerCreateVolumeRequest{
...@@ -828,7 +840,7 @@ func TestClient_RPC_ControllerDeleteVolume(t *testing.T) { ...@@ -828,7 +840,7 @@ func TestClient_RPC_ControllerDeleteVolume(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
cc.NextErr = tc.ResponseErr cc.NextErr = tc.ResponseErr
...@@ -871,7 +883,7 @@ func TestClient_RPC_ControllerListVolume(t *testing.T) { ...@@ -871,7 +883,7 @@ func TestClient_RPC_ControllerListVolume(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
cc.NextErr = tc.ResponseErr cc.NextErr = tc.ResponseErr
...@@ -979,7 +991,7 @@ func TestClient_RPC_ControllerCreateSnapshot(t *testing.T) { ...@@ -979,7 +991,7 @@ func TestClient_RPC_ControllerCreateSnapshot(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
cc.NextErr = tc.ResponseErr cc.NextErr = tc.ResponseErr
...@@ -1025,7 +1037,7 @@ func TestClient_RPC_ControllerDeleteSnapshot(t *testing.T) { ...@@ -1025,7 +1037,7 @@ func TestClient_RPC_ControllerDeleteSnapshot(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
cc.NextErr = tc.ResponseErr cc.NextErr = tc.ResponseErr
...@@ -1068,7 +1080,7 @@ func TestClient_RPC_ControllerListSnapshots(t *testing.T) { ...@@ -1068,7 +1080,7 @@ func TestClient_RPC_ControllerListSnapshots(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, cc, _, client := newTestClient() _, cc, _, client := newTestClient(t)
defer client.Close() defer client.Close()
cc.NextErr = tc.ResponseErr cc.NextErr = tc.ResponseErr
...@@ -1124,7 +1136,7 @@ func TestClient_RPC_NodeStageVolume(t *testing.T) { ...@@ -1124,7 +1136,7 @@ func TestClient_RPC_NodeStageVolume(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, _, nc, client := newTestClient() _, _, nc, client := newTestClient(t)
defer client.Close() defer client.Close()
nc.NextErr = tc.ResponseErr nc.NextErr = tc.ResponseErr
...@@ -1165,7 +1177,7 @@ func TestClient_RPC_NodeUnstageVolume(t *testing.T) { ...@@ -1165,7 +1177,7 @@ func TestClient_RPC_NodeUnstageVolume(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, _, nc, client := newTestClient() _, _, nc, client := newTestClient(t)
defer client.Close() defer client.Close()
nc.NextErr = tc.ResponseErr nc.NextErr = tc.ResponseErr
...@@ -1221,7 +1233,7 @@ func TestClient_RPC_NodePublishVolume(t *testing.T) { ...@@ -1221,7 +1233,7 @@ func TestClient_RPC_NodePublishVolume(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, _, nc, client := newTestClient() _, _, nc, client := newTestClient(t)
defer client.Close() defer client.Close()
nc.NextErr = tc.ResponseErr nc.NextErr = tc.ResponseErr
...@@ -1274,7 +1286,7 @@ func TestClient_RPC_NodeUnpublishVolume(t *testing.T) { ...@@ -1274,7 +1286,7 @@ func TestClient_RPC_NodeUnpublishVolume(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
_, _, nc, client := newTestClient() _, _, nc, client := newTestClient(t)
defer client.Close() defer client.Close()
nc.NextErr = tc.ResponseErr nc.NextErr = tc.ResponseErr
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment