diff --git a/.changelog/12057.txt b/.changelog/12057.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8e74a8fe045be14c4471961125f7b9b7185ddca6
--- /dev/null
+++ b/.changelog/12057.txt
@@ -0,0 +1,7 @@
+```release-note:bug
+csi: Fixed a bug where the plugin instance manager would not retry the initial gRPC connection to plugins
+```
+
+```release-note:bug
+csi: Fixed a bug where the plugin supervisor would not restart the task if it failed to connect to the plugin
+```
diff --git a/client/allocrunner/taskrunner/plugin_supervisor_hook.go b/client/allocrunner/taskrunner/plugin_supervisor_hook.go
index 9adb8d53bf509c487c2c7f22fb1cf23490f7287a..679fb2f73967e66390ba461163911468c5ee22f3 100644
--- a/client/allocrunner/taskrunner/plugin_supervisor_hook.go
+++ b/client/allocrunner/taskrunner/plugin_supervisor_hook.go
@@ -38,6 +38,7 @@ type csiPluginSupervisorHook struct {
 
 	// eventEmitter is used to emit events to the task
 	eventEmitter ti.EventEmitter
+	lifecycle    ti.TaskLifecycle
 
 	shutdownCtx      context.Context
 	shutdownCancelFn context.CancelFunc
@@ -54,6 +55,7 @@ type csiPluginSupervisorHookConfig struct {
 	clientStateDirPath string
 	events             ti.EventEmitter
 	runner             *TaskRunner
+	lifecycle          ti.TaskLifecycle
 	capabilities       *drivers.Capabilities
 	logger             hclog.Logger
 }
@@ -90,6 +92,7 @@ func newCSIPluginSupervisorHook(config *csiPluginSupervisorHookConfig) *csiPlugi
 	hook := &csiPluginSupervisorHook{
 		alloc:            config.runner.Alloc(),
 		runner:           config.runner,
+		lifecycle:        config.lifecycle,
 		logger:           config.logger,
 		task:             task,
 		mountPoint:       pluginRoot,
@@ -201,27 +204,41 @@ func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) {
 	}()
 
 	socketPath := filepath.Join(h.mountPoint, structs.CSISocketName)
+
+	client := csi.NewClient(socketPath, h.logger.Named("csi_client").With(
+		"plugin.name", h.task.CSIPluginConfig.ID,
+		"plugin.type", h.task.CSIPluginConfig.Type))
+	defer client.Close()
+
 	t := time.NewTimer(0)
 
+	// We're in Poststart at this point, so if we can't connect within
+	// this deadline, assume it's broken so we can restart the task
+	startCtx, startCancelFn := context.WithTimeout(ctx, 30*time.Second)
+	defer startCancelFn()
+
+	var err error
+	var pluginHealthy bool
+
 	// Step 1: Wait for the plugin to initially become available.
 WAITFORREADY:
 	for {
 		select {
-		case <-ctx.Done():
+		case <-startCtx.Done():
+			h.kill(ctx, fmt.Errorf("CSI plugin failed probe: %v", err))
 			return
 		case <-t.C:
-			pluginHealthy, err := h.supervisorLoopOnce(ctx, socketPath)
+			pluginHealthy, err = h.supervisorLoopOnce(startCtx, client)
 			if err != nil || !pluginHealthy {
-				h.logger.Debug("CSI Plugin not ready", "error", err)
-
-				// Plugin is not yet returning healthy, because we want to optimise for
-				// quickly bringing a plugin online, we use a short timeout here.
-				// TODO(dani): Test with more plugins and adjust.
+				h.logger.Debug("CSI plugin not ready", "error", err)
+				// Use only a short delay here to optimize for quickly
+				// bringing up a plugin
 				t.Reset(5 * time.Second)
 				continue
 			}
 
 			// Mark the plugin as healthy in a task event
+			h.logger.Debug("CSI plugin is ready")
 			h.previousHealthState = pluginHealthy
 			event := structs.NewTaskEvent(structs.TaskPluginHealthy)
 			event.SetMessage(fmt.Sprintf("plugin: %s", h.task.CSIPluginConfig.ID))
@@ -232,15 +249,14 @@ WAITFORREADY:
 	}
 
 	// Step 2: Register the plugin with the catalog.
-	deregisterPluginFn, err := h.registerPlugin(socketPath)
+	deregisterPluginFn, err := h.registerPlugin(client, socketPath)
 	if err != nil {
-		h.logger.Error("CSI Plugin registration failed", "error", err)
-		event := structs.NewTaskEvent(structs.TaskPluginUnhealthy)
-		event.SetMessage(fmt.Sprintf("failed to register plugin: %s, reason: %v", h.task.CSIPluginConfig.ID, err))
-		h.eventEmitter.EmitEvent(event)
+		h.kill(ctx, fmt.Errorf("CSI plugin failed to register: %v", err))
+		return
 	}
 
-	// Step 3: Start the lightweight supervisor loop.
+	// Step 3: Start the lightweight supervisor loop. At this point, failures
+	// don't cause the task to restart
 	t.Reset(0)
 	for {
 		select {
@@ -249,9 +265,9 @@ WAITFORREADY:
 			deregisterPluginFn()
 			return
 		case <-t.C:
-			pluginHealthy, err := h.supervisorLoopOnce(ctx, socketPath)
+			pluginHealthy, err := h.supervisorLoopOnce(ctx, client)
 			if err != nil {
-				h.logger.Error("CSI Plugin fingerprinting failed", "error", err)
+				h.logger.Error("CSI plugin fingerprinting failed", "error", err)
 			}
 
 			// The plugin has transitioned to a healthy state. Emit an event.
@@ -265,7 +281,7 @@ WAITFORREADY:
 			if h.previousHealthState && !pluginHealthy {
 				event := structs.NewTaskEvent(structs.TaskPluginUnhealthy)
 				if err != nil {
-					event.SetMessage(fmt.Sprintf("error: %v", err))
+					event.SetMessage(fmt.Sprintf("Error: %v", err))
 				} else {
 					event.SetMessage("Unknown Reason")
 				}
@@ -281,16 +297,9 @@ WAITFORREADY:
 	}
 }
 
-func (h *csiPluginSupervisorHook) registerPlugin(socketPath string) (func(), error) {
-
+func (h *csiPluginSupervisorHook) registerPlugin(client csi.CSIPlugin, socketPath string) (func(), error) {
 	// At this point we know the plugin is ready and we can fingerprint it
 	// to get its vendor name and version
-	client, err := csi.NewClient(socketPath, h.logger.Named("csi_client").With("plugin.name", h.task.CSIPluginConfig.ID, "plugin.type", h.task.CSIPluginConfig.Type))
-	if err != nil {
-		return nil, fmt.Errorf("failed to create csi client: %v", err)
-	}
-	defer client.Close()
-
 	info, err := client.PluginInfo()
 	if err != nil {
 		return nil, fmt.Errorf("failed to probe plugin: %v", err)
@@ -354,21 +363,13 @@ func (h *csiPluginSupervisorHook) registerPlugin(socketPath string) (func(), err
 	}, nil
 }
 
-func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, socketPath string) (bool, error) {
-	_, err := os.Stat(socketPath)
-	if err != nil {
-		return false, fmt.Errorf("failed to stat socket: %v", err)
-	}
+func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, client csi.CSIPlugin) (bool, error) {
+	probeCtx, probeCancelFn := context.WithTimeout(ctx, 5*time.Second)
+	defer probeCancelFn()
 
-	client, err := csi.NewClient(socketPath, h.logger.Named("csi_client").With("plugin.name", h.task.CSIPluginConfig.ID, "plugin.type", h.task.CSIPluginConfig.Type))
+	healthy, err := client.PluginProbe(probeCtx)
 	if err != nil {
-		return false, fmt.Errorf("failed to create csi client: %v", err)
-	}
-	defer client.Close()
-
-	healthy, err := client.PluginProbe(ctx)
-	if err != nil {
-		return false, fmt.Errorf("failed to probe plugin: %v", err)
+		return false, err
 	}
 
 	return healthy, nil
@@ -387,6 +388,21 @@ func (h *csiPluginSupervisorHook) Stop(_ context.Context, req *interfaces.TaskSt
 	return nil
 }
 
+func (h *csiPluginSupervisorHook) kill(ctx context.Context, reason error) {
+	h.logger.Error("killing task because plugin failed", "error", reason)
+	event := structs.NewTaskEvent(structs.TaskPluginUnhealthy)
+	event.SetMessage(fmt.Sprintf("Error: %v", reason.Error()))
+	h.eventEmitter.EmitEvent(event)
+
+	if err := h.lifecycle.Kill(ctx,
+		structs.NewTaskEvent(structs.TaskKilling).
+			SetFailsTask().
+			SetDisplayMessage("CSI plugin did not become healthy before timeout"),
+	); err != nil {
+		h.logger.Error("failed to kill task", "kill_reason", reason, "error", err)
+	}
+}
+
 func ensureMountpointInserted(mounts []*drivers.MountConfig, mount *drivers.MountConfig) []*drivers.MountConfig {
 	for _, mnt := range mounts {
 		if mnt.IsEqual(mount) {
diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go
index b3f44a8b9855e02e14592500c18d18e79f86d6cc..62ff26c4b62566f8b7b9d33246610384b01d85d4 100644
--- a/client/allocrunner/taskrunner/task_runner_hooks.go
+++ b/client/allocrunner/taskrunner/task_runner_hooks.go
@@ -76,6 +76,7 @@ func (tr *TaskRunner) initHooks() {
 				clientStateDirPath: tr.clientConfig.StateDir,
 				events:             tr,
 				runner:             tr,
+				lifecycle:          tr,
 				capabilities:       tr.driverCapabilities,
 				logger:             hookLogger,
 			}))
diff --git a/client/client.go b/client/client.go
index 34428464aeaf58b544d8d3aa204d96dc49bef382..dd21b66c0c29034017e4c55eb2f24195bff048c7 100644
--- a/client/client.go
+++ b/client/client.go
@@ -390,11 +390,11 @@ func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxie
 	c.dynamicRegistry =
 		dynamicplugins.NewRegistry(c.stateDB, map[string]dynamicplugins.PluginDispenser{
 			dynamicplugins.PluginTypeCSIController: func(info *dynamicplugins.PluginInfo) (interface{}, error) {
-				return csi.NewClient(info.ConnectionInfo.SocketPath, logger.Named("csi_client").With("plugin.name", info.Name, "plugin.type", "controller"))
+				return csi.NewClient(info.ConnectionInfo.SocketPath, logger.Named("csi_client").With("plugin.name", info.Name, "plugin.type", "controller")), nil
 			},
 			dynamicplugins.PluginTypeCSINode: func(info *dynamicplugins.PluginInfo) (interface{}, error) {
-				return csi.NewClient(info.ConnectionInfo.SocketPath, logger.Named("csi_client").With("plugin.name", info.Name, "plugin.type", "client"))
-			}, // TODO(tgross): refactor these dispenser constructors into csimanager to tidy it up
+				return csi.NewClient(info.ConnectionInfo.SocketPath, logger.Named("csi_client").With("plugin.name", info.Name, "plugin.type", "client")), nil
+			},
 		})
 
 	// Setup the clients RPC server
diff --git a/client/pluginmanager/csimanager/instance.go b/client/pluginmanager/csimanager/instance.go
index 062b7397278381327f8afaa0a04a51d6bb2de048..3839113dcd4e374bc7a7cf6f57aef1fab5883aed 100644
--- a/client/pluginmanager/csimanager/instance.go
+++ b/client/pluginmanager/csimanager/instance.go
@@ -73,12 +73,7 @@ func newInstanceManager(logger hclog.Logger, eventer TriggerNodeEvent, updater U
 }
 
 func (i *instanceManager) run() {
-	c, err := csi.NewClient(i.info.ConnectionInfo.SocketPath, i.logger)
-	if err != nil {
-		i.logger.Error("failed to setup instance manager client", "error", err)
-		close(i.shutdownCh)
-		return
-	}
+	c := csi.NewClient(i.info.ConnectionInfo.SocketPath, i.logger)
 	i.client = c
 	i.fp.client = c
 
diff --git a/plugins/csi/client.go b/plugins/csi/client.go
index 89237dbbf5a366f91e22ed03697a8f76dcb09b99..6d34d8f55aa235d2fcd2b634caa4b9c2e2a6c17a 100644
--- a/plugins/csi/client.go
+++ b/plugins/csi/client.go
@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"math"
 	"net"
+	"os"
 	"time"
 
 	csipbv1 "github.com/container-storage-interface/spec/lib/go/csi"
@@ -88,6 +89,7 @@ type CSINodeClient interface {
 }
 
 type client struct {
+	addr             string
 	conn             *grpc.ClientConn
 	identityClient   csipbv1.IdentityClient
 	controllerClient CSIControllerClient
@@ -102,30 +104,59 @@ func (c *client) Close() error {
 	return nil
 }
 
-func NewClient(addr string, logger hclog.Logger) (CSIPlugin, error) {
-	if addr == "" {
-		return nil, fmt.Errorf("address is empty")
+func NewClient(addr string, logger hclog.Logger) CSIPlugin {
+	return &client{
+		addr:   addr,
+		logger: logger,
 	}
+}
 
-	conn, err := newGrpcConn(addr, logger)
-	if err != nil {
-		return nil, err
+func (c *client) ensureConnected(ctx context.Context) error {
+	if c == nil {
+		return fmt.Errorf("client not initialized")
+	}
+	if c.conn != nil {
+		return nil
+	}
+	if c.addr == "" {
+		return fmt.Errorf("address is empty")
+	}
+	var conn *grpc.ClientConn
+	var err error
+	t := time.NewTimer(0)
+	for {
+		select {
+		case <-ctx.Done():
+			return fmt.Errorf("timeout while connecting to gRPC socket: %v", err)
+		case <-t.C:
+			_, err = os.Stat(c.addr)
+			if err != nil {
+				err = fmt.Errorf("failed to stat socket: %v", err)
+				t.Reset(5 * time.Second)
+				continue
+			}
+			conn, err = newGrpcConn(c.addr, c.logger)
+			if err != nil {
+				err = fmt.Errorf("failed to create gRPC connection: %v", err)
+				t.Reset(time.Second * 5)
+				continue
+			}
+			c.conn = conn
+			c.identityClient = csipbv1.NewIdentityClient(conn)
+			c.controllerClient = csipbv1.NewControllerClient(conn)
+			c.nodeClient = csipbv1.NewNodeClient(conn)
+			return nil
+		}
 	}
-
-	return &client{
-		conn:             conn,
-		identityClient:   csipbv1.NewIdentityClient(conn),
-		controllerClient: csipbv1.NewControllerClient(conn),
-		nodeClient:       csipbv1.NewNodeClient(conn),
-		logger:           logger,
-	}, nil
 }
 
 func newGrpcConn(addr string, logger hclog.Logger) (*grpc.ClientConn, error) {
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
+	// after DialContext returns w/ initial connection, closing this
+	// context is a no-op
+	connectCtx, cancel := context.WithTimeout(context.Background(), time.Second*1)
 	defer cancel()
 	conn, err := grpc.DialContext(
-		ctx,
+		connectCtx,
 		addr,
 		grpc.WithBlock(),
 		grpc.WithInsecure(),
@@ -146,10 +177,14 @@ func newGrpcConn(addr string, logger hclog.Logger) (*grpc.ClientConn, error) {
 // PluginInfo describes the type and version of a plugin as required by the nomad
 // base.BasePlugin interface.
 func (c *client) PluginInfo() (*base.PluginInfoResponse, error) {
-	// note: no grpc retries needed here, as this is called in
-	// fingerprinting and will get retried by the caller.
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
 	defer cancel()
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
+	}
+
+	// note: no grpc retries needed here, as this is called in
+	// fingerprinting and will get retried by the caller.
 	name, version, err := c.PluginGetInfo(ctx)
 	if err != nil {
 		return nil, err
@@ -176,6 +211,10 @@ func (c *client) SetConfig(_ *base.Config) error {
 }
 
 func (c *client) PluginProbe(ctx context.Context) (bool, error) {
+	if err := c.ensureConnected(ctx); err != nil {
+		return false, err
+	}
+
 	// note: no grpc retries should be done here
 	req, err := c.identityClient.Probe(ctx, &csipbv1.ProbeRequest{})
 	if err != nil {
@@ -198,11 +237,8 @@ func (c *client) PluginProbe(ctx context.Context) (bool, error) {
 }
 
 func (c *client) PluginGetInfo(ctx context.Context) (string, string, error) {
-	if c == nil {
-		return "", "", fmt.Errorf("Client not initialized")
-	}
-	if c.identityClient == nil {
-		return "", "", fmt.Errorf("Client not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return "", "", err
 	}
 
 	resp, err := c.identityClient.GetPluginInfo(ctx, &csipbv1.GetPluginInfoRequest{})
@@ -220,11 +256,8 @@ func (c *client) PluginGetInfo(ctx context.Context) (string, string, error) {
 }
 
 func (c *client) PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySet, error) {
-	if c == nil {
-		return nil, fmt.Errorf("Client not initialized")
-	}
-	if c.identityClient == nil {
-		return nil, fmt.Errorf("Client not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
 	}
 
 	// note: no grpc retries needed here, as this is called in
@@ -243,11 +276,8 @@ func (c *client) PluginGetCapabilities(ctx context.Context) (*PluginCapabilitySe
 //
 
 func (c *client) ControllerGetCapabilities(ctx context.Context) (*ControllerCapabilitySet, error) {
-	if c == nil {
-		return nil, fmt.Errorf("Client not initialized")
-	}
-	if c.controllerClient == nil {
-		return nil, fmt.Errorf("controllerClient not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
 	}
 
 	// note: no grpc retries needed here, as this is called in
@@ -262,11 +292,8 @@ func (c *client) ControllerGetCapabilities(ctx context.Context) (*ControllerCapa
 }
 
 func (c *client) ControllerPublishVolume(ctx context.Context, req *ControllerPublishVolumeRequest, opts ...grpc.CallOption) (*ControllerPublishVolumeResponse, error) {
-	if c == nil {
-		return nil, fmt.Errorf("Client not initialized")
-	}
-	if c.controllerClient == nil {
-		return nil, fmt.Errorf("controllerClient not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
 	}
 
 	err := req.Validate()
@@ -304,11 +331,8 @@ func (c *client) ControllerPublishVolume(ctx context.Context, req *ControllerPub
 }
 
 func (c *client) ControllerUnpublishVolume(ctx context.Context, req *ControllerUnpublishVolumeRequest, opts ...grpc.CallOption) (*ControllerUnpublishVolumeResponse, error) {
-	if c == nil {
-		return nil, fmt.Errorf("Client not initialized")
-	}
-	if c.controllerClient == nil {
-		return nil, fmt.Errorf("controllerClient not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
 	}
 	err := req.Validate()
 	if err != nil {
@@ -337,13 +361,9 @@ func (c *client) ControllerUnpublishVolume(ctx context.Context, req *ControllerU
 }
 
 func (c *client) ControllerValidateCapabilities(ctx context.Context, req *ControllerValidateVolumeRequest, opts ...grpc.CallOption) error {
-	if c == nil {
-		return fmt.Errorf("Client not initialized")
-	}
-	if c.controllerClient == nil {
-		return fmt.Errorf("controllerClient not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return err
 	}
-
 	if req.ExternalID == "" {
 		return fmt.Errorf("missing volume ID")
 	}
@@ -390,6 +410,10 @@ func (c *client) ControllerValidateCapabilities(ctx context.Context, req *Contro
 }
 
 func (c *client) ControllerCreateVolume(ctx context.Context, req *ControllerCreateVolumeRequest, opts ...grpc.CallOption) (*ControllerCreateVolumeResponse, error) {
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
+	}
+
 	err := req.Validate()
 	if err != nil {
 		return nil, err
@@ -433,6 +457,10 @@ func (c *client) ControllerCreateVolume(ctx context.Context, req *ControllerCrea
 }
 
 func (c *client) ControllerListVolumes(ctx context.Context, req *ControllerListVolumesRequest, opts ...grpc.CallOption) (*ControllerListVolumesResponse, error) {
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
+	}
+
 	err := req.Validate()
 	if err != nil {
 		return nil, err
@@ -455,6 +483,10 @@ func (c *client) ControllerListVolumes(ctx context.Context, req *ControllerListV
 }
 
 func (c *client) ControllerDeleteVolume(ctx context.Context, req *ControllerDeleteVolumeRequest, opts ...grpc.CallOption) error {
+	if err := c.ensureConnected(ctx); err != nil {
+		return err
+	}
+
 	err := req.Validate()
 	if err != nil {
 		return err
@@ -552,6 +584,10 @@ NEXT_CAP:
 }
 
 func (c *client) ControllerCreateSnapshot(ctx context.Context, req *ControllerCreateSnapshotRequest, opts ...grpc.CallOption) (*ControllerCreateSnapshotResponse, error) {
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
+	}
+
 	err := req.Validate()
 	if err != nil {
 		return nil, err
@@ -596,6 +632,10 @@ func (c *client) ControllerCreateSnapshot(ctx context.Context, req *ControllerCr
 }
 
 func (c *client) ControllerDeleteSnapshot(ctx context.Context, req *ControllerDeleteSnapshotRequest, opts ...grpc.CallOption) error {
+	if err := c.ensureConnected(ctx); err != nil {
+		return err
+	}
+
 	err := req.Validate()
 	if err != nil {
 		return err
@@ -626,6 +666,10 @@ func (c *client) ControllerDeleteSnapshot(ctx context.Context, req *ControllerDe
 }
 
 func (c *client) ControllerListSnapshots(ctx context.Context, req *ControllerListSnapshotsRequest, opts ...grpc.CallOption) (*ControllerListSnapshotsResponse, error) {
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
+	}
+
 	err := req.Validate()
 	if err != nil {
 		return nil, err
@@ -657,11 +701,8 @@ func (c *client) ControllerListSnapshots(ctx context.Context, req *ControllerLis
 //
 
 func (c *client) NodeGetCapabilities(ctx context.Context) (*NodeCapabilitySet, error) {
-	if c == nil {
-		return nil, fmt.Errorf("Client not initialized")
-	}
-	if c.nodeClient == nil {
-		return nil, fmt.Errorf("Client not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
 	}
 
 	// note: no grpc retries needed here, as this is called in
@@ -675,11 +716,8 @@ func (c *client) NodeGetCapabilities(ctx context.Context) (*NodeCapabilitySet, e
 }
 
 func (c *client) NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error) {
-	if c == nil {
-		return nil, fmt.Errorf("Client not initialized")
-	}
-	if c.nodeClient == nil {
-		return nil, fmt.Errorf("Client not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return nil, err
 	}
 
 	result := &NodeGetInfoResponse{}
@@ -706,11 +744,8 @@ func (c *client) NodeGetInfo(ctx context.Context) (*NodeGetInfoResponse, error)
 }
 
 func (c *client) NodeStageVolume(ctx context.Context, req *NodeStageVolumeRequest, opts ...grpc.CallOption) error {
-	if c == nil {
-		return fmt.Errorf("Client not initialized")
-	}
-	if c.nodeClient == nil {
-		return fmt.Errorf("Client not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return err
 	}
 	err := req.Validate()
 	if err != nil {
@@ -741,11 +776,8 @@ func (c *client) NodeStageVolume(ctx context.Context, req *NodeStageVolumeReques
 }
 
 func (c *client) NodeUnstageVolume(ctx context.Context, volumeID string, stagingTargetPath string, opts ...grpc.CallOption) error {
-	if c == nil {
-		return fmt.Errorf("Client not initialized")
-	}
-	if c.nodeClient == nil {
-		return fmt.Errorf("Client not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return err
 	}
 	// These errors should not be returned during production use but exist as aids
 	// during Nomad development
@@ -779,13 +811,9 @@ func (c *client) NodeUnstageVolume(ctx context.Context, volumeID string, staging
 }
 
 func (c *client) NodePublishVolume(ctx context.Context, req *NodePublishVolumeRequest, opts ...grpc.CallOption) error {
-	if c == nil {
-		return fmt.Errorf("Client not initialized")
-	}
-	if c.nodeClient == nil {
-		return fmt.Errorf("Client not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return err
 	}
-
 	if err := req.Validate(); err != nil {
 		return fmt.Errorf("validation error: %v", err)
 	}
@@ -813,13 +841,9 @@ func (c *client) NodePublishVolume(ctx context.Context, req *NodePublishVolumeRe
 }
 
 func (c *client) NodeUnpublishVolume(ctx context.Context, volumeID, targetPath string, opts ...grpc.CallOption) error {
-	if c == nil {
-		return fmt.Errorf("Client not initialized")
-	}
-	if c.nodeClient == nil {
-		return fmt.Errorf("Client not initialized")
+	if err := c.ensureConnected(ctx); err != nil {
+		return err
 	}
-
 	// These errors should not be returned during production use but exist as aids
 	// during Nomad development
 	if volumeID == "" {
diff --git a/plugins/csi/client_test.go b/plugins/csi/client_test.go
index 40da479b7dc368f93cc8386c3012c687c88be8f5..1c951b2bac961846781c0321ad142d519ef48371 100644
--- a/plugins/csi/client_test.go
+++ b/plugins/csi/client_test.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"path/filepath"
 	"testing"
 
 	csipbv1 "github.com/container-storage-interface/spec/lib/go/csi"
@@ -11,15 +12,26 @@ import (
 	"github.com/hashicorp/nomad/nomad/structs"
 	fake "github.com/hashicorp/nomad/plugins/csi/testing"
 	"github.com/stretchr/testify/require"
+	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
 )
 
-func newTestClient() (*fake.IdentityClient, *fake.ControllerClient, *fake.NodeClient, CSIPlugin) {
+func newTestClient(t *testing.T) (*fake.IdentityClient, *fake.ControllerClient, *fake.NodeClient, CSIPlugin) {
 	ic := fake.NewIdentityClient()
 	cc := fake.NewControllerClient()
 	nc := fake.NewNodeClient()
+
+	// we've set this as non-blocking so it won't connect to the
+	// socket unless a RPC is invoked
+	conn, err := grpc.DialContext(context.Background(),
+		filepath.Join(t.TempDir(), "csi.sock"), grpc.WithInsecure())
+	if err != nil {
+		t.Errorf("failed: %v", err)
+	}
+
 	client := &client{
+		conn:             conn,
 		identityClient:   ic,
 		controllerClient: cc,
 		nodeClient:       nc,
@@ -69,7 +81,7 @@ func TestClient_RPC_PluginProbe(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			ic, _, _, client := newTestClient()
+			ic, _, _, client := newTestClient(t)
 			defer client.Close()
 
 			ic.NextErr = tc.ResponseErr
@@ -121,7 +133,7 @@ func TestClient_RPC_PluginInfo(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			ic, _, _, client := newTestClient()
+			ic, _, _, client := newTestClient(t)
 			defer client.Close()
 
 			ic.NextErr = tc.ResponseErr
@@ -186,7 +198,7 @@ func TestClient_RPC_PluginGetCapabilities(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			ic, _, _, client := newTestClient()
+			ic, _, _, client := newTestClient(t)
 			defer client.Close()
 
 			ic.NextErr = tc.ResponseErr
@@ -284,7 +296,7 @@ func TestClient_RPC_ControllerGetCapabilities(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			cc.NextErr = tc.ResponseErr
@@ -342,7 +354,7 @@ func TestClient_RPC_NodeGetCapabilities(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, _, nc, client := newTestClient()
+			_, _, nc, client := newTestClient(t)
 			defer client.Close()
 
 			nc.NextErr = tc.ResponseErr
@@ -407,7 +419,7 @@ func TestClient_RPC_ControllerPublishVolume(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			cc.NextErr = tc.ResponseErr
@@ -453,7 +465,7 @@ func TestClient_RPC_ControllerUnpublishVolume(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			cc.NextErr = tc.ResponseErr
@@ -661,7 +673,7 @@ func TestClient_RPC_ControllerValidateVolume(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			requestedCaps := []*VolumeCapability{{
@@ -758,7 +770,7 @@ func TestClient_RPC_ControllerCreateVolume(t *testing.T) {
 	}
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			req := &ControllerCreateVolumeRequest{
@@ -828,7 +840,7 @@ func TestClient_RPC_ControllerDeleteVolume(t *testing.T) {
 	}
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			cc.NextErr = tc.ResponseErr
@@ -871,7 +883,7 @@ func TestClient_RPC_ControllerListVolume(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			cc.NextErr = tc.ResponseErr
@@ -979,7 +991,7 @@ func TestClient_RPC_ControllerCreateSnapshot(t *testing.T) {
 	}
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			cc.NextErr = tc.ResponseErr
@@ -1025,7 +1037,7 @@ func TestClient_RPC_ControllerDeleteSnapshot(t *testing.T) {
 	}
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			cc.NextErr = tc.ResponseErr
@@ -1068,7 +1080,7 @@ func TestClient_RPC_ControllerListSnapshots(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, cc, _, client := newTestClient()
+			_, cc, _, client := newTestClient(t)
 			defer client.Close()
 
 			cc.NextErr = tc.ResponseErr
@@ -1124,7 +1136,7 @@ func TestClient_RPC_NodeStageVolume(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, _, nc, client := newTestClient()
+			_, _, nc, client := newTestClient(t)
 			defer client.Close()
 
 			nc.NextErr = tc.ResponseErr
@@ -1165,7 +1177,7 @@ func TestClient_RPC_NodeUnstageVolume(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, _, nc, client := newTestClient()
+			_, _, nc, client := newTestClient(t)
 			defer client.Close()
 
 			nc.NextErr = tc.ResponseErr
@@ -1221,7 +1233,7 @@ func TestClient_RPC_NodePublishVolume(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, _, nc, client := newTestClient()
+			_, _, nc, client := newTestClient(t)
 			defer client.Close()
 
 			nc.NextErr = tc.ResponseErr
@@ -1274,7 +1286,7 @@ func TestClient_RPC_NodeUnpublishVolume(t *testing.T) {
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			_, _, nc, client := newTestClient()
+			_, _, nc, client := newTestClient(t)
 			defer client.Close()
 
 			nc.NextErr = tc.ResponseErr