diff --git a/internal/releasesjson/downloader.go b/internal/releasesjson/downloader.go index b50fea0..2f937f1 100644 --- a/internal/releasesjson/downloader.go +++ b/internal/releasesjson/downloader.go @@ -64,23 +64,9 @@ func (d *Downloader) DownloadAndUnpack(ctx context.Context, pv *ProductVersion, client := httpclient.NewHTTPClient() - archiveURL := pb.URL - if d.BaseURL != "" { - // If custom URL is set, use that instead of the one from the JSON. - // Also ensures that absolute download links from mocked responses - // are still pointing to the mock server if one is set. - baseURL, err := url.Parse(d.BaseURL) - if err != nil { - return nil, err - } - - u, err := url.Parse(archiveURL) - if err != nil { - return nil, err - } - u.Scheme = baseURL.Scheme - u.Host = baseURL.Host - archiveURL = u.String() + archiveURL, err := determineArchiveURL(pb.URL, d.BaseURL) + if err != nil { + return nil, err } d.Logger.Printf("downloading archive from %s", archiveURL) @@ -237,3 +223,28 @@ func isLicenseFile(filename string) bool { } return false } + +// determineArchiveURL determines the archive URL based on the base URL provided. +func determineArchiveURL(archiveURL, baseURL string) (string, error) { + // If custom URL is set, use that instead of the one from the JSON. + // Also ensures that absolute download links from mocked responses + // are still pointing to the mock server if one is set. + if baseURL == "" { + return archiveURL, nil + } + + base, err := url.Parse(baseURL) + if err != nil { + return "", err + } + + u, err := url.Parse(archiveURL) + if err != nil { + return "", err + } + + // Use base URL path and append the path from the archive URL. + newArchiveURL := base.JoinPath(u.Path) + + return newArchiveURL.String(), nil +} diff --git a/internal/releasesjson/downloader_test.go b/internal/releasesjson/downloader_test.go new file mode 100644 index 0000000..043caf4 --- /dev/null +++ b/internal/releasesjson/downloader_test.go @@ -0,0 +1,47 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package releasesjson + +import "testing" + +func TestDetermineArchiveURL(t *testing.T) { + tests := []struct { + name string + archiveURL string + baseURL string + want string + }{ + { + name: "with custom base URL + path", + archiveURL: "https://releases.hashicorp.com/terraform/1.8.2/terraform_1.8.2_darwin_amd64.zip", + baseURL: "https://myartifactory.company.com/artifactory/hashicorp-remote", + want: "https://myartifactory.company.com/artifactory/hashicorp-remote/terraform/1.8.2/terraform_1.8.2_darwin_amd64.zip", + }, + { + name: "with custom base URL + port + path", + archiveURL: "https://releases.hashicorp.com/terraform/1.8.2/terraform_1.8.2_darwin_amd64.zip", + baseURL: "https://myartifactory.company.com:443/artifactory/hashicorp-remote", + want: "https://myartifactory.company.com:443/artifactory/hashicorp-remote/terraform/1.8.2/terraform_1.8.2_darwin_amd64.zip", + }, + { + name: "without custom base URL", + archiveURL: "https://releases.hashicorp.com/terraform/1.8.2/terraform_1.8.2_darwin_amd64.zip", + baseURL: "", + want: "https://releases.hashicorp.com/terraform/1.8.2/terraform_1.8.2_darwin_amd64.zip", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := determineArchiveURL(tt.archiveURL, tt.baseURL) + if err != nil { + t.Errorf("determineArchiveURL() error = %v", err) + return + } + if got != tt.want { + t.Errorf("determineArchiveURL() = %v, want %v", got, tt.want) + } + }) + } +}