Skip to content

Commit

Permalink
feat(setter): print '(cached)' for cached results (#776)
Browse files Browse the repository at this point in the history
  • Loading branch information
favonia committed Jun 27, 2024
1 parent 3966de4 commit 1bcbbf0
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 53 deletions.
2 changes: 1 addition & 1 deletion internal/api/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
// A Handle represents a generic API to update DNS records. Currently, the only implementation is Cloudflare.
type Handle interface {
// ListRecords lists all matching DNS records.
ListRecords(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type) (map[string]netip.Addr, bool)
ListRecords(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type) (map[string]netip.Addr, bool, bool) //nolint:lll

// DeleteRecord deletes one DNS record.
DeleteRecord(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type, id string) bool
Expand Down
15 changes: 8 additions & 7 deletions internal/api/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,18 @@ zoneSearch:
return "", false
}

// ListRecords lists all matching DNS records.
// ListRecords lists all matching DNS records. The second return value indicates whether
// the lists are from cached responses.
func (h *CloudflareHandle) ListRecords(ctx context.Context, ppfmt pp.PP,
domain domain.Domain, ipNet ipnet.Type,
) (map[string]netip.Addr, bool) {
) (map[string]netip.Addr, bool, bool) {
if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
return rmap.Value(), true
return rmap.Value(), true, true
}

zone, ok := h.ZoneOfDomain(ctx, ppfmt, domain)
if !ok {
return nil, false
return nil, false, false
}

//nolint:exhaustruct // Other fields are intentionally unspecified
Expand All @@ -196,21 +197,21 @@ func (h *CloudflareHandle) ListRecords(ctx context.Context, ppfmt pp.PP,
})
if err != nil {
ppfmt.Warningf(pp.EmojiError, "Failed to retrieve records of %q: %v", domain.Describe(), err)
return nil, false
return nil, false, false
}

rmap := map[string]netip.Addr{}
for i := range rs {
rmap[rs[i].ID], err = netip.ParseAddr(rs[i].Content)
if err != nil {
ppfmt.Warningf(pp.EmojiImpossible, "Failed to parse the IP address in records of %q: %v", domain.Describe(), err)
return nil, false
return nil, false, false
}
}

h.cache.listRecords[ipNet].Set(domain.DNSNameASCII(), rmap, ttlcache.DefaultTTL)

return rmap, true
return rmap, false, true
}

// DeleteRecord deletes one DNS record.
Expand Down
49 changes: 31 additions & 18 deletions internal/api/cloudflare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func newHandle(t *testing.T, emptyAccountID bool) (*http.ServeMux, api.Handle) {
func TestNewValid(t *testing.T) {
t.Parallel()

_, _ = newHandle(t, false)
newHandle(t, false)
}

func TestNewEmpty(t *testing.T) {
Expand Down Expand Up @@ -636,15 +636,17 @@ func TestListRecords(t *testing.T) {
expected := map[string]netip.Addr{"record1": mustIP("::1"), "record2": mustIP("::2")}
ipNet, ips, accessCount = ipnet.IP6, expected, 1
mockPP := mocks.NewMockPP(mockCtrl)
ips, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
ips, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.True(t, ok)
require.False(t, cached)
require.Equal(t, expected, ips)
require.Equal(t, 0, accessCount)

// testing the caching
mockPP = mocks.NewMockPP(mockCtrl)
ips, ok = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
ips, cached, ok = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.True(t, ok)
require.True(t, cached)
require.Equal(t, expected, ips)
}

Expand Down Expand Up @@ -698,8 +700,9 @@ func TestListRecordsInvalidIPAddress(t *testing.T) {
"sub.test.org",
gomock.Any(),
)
ips, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
ips, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.False(t, ok)
require.False(t, cached)
require.Nil(t, ips)
require.Equal(t, 0, accessCount)

Expand All @@ -711,8 +714,9 @@ func TestListRecordsInvalidIPAddress(t *testing.T) {
"sub.test.org",
gomock.Any(),
)
ips, ok = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
ips, cached, ok = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.False(t, ok)
require.False(t, cached)
require.Nil(t, ips)
require.Equal(t, 0, accessCount)
}
Expand Down Expand Up @@ -762,15 +766,17 @@ func TestListRecordsWildcard(t *testing.T) {
expected := map[string]netip.Addr{"record1": mustIP("::1"), "record2": mustIP("::2")}
ipNet, ips, accessCount = ipnet.IP6, expected, 1
mockPP := mocks.NewMockPP(mockCtrl)
ips, ok := h.ListRecords(context.Background(), mockPP, domain.Wildcard("test.org"), ipnet.IP6)
ips, cached, ok := h.ListRecords(context.Background(), mockPP, domain.Wildcard("test.org"), ipnet.IP6)
require.True(t, ok)
require.False(t, cached)
require.Equal(t, expected, ips)
require.Equal(t, 0, accessCount)

// testing the caching
mockPP = mocks.NewMockPP(mockCtrl)
ips, ok = h.ListRecords(context.Background(), mockPP, domain.Wildcard("test.org"), ipnet.IP6)
ips, cached, ok = h.ListRecords(context.Background(), mockPP, domain.Wildcard("test.org"), ipnet.IP6)
require.True(t, ok)
require.True(t, cached)
require.Equal(t, expected, ips)
}

Expand All @@ -785,14 +791,16 @@ func TestListRecordsInvalidDomain(t *testing.T) {

mockPP := mocks.NewMockPP(mockCtrl)
mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to retrieve records of %q: %v", "sub.test.org", gomock.Any())
ips, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP4)
ips, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP4)
require.False(t, ok)
require.False(t, cached)
require.Nil(t, ips)

mockPP = mocks.NewMockPP(mockCtrl)
mockPP.EXPECT().Warningf(pp.EmojiError, "Failed to retrieve records of %q: %v", "sub.test.org", gomock.Any())
ips, ok = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
ips, cached, ok = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.False(t, ok)
require.False(t, cached)
require.Nil(t, ips)
}

Expand All @@ -809,8 +817,9 @@ func TestListRecordsInvalidZone(t *testing.T) {
"sub.test.org",
gomock.Any(),
)
ips, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP4)
ips, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP4)
require.False(t, ok)
require.False(t, cached)
require.Nil(t, ips)

mockPP = mocks.NewMockPP(mockCtrl)
Expand All @@ -820,8 +829,9 @@ func TestListRecordsInvalidZone(t *testing.T) {
"sub.test.org",
gomock.Any(),
)
ips, ok = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
ips, cached, ok = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.False(t, ok)
require.False(t, cached)
require.Nil(t, ips)
}

Expand Down Expand Up @@ -907,10 +917,11 @@ func TestDeleteRecordValid(t *testing.T) {

listAccessCount, deleteAccessCount = 1, 1
mockPP = mocks.NewMockPP(mockCtrl)
_, _ = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
_ = h.DeleteRecord(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6, "record1")
rs, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
rs, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.True(t, ok)
require.True(t, cached)
require.Empty(t, rs)
}

Expand Down Expand Up @@ -1022,10 +1033,11 @@ func TestUpdateRecordValid(t *testing.T) {

listAccessCount, updateAccessCount = 1, 1
mockPP = mocks.NewMockPP(mockCtrl)
_, _ = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
_ = h.UpdateRecord(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6, "record1", mustIP("::2"))
rs, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
rs, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.True(t, ok)
require.True(t, cached)
require.Equal(t, map[string]netip.Addr{"record1": mustIP("::2")}, rs)
}

Expand Down Expand Up @@ -1138,10 +1150,11 @@ func TestCreateRecordValid(t *testing.T) {

listAccessCount, createAccessCount = 1, 1
mockPP = mocks.NewMockPP(mockCtrl)
_, _ = h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
_, _ = h.CreateRecord(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6, mustIP("::1"), 100, false) //nolint:lll
rs, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
h.CreateRecord(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6, mustIP("::1"), 100, false) //nolint:lll
rs, cached, ok := h.ListRecords(context.Background(), mockPP, domain.FQDN("sub.test.org"), ipnet.IP6)
require.True(t, ok)
require.True(t, cached)
require.Equal(t, map[string]netip.Addr{"record1": mustIP("::1")}, rs)
}

Expand Down
13 changes: 7 additions & 6 deletions internal/mocks/mock_api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions internal/setter/setter.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (s *setter) Set(ctx context.Context, ppfmt pp.PP,
recordType := ipnet.RecordType()
domainDescription := domain.Describe()

rs, ok := s.Handle.ListRecords(ctx, ppfmt, domain, ipnet)
rs, cached, ok := s.Handle.ListRecords(ctx, ppfmt, domain, ipnet)
if !ok {
ppfmt.Errorf(pp.EmojiError, "Failed to retrieve the current %s records of %q", recordType, domainDescription)
return ResponseFailed
Expand All @@ -76,7 +76,11 @@ func (s *setter) Set(ctx context.Context, ppfmt pp.PP,

// If it's up to date and there are no other records, we are done!
if foundMatched && len(unprocessedMatched) == 0 && len(unprocessedUnmatched) == 0 {
ppfmt.Infof(pp.EmojiAlreadyDone, "The %s records of %q are already up to date", recordType, domainDescription)
if cached {
ppfmt.Infof(pp.EmojiAlreadyDone, "The %s records of %q are already up to date (cached)", recordType, domainDescription) //nolint:lll
} else {
ppfmt.Infof(pp.EmojiAlreadyDone, "The %s records of %q are already up to date", recordType, domainDescription)
}
return ResponseNoop
}

Expand Down Expand Up @@ -177,7 +181,7 @@ func (s *setter) Delete(ctx context.Context, ppfmt pp.PP, domain domain.Domain,
recordType := ipnet.RecordType()
domainDescription := domain.Describe()

rmap, ok := s.Handle.ListRecords(ctx, ppfmt, domain, ipnet)
rmap, cached, ok := s.Handle.ListRecords(ctx, ppfmt, domain, ipnet)
if !ok {
ppfmt.Errorf(pp.EmojiError, "Failed to retrieve the current %s records of %q", recordType, domainDescription)
return ResponseFailed
Expand All @@ -191,7 +195,11 @@ func (s *setter) Delete(ctx context.Context, ppfmt pp.PP, domain domain.Domain,
sort.Strings(unmatchedIDs)

if len(unmatchedIDs) == 0 {
ppfmt.Infof(pp.EmojiAlreadyDone, "The %s records of %q were already deleted", recordType, domainDescription)
if cached {
ppfmt.Infof(pp.EmojiAlreadyDone, "The %s records of %q were already deleted (cached)", recordType, domainDescription)
} else {
ppfmt.Infof(pp.EmojiAlreadyDone, "The %s records of %q were already deleted", recordType, domainDescription)
}
return ResponseNoop
}

Expand Down
Loading

0 comments on commit 1bcbbf0

Please sign in to comment.