From 179b32d426025f9cb94f57488634ee4e3d9699a3 Mon Sep 17 00:00:00 2001
From: Sander van Harmelen <sander@xanzy.io>
Date: Thu, 5 Jul 2018 21:28:29 +0200
Subject: [PATCH] Add a `CredentialsForHost` method to disco.Disco

By adding this method you now only have to pass a `*disco.Disco` object around in order to do discovery and use any configured credentials for the discovered hosts.

Of course you can also still pass around both a `*disco.Disco` and a `auth.CredentialsSource` object if there is a need or a reason for that!
---
 command/command_test.go           |  2 +-
 command/init.go                   |  2 +-
 command/meta.go                   |  7 +------
 commands.go                       |  8 ++-----
 config/module/module_test.go      |  2 +-
 config/module/storage.go          | 13 ++++--------
 config/module/storage_test.go     |  4 ++--
 configs/configload/loader.go      |  8 +------
 configs/configload/module_mgr.go  |  4 ----
 main.go                           |  5 ++++-
 registry/client.go                | 18 +++-------------
 registry/client_test.go           | 35 ++++++++++++++++---------------
 registry/test/mock_registry.go    |  6 +++---
 svchost/auth/credentials.go       |  3 +++
 svchost/auth/token_credentials.go |  5 +++++
 svchost/disco/disco.go            | 34 +++++++++++++++++++-----------
 svchost/disco/disco_test.go       | 18 ++++++++--------
 17 files changed, 80 insertions(+), 94 deletions(-)

diff --git a/command/command_test.go b/command/command_test.go
index 12d48761b7..c0a8529c6e 100644
--- a/command/command_test.go
+++ b/command/command_test.go
@@ -117,7 +117,7 @@ func testModule(t *testing.T, name string) *module.Tree {
 		t.Fatalf("err: %s", err)
 	}
 
-	s := module.NewStorage(tempDir(t), nil, nil)
+	s := module.NewStorage(tempDir(t), nil)
 	s.Mode = module.GetModeGet
 	if err := mod.Load(s); err != nil {
 		t.Fatalf("err: %s", err)
diff --git a/command/init.go b/command/init.go
index efa4b5724d..b96cdc6eca 100644
--- a/command/init.go
+++ b/command/init.go
@@ -129,7 +129,7 @@ func (c *InitCommand) Run(args []string) int {
 		)))
 		header = true
 
-		s := module.NewStorage("", c.Services, c.Credentials)
+		s := module.NewStorage("", c.Services)
 		if err := s.GetModule(path, src); err != nil {
 			c.Ui.Error(fmt.Sprintf("Error copying source module: %s", err))
 			return 1
diff --git a/command/meta.go b/command/meta.go
index 91f1008fed..f154f2d6c8 100644
--- a/command/meta.go
+++ b/command/meta.go
@@ -25,7 +25,6 @@ import (
 	"github.com/hashicorp/terraform/helper/experiment"
 	"github.com/hashicorp/terraform/helper/variables"
 	"github.com/hashicorp/terraform/helper/wrappedstreams"
-	"github.com/hashicorp/terraform/svchost/auth"
 	"github.com/hashicorp/terraform/svchost/disco"
 	"github.com/hashicorp/terraform/terraform"
 	"github.com/hashicorp/terraform/tfdiags"
@@ -51,10 +50,6 @@ type Meta struct {
 	// "terraform-native' services running at a specific user-facing hostname.
 	Services *disco.Disco
 
-	// Credentials provides access to credentials for "terraform-native"
-	// services, which are accessed by a service hostname.
-	Credentials auth.CredentialsSource
-
 	// RunningInAutomation indicates that commands are being run by an
 	// automated system rather than directly at a command prompt.
 	//
@@ -410,7 +405,7 @@ func (m *Meta) flagSet(n string) *flag.FlagSet {
 // moduleStorage returns the module.Storage implementation used to store
 // modules for commands.
 func (m *Meta) moduleStorage(root string, mode module.GetMode) *module.Storage {
-	s := module.NewStorage(filepath.Join(root, "modules"), m.Services, m.Credentials)
+	s := module.NewStorage(filepath.Join(root, "modules"), m.Services)
 	s.Ui = m.Ui
 	s.Mode = mode
 	return s
diff --git a/commands.go b/commands.go
index 3335d2cdb5..113c771eb7 100644
--- a/commands.go
+++ b/commands.go
@@ -30,15 +30,12 @@ const (
 	OutputPrefix = "o:"
 )
 
-func initCommands(config *Config) {
+func initCommands(config *Config, services *disco.Disco) {
 	var inAutomation bool
 	if v := os.Getenv(runningInAutomationEnvName); v != "" {
 		inAutomation = true
 	}
 
-	credsSrc := credentialsSource(config)
-	services := disco.NewDisco()
-	services.SetCredentialsSource(credsSrc)
 	for userHost, hostConfig := range config.Hosts {
 		host, err := svchost.ForComparison(userHost)
 		if err != nil {
@@ -57,8 +54,7 @@ func initCommands(config *Config) {
 		PluginOverrides:  &PluginOverrides,
 		Ui:               Ui,
 
-		Services:    services,
-		Credentials: credsSrc,
+		Services: services,
 
 		RunningInAutomation: inAutomation,
 		PluginCacheDir:      config.PluginCacheDir,
diff --git a/config/module/module_test.go b/config/module/module_test.go
index 62e7ed2a75..80e931e0b5 100644
--- a/config/module/module_test.go
+++ b/config/module/module_test.go
@@ -44,5 +44,5 @@ func testConfig(t *testing.T, n string) *config.Config {
 
 func testStorage(t *testing.T, d *disco.Disco) *Storage {
 	t.Helper()
-	return NewStorage(tempDir(t), d, nil)
+	return NewStorage(tempDir(t), d)
 }
diff --git a/config/module/storage.go b/config/module/storage.go
index fa5e1c621c..4b828dcb08 100644
--- a/config/module/storage.go
+++ b/config/module/storage.go
@@ -11,7 +11,6 @@ import (
 	getter "github.com/hashicorp/go-getter"
 	"github.com/hashicorp/terraform/registry"
 	"github.com/hashicorp/terraform/registry/regsrc"
-	"github.com/hashicorp/terraform/svchost/auth"
 	"github.com/hashicorp/terraform/svchost/disco"
 	"github.com/mitchellh/cli"
 )
@@ -64,14 +63,10 @@ type Storage struct {
 	// StorageDir is the full path to the directory where all modules will be
 	// stored.
 	StorageDir string
-	// Services is a required *disco.Disco, which may have services and
-	// credentials pre-loaded.
-	Services *disco.Disco
-	// Creds optionally provides credentials for communicating with service
-	// providers.
-	Creds auth.CredentialsSource
+
 	// Ui is an optional cli.Ui for user output
 	Ui cli.Ui
+
 	// Mode is the GetMode that will be used for various operations.
 	Mode GetMode
 
@@ -79,8 +74,8 @@ type Storage struct {
 }
 
 // NewStorage returns a new initialized Storage object.
-func NewStorage(dir string, services *disco.Disco, creds auth.CredentialsSource) *Storage {
-	regClient := registry.NewClient(services, creds, nil)
+func NewStorage(dir string, services *disco.Disco) *Storage {
+	regClient := registry.NewClient(services, nil)
 
 	return &Storage{
 		StorageDir: dir,
diff --git a/config/module/storage_test.go b/config/module/storage_test.go
index 10811190e3..cb41f6d65b 100644
--- a/config/module/storage_test.go
+++ b/config/module/storage_test.go
@@ -22,7 +22,7 @@ func TestGetModule(t *testing.T) {
 		t.Fatal(err)
 	}
 	defer os.RemoveAll(td)
-	storage := NewStorage(td, disco, nil)
+	storage := NewStorage(td, disco)
 
 	// this module exists in a test fixture, and is known by the test.Registry
 	// relative to our cwd.
@@ -139,7 +139,7 @@ func TestAccRegistryDiscover(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	s := NewStorage("/tmp", nil, nil)
+	s := NewStorage("/tmp", nil)
 	loc, err := s.registry.Location(module, "")
 	if err != nil {
 		t.Fatal(err)
diff --git a/configs/configload/loader.go b/configs/configload/loader.go
index 06ff27400c..45e60f77ce 100644
--- a/configs/configload/loader.go
+++ b/configs/configload/loader.go
@@ -5,7 +5,6 @@ import (
 
 	"github.com/hashicorp/terraform/configs"
 	"github.com/hashicorp/terraform/registry"
-	"github.com/hashicorp/terraform/svchost/auth"
 	"github.com/hashicorp/terraform/svchost/disco"
 	"github.com/spf13/afero"
 )
@@ -39,10 +38,6 @@ type Config struct {
 	// not supported, which should be true only in specialized circumstances
 	// such as in tests.
 	Services *disco.Disco
-
-	// Creds is a credentials store for communicating with remote module
-	// registry endpoints. If this is nil then no credentials will be used.
-	Creds auth.CredentialsSource
 }
 
 // NewLoader creates and returns a loader that reads configuration from the
@@ -54,7 +49,7 @@ type Config struct {
 func NewLoader(config *Config) (*Loader, error) {
 	fs := afero.NewOsFs()
 	parser := configs.NewParser(fs)
-	reg := registry.NewClient(config.Services, config.Creds, nil)
+	reg := registry.NewClient(config.Services, nil)
 
 	ret := &Loader{
 		parser: parser,
@@ -63,7 +58,6 @@ func NewLoader(config *Config) (*Loader, error) {
 			CanInstall: true,
 			Dir:        config.ModulesDir,
 			Services:   config.Services,
-			Creds:      config.Creds,
 			Registry:   reg,
 		},
 	}
diff --git a/configs/configload/module_mgr.go b/configs/configload/module_mgr.go
index ef17fda7a7..6b2a5199fb 100644
--- a/configs/configload/module_mgr.go
+++ b/configs/configload/module_mgr.go
@@ -2,7 +2,6 @@ package configload
 
 import (
 	"github.com/hashicorp/terraform/registry"
-	"github.com/hashicorp/terraform/svchost/auth"
 	"github.com/hashicorp/terraform/svchost/disco"
 	"github.com/spf13/afero"
 )
@@ -25,9 +24,6 @@ type moduleMgr struct {
 	// cached discovery information.
 	Services *disco.Disco
 
-	// Creds provides optional credentials for communicating with service hosts.
-	Creds auth.CredentialsSource
-
 	// Registry is a client for the module registry protocol, which is used
 	// when a module is requested from a registry source.
 	Registry *registry.Client
diff --git a/main.go b/main.go
index 1818a91c44..523863e7b8 100644
--- a/main.go
+++ b/main.go
@@ -16,6 +16,7 @@ import (
 	"github.com/hashicorp/go-plugin"
 	"github.com/hashicorp/terraform/command/format"
 	"github.com/hashicorp/terraform/helper/logging"
+	"github.com/hashicorp/terraform/svchost/disco"
 	"github.com/hashicorp/terraform/terraform"
 	"github.com/mattn/go-colorable"
 	"github.com/mattn/go-shellwords"
@@ -144,7 +145,9 @@ func wrappedMain() int {
 
 	// In tests, Commands may already be set to provide mock commands
 	if Commands == nil {
-		initCommands(config)
+		credsSrc := credentialsSource(config)
+		services := disco.NewWithCredentialsSource(credsSrc)
+		initCommands(config, services)
 	}
 
 	// Run checkpoint
diff --git a/registry/client.go b/registry/client.go
index fba59ec873..8e31a6a3e2 100644
--- a/registry/client.go
+++ b/registry/client.go
@@ -15,7 +15,6 @@ import (
 	"github.com/hashicorp/terraform/registry/regsrc"
 	"github.com/hashicorp/terraform/registry/response"
 	"github.com/hashicorp/terraform/svchost"
-	"github.com/hashicorp/terraform/svchost/auth"
 	"github.com/hashicorp/terraform/svchost/disco"
 	"github.com/hashicorp/terraform/version"
 )
@@ -37,20 +36,14 @@ type Client struct {
 	// services is a required *disco.Disco, which may have services and
 	// credentials pre-loaded.
 	services *disco.Disco
-
-	// Creds optionally provides credentials for communicating with service
-	// providers.
-	creds auth.CredentialsSource
 }
 
 // NewClient returns a new initialized registry client.
-func NewClient(services *disco.Disco, creds auth.CredentialsSource, client *http.Client) *Client {
+func NewClient(services *disco.Disco, client *http.Client) *Client {
 	if services == nil {
-		services = disco.NewDisco()
+		services = disco.New()
 	}
 
-	services.SetCredentialsSource(creds)
-
 	if client == nil {
 		client = httpclient.New()
 		client.Timeout = requestTimeout
@@ -61,7 +54,6 @@ func NewClient(services *disco.Disco, creds auth.CredentialsSource, client *http
 	return &Client{
 		client:   client,
 		services: services,
-		creds:    creds,
 	}
 }
 
@@ -138,11 +130,7 @@ func (c *Client) Versions(module *regsrc.Module) (*response.ModuleVersions, erro
 }
 
 func (c *Client) addRequestCreds(host svchost.Hostname, req *http.Request) {
-	if c.creds == nil {
-		return
-	}
-
-	creds, err := c.creds.ForHost(host)
+	creds, err := c.services.CredentialsForHost(host)
 	if err != nil {
 		log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
 		return
diff --git a/registry/client_test.go b/registry/client_test.go
index 279c5a4831..5ee712f7f2 100644
--- a/registry/client_test.go
+++ b/registry/client_test.go
@@ -15,7 +15,7 @@ func TestLookupModuleVersions(t *testing.T) {
 	server := test.Registry()
 	defer server.Close()
 
-	client := NewClient(test.Disco(server), nil, nil)
+	client := NewClient(test.Disco(server), nil)
 
 	// test with and without a hostname
 	for _, src := range []string{
@@ -59,7 +59,7 @@ func TestInvalidRegistry(t *testing.T) {
 	server := test.Registry()
 	defer server.Close()
 
-	client := NewClient(test.Disco(server), nil, nil)
+	client := NewClient(test.Disco(server), nil)
 
 	src := "non-existent.localhost.localdomain/test-versions/name/provider"
 	modsrc, err := regsrc.ParseModuleSource(src)
@@ -76,7 +76,7 @@ func TestRegistryAuth(t *testing.T) {
 	server := test.Registry()
 	defer server.Close()
 
-	client := NewClient(test.Disco(server), nil, nil)
+	client := NewClient(test.Disco(server), nil)
 
 	src := "private/name/provider"
 	mod, err := regsrc.ParseModuleSource(src)
@@ -84,25 +84,26 @@ func TestRegistryAuth(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	// both should fail without auth
 	_, err = client.Versions(mod)
-	if err == nil {
-		t.Fatal("expected error")
+	if err != nil {
+		t.Fatal(err)
 	}
 	_, err = client.Location(mod, "1.0.0")
-	if err == nil {
-		t.Fatal("expected error")
+	if err != nil {
+		t.Fatal(err)
 	}
 
-	client = NewClient(test.Disco(server), test.Credentials, nil)
+	// Also test without a credentials source
+	client.services.SetCredentialsSource(nil)
 
+	// both should fail without auth
 	_, err = client.Versions(mod)
-	if err != nil {
-		t.Fatal(err)
+	if err == nil {
+		t.Fatal("expected error")
 	}
 	_, err = client.Location(mod, "1.0.0")
-	if err != nil {
-		t.Fatal(err)
+	if err == nil {
+		t.Fatal("expected error")
 	}
 }
 
@@ -110,7 +111,7 @@ func TestLookupModuleLocationRelative(t *testing.T) {
 	server := test.Registry()
 	defer server.Close()
 
-	client := NewClient(test.Disco(server), nil, nil)
+	client := NewClient(test.Disco(server), nil)
 
 	src := "relative/foo/bar"
 	mod, err := regsrc.ParseModuleSource(src)
@@ -133,7 +134,7 @@ func TestAccLookupModuleVersions(t *testing.T) {
 	if os.Getenv("TF_ACC") == "" {
 		t.Skip()
 	}
-	regDisco := disco.NewDisco()
+	regDisco := disco.New()
 
 	// test with and without a hostname
 	for _, src := range []string{
@@ -145,7 +146,7 @@ func TestAccLookupModuleVersions(t *testing.T) {
 			t.Fatal(err)
 		}
 
-		s := NewClient(regDisco, nil, nil)
+		s := NewClient(regDisco, nil)
 		resp, err := s.Versions(modsrc)
 		if err != nil {
 			t.Fatal(err)
@@ -179,7 +180,7 @@ func TestLookupLookupModuleError(t *testing.T) {
 	server := test.Registry()
 	defer server.Close()
 
-	client := NewClient(test.Disco(server), nil, nil)
+	client := NewClient(test.Disco(server), nil)
 
 	// this should not be found in teh registry
 	src := "bad/local/path"
diff --git a/registry/test/mock_registry.go b/registry/test/mock_registry.go
index c1fabbc25b..bd3d80b7f0 100644
--- a/registry/test/mock_registry.go
+++ b/registry/test/mock_registry.go
@@ -27,7 +27,7 @@ func Disco(s *httptest.Server) *disco.Disco {
 		// TODO: add specific tests to enumerate both possibilities.
 		"modules.v1": fmt.Sprintf("%s/v1/modules", s.URL),
 	}
-	d := disco.NewDisco()
+	d := disco.NewWithCredentialsSource(credsSrc)
 
 	d.ForceHostServices(svchost.Hostname("registry.terraform.io"), services)
 	d.ForceHostServices(svchost.Hostname("localhost"), services)
@@ -48,8 +48,8 @@ const (
 )
 
 var (
-	regHost     = svchost.Hostname(regsrc.PublicRegistryHost.Normalized())
-	Credentials = auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
+	regHost  = svchost.Hostname(regsrc.PublicRegistryHost.Normalized())
+	credsSrc = auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
 		regHost: {"token": testCred},
 	})
 )
diff --git a/svchost/auth/credentials.go b/svchost/auth/credentials.go
index 0bc6db4f10..0372c16096 100644
--- a/svchost/auth/credentials.go
+++ b/svchost/auth/credentials.go
@@ -42,6 +42,9 @@ type HostCredentials interface {
 	// receiving credentials. The usual behavior of this method is to
 	// add some sort of Authorization header to the request.
 	PrepareRequest(req *http.Request)
+
+	// Token returns the authentication token.
+	Token() string
 }
 
 // ForHost iterates over the contained CredentialsSource objects and
diff --git a/svchost/auth/token_credentials.go b/svchost/auth/token_credentials.go
index 8f771b0d9b..9358bcb644 100644
--- a/svchost/auth/token_credentials.go
+++ b/svchost/auth/token_credentials.go
@@ -18,3 +18,8 @@ func (tc HostCredentialsToken) PrepareRequest(req *http.Request) {
 	}
 	req.Header.Set("Authorization", "Bearer "+string(tc))
 }
+
+// Token returns the authentication token.
+func (tc HostCredentialsToken) Token() string {
+	return string(tc)
+}
diff --git a/svchost/disco/disco.go b/svchost/disco/disco.go
index 76a1b3b0d5..7fc49da9cb 100644
--- a/svchost/disco/disco.go
+++ b/svchost/disco/disco.go
@@ -42,9 +42,15 @@ type Disco struct {
 	Transport http.RoundTripper
 }
 
-// NewDisco returns a new initialized Disco object.
-func NewDisco() *Disco {
-	return &Disco{}
+// New returns a new initialized discovery object.
+func New() *Disco {
+	return NewWithCredentialsSource(nil)
+}
+
+// NewWithCredentialsSource returns a new discovery object initialized with
+// the given credentials source.
+func NewWithCredentialsSource(credsSrc auth.CredentialsSource) *Disco {
+	return &Disco{credsSrc: credsSrc}
 }
 
 // SetCredentialsSource provides a credentials source that will be used to
@@ -56,6 +62,15 @@ func (d *Disco) SetCredentialsSource(src auth.CredentialsSource) {
 	d.credsSrc = src
 }
 
+// CredentialsForHost returns a non-nil HostCredentials if the embedded source has
+// credentials available for the host, and a nil HostCredentials if it does not.
+func (d *Disco) CredentialsForHost(host svchost.Hostname) (auth.HostCredentials, error) {
+	if d.credsSrc == nil {
+		return nil, nil
+	}
+	return d.credsSrc.ForHost(host)
+}
+
 // ForceHostServices provides a pre-defined set of services for a given
 // host, which prevents the receiver from attempting network-based discovery
 // for the given host. Instead, the given services map will be returned
@@ -145,15 +160,10 @@ func (d *Disco) discover(host svchost.Hostname) Host {
 		URL:    discoURL,
 	}
 
-	if d.credsSrc != nil {
-		creds, err := d.credsSrc.ForHost(host)
-		if err == nil {
-			if creds != nil {
-				creds.PrepareRequest(req) // alters req to include credentials
-			}
-		} else {
-			log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
-		}
+	if creds, err := d.CredentialsForHost(host); err != nil {
+		log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
+	} else if creds != nil {
+		creds.PrepareRequest(req) // alters req to include credentials
 	}
 
 	log.Printf("[DEBUG] Service discovery for %s at %s", host, discoURL)
diff --git a/svchost/disco/disco_test.go b/svchost/disco/disco_test.go
index 94d2a220f5..c8bc16c455 100644
--- a/svchost/disco/disco_test.go
+++ b/svchost/disco/disco_test.go
@@ -45,7 +45,7 @@ func TestDiscover(t *testing.T) {
 			t.Fatalf("test server hostname is invalid: %s", err)
 		}
 
-		d := NewDisco()
+		d := New()
 		discovered := d.Discover(host)
 		gotURL := discovered.ServiceURL("thingy.v1")
 		if gotURL == nil {
@@ -80,7 +80,7 @@ func TestDiscover(t *testing.T) {
 			t.Fatalf("test server hostname is invalid: %s", err)
 		}
 
-		d := NewDisco()
+		d := New()
 		discovered := d.Discover(host)
 		gotURL := discovered.ServiceURL("wotsit.v2")
 		if gotURL == nil {
@@ -107,7 +107,7 @@ func TestDiscover(t *testing.T) {
 			t.Fatalf("test server hostname is invalid: %s", err)
 		}
 
-		d := NewDisco()
+		d := New()
 		d.SetCredentialsSource(auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
 			host: map[string]interface{}{
 				"token": "abc123",
@@ -124,7 +124,7 @@ func TestDiscover(t *testing.T) {
 			"wotsit.v2": "/foo",
 		}
 
-		d := NewDisco()
+		d := New()
 		d.ForceHostServices(svchost.Hostname("example.com"), forced)
 
 		givenHost := "example.com"
@@ -167,7 +167,7 @@ func TestDiscover(t *testing.T) {
 			t.Fatalf("test server hostname is invalid: %s", err)
 		}
 
-		d := NewDisco()
+		d := New()
 		discovered := d.Discover(host)
 
 		// result should be empty, which we can verify only by reaching into
@@ -190,7 +190,7 @@ func TestDiscover(t *testing.T) {
 			t.Fatalf("test server hostname is invalid: %s", err)
 		}
 
-		d := NewDisco()
+		d := New()
 		discovered := d.Discover(host)
 
 		// result should be empty, which we can verify only by reaching into
@@ -217,7 +217,7 @@ func TestDiscover(t *testing.T) {
 			t.Fatalf("test server hostname is invalid: %s", err)
 		}
 
-		d := NewDisco()
+		d := New()
 		discovered := d.Discover(host)
 
 		if discovered.services == nil {
@@ -236,7 +236,7 @@ func TestDiscover(t *testing.T) {
 			t.Fatalf("test server hostname is invalid: %s", err)
 		}
 
-		d := NewDisco()
+		d := New()
 		discovered := d.Discover(host)
 
 		// result should be empty, which we can verify only by reaching into
@@ -267,7 +267,7 @@ func TestDiscover(t *testing.T) {
 			t.Fatalf("test server hostname is invalid: %s", err)
 		}
 
-		d := NewDisco()
+		d := New()
 		discovered := d.Discover(host)
 
 		gotURL := discovered.ServiceURL("thingy.v1")
-- 
GitLab