diff --git a/.golangci.yml b/.golangci.yml index 3d9c8b0fe..4868c3862 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -36,7 +36,7 @@ linters-settings: dupl: threshold: 150 funlen: - Lines: 120 + Lines: 130 Statements: 60 goconst: min-len: 2 diff --git a/go.mod b/go.mod index 23763a260..0a138f55a 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/edwarnicke/serialize v1.0.7 github.com/fsnotify/fsnotify v1.5.4 github.com/ghodss/yaml v1.0.0 + github.com/go-pkgz/expirable-cache/v3 v3.0.0 github.com/golang-jwt/jwt/v4 v4.5.1 github.com/golang/protobuf v1.5.3 github.com/google/go-cmp v0.6.0 diff --git a/go.sum b/go.sum index 627615c88..389cc1996 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,8 @@ github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-pkgz/expirable-cache/v3 v3.0.0 h1:u3/gcu3sabLYiTCevoRKv+WzjIn5oo7P8XtiXBeRDLw= +github.com/go-pkgz/expirable-cache/v3 v3.0.0/go.mod h1:2OQiDyEGQalYecLWmXprm3maPXeVb5/6/X7yRPYTzec= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= @@ -143,6 +145,7 @@ github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/C github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/raft v1.3.9 h1:9yuo1aR0bFTr1cw7pj3S2Bk6MhJCsnr2NAxvIBrP2x4= github.com/hashicorp/raft v1.3.9/go.mod h1:4Ak7FSPnuvmb0GV6vgIAJ4vYT4bek9bb6Q+7HVbyzqM= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= diff --git a/pkg/networkservice/chains/nsmgr/server.go b/pkg/networkservice/chains/nsmgr/server.go index c1196ec56..4ee6efc8d 100644 --- a/pkg/networkservice/chains/nsmgr/server.go +++ b/pkg/networkservice/chains/nsmgr/server.go @@ -55,6 +55,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/localbypass" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" + "github.com/networkservicemesh/sdk/pkg/registry/common/querycache" registryrecvfd "github.com/networkservicemesh/sdk/pkg/registry/common/recvfd" registrysendfd "github.com/networkservicemesh/sdk/pkg/registry/common/sendfd" "github.com/networkservicemesh/sdk/pkg/registry/common/updatepath" @@ -239,6 +240,7 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options chain.NewNetworkServiceRegistryClient( clienturl.NewNetworkServiceRegistryClient(opts.regURL), begin.NewNetworkServiceRegistryClient(), + querycache.NewNetworkServiceRegistryClient(ctx), clientconn.NewNetworkServiceRegistryClient(), opts.authorizeNSRegistryClient, grpcmetadata.NewNetworkServiceRegistryClient(), @@ -263,6 +265,7 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options registryconnect.NewNetworkServiceEndpointRegistryServer( chain.NewNetworkServiceEndpointRegistryClient( begin.NewNetworkServiceEndpointRegistryClient(), + querycache.NewNetworkServiceEndpointRegistryClient(ctx), clienturl.NewNetworkServiceEndpointRegistryClient(opts.regURL), clientconn.NewNetworkServiceEndpointRegistryClient(), opts.authorizeNSERegistryClient, diff --git a/pkg/networkservice/chains/nsmgr/single_test.go b/pkg/networkservice/chains/nsmgr/single_test.go index 53afb1458..cb60bff83 100644 --- a/pkg/networkservice/chains/nsmgr/single_test.go +++ b/pkg/networkservice/chains/nsmgr/single_test.go @@ -373,6 +373,14 @@ func Test_UsecasePoint2MultiPoint(t *testing.T) { _, err = nsc.Close(ctx, conn) require.NoError(t, err) + require.Eventually(t, func() bool { + conn, err = nsc.Request(ctx, request.Clone()) + if err != nil { + return false + } + return len(conn.Path.PathSegments) == 4 && conn.GetPath().GetPathSegments()[2].Name == "p2p forwarder" + }, time.Second, time.Second/10) + conn, err = nsc.Request(ctx, request.Clone()) require.NoError(t, err) require.NotNil(t, conn) @@ -484,6 +492,14 @@ func Test_RemoteUsecase_Point2MultiPoint(t *testing.T) { _, err = nsc.Close(ctx, conn) require.NoError(t, err) + require.Eventually(t, func() bool { + conn, err = nsc.Request(ctx, request.Clone()) + if err != nil { + return false + } + return len(conn.Path.PathSegments) == 6 && conn.GetPath().GetPathSegments()[2].Name == "p2p forwarder-0" && conn.GetPath().GetPathSegments()[4].Name == "p2p forwarder-1" + }, time.Second, time.Second/10) + conn, err = nsc.Request(ctx, request.Clone()) require.NoError(t, err) require.NotNil(t, conn) diff --git a/pkg/networkservice/common/discoverforwarder/server.go b/pkg/networkservice/common/discoverforwarder/server.go index 285a55f5d..63077edb2 100644 --- a/pkg/networkservice/common/discoverforwarder/server.go +++ b/pkg/networkservice/common/discoverforwarder/server.go @@ -197,9 +197,10 @@ func (d *discoverForwarderServer) matchForwarders(nsLabels map[string]string, ns var matchLabels = match.GetMetadata().GetLabels() if matchLabels == nil { - matchLabels = map[string]string{ - "p2p": "true", - } + matchLabels = make(map[string]string) + } + if len(matchLabels) == 0 { + matchLabels["p2p"] = "true" } for _, nse := range nses { var forwarderLabels = nse.GetNetworkServiceLabels()[d.forwarderServiceName] @@ -217,7 +218,6 @@ func (d *discoverForwarderServer) matchForwarders(nsLabels map[string]string, ns break } - return result } diff --git a/pkg/registry/chains/client/ns_client.go b/pkg/registry/chains/client/ns_client.go index 749151548..f5dff0d96 100644 --- a/pkg/registry/chains/client/ns_client.go +++ b/pkg/registry/chains/client/ns_client.go @@ -32,6 +32,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/heal" "github.com/networkservicemesh/sdk/pkg/registry/common/null" + "github.com/networkservicemesh/sdk/pkg/registry/common/querycache" "github.com/networkservicemesh/sdk/pkg/registry/common/retry" "github.com/networkservicemesh/sdk/pkg/registry/core/chain" "github.com/networkservicemesh/sdk/pkg/registry/utils/metadata" @@ -53,6 +54,7 @@ func NewNetworkServiceRegistryClient(ctx context.Context, opts ...Option) regist []registry.NetworkServiceRegistryClient{ begin.NewNetworkServiceRegistryClient(), metadata.NewNetworkServiceClient(), + querycache.NewNetworkServiceRegistryClient(ctx), retry.NewNetworkServiceRegistryClient(ctx), clientOpts.authorizeNSRegistryClient, heal.NewNetworkServiceRegistryClient(ctx), diff --git a/pkg/registry/chains/client/nse_client.go b/pkg/registry/chains/client/nse_client.go index 7ef1d7ae5..46ca61573 100644 --- a/pkg/registry/chains/client/nse_client.go +++ b/pkg/registry/chains/client/nse_client.go @@ -33,6 +33,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/heal" "github.com/networkservicemesh/sdk/pkg/registry/common/null" + "github.com/networkservicemesh/sdk/pkg/registry/common/querycache" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" "github.com/networkservicemesh/sdk/pkg/registry/common/retry" "github.com/networkservicemesh/sdk/pkg/registry/core/chain" @@ -55,6 +56,7 @@ func NewNetworkServiceEndpointRegistryClient(ctx context.Context, opts ...Option []registry.NetworkServiceEndpointRegistryClient{ begin.NewNetworkServiceEndpointRegistryClient(), metadata.NewNetworkServiceEndpointClient(), + querycache.NewNetworkServiceEndpointRegistryClient(ctx), retry.NewNetworkServiceEndpointRegistryClient(ctx), heal.NewNetworkServiceEndpointRegistryClient(ctx), refresh.NewNetworkServiceEndpointRegistryClient(ctx), diff --git a/pkg/registry/common/localbypass/server.go b/pkg/registry/common/localbypass/server.go index 4f7e19ba0..c6f4c22b1 100644 --- a/pkg/registry/common/localbypass/server.go +++ b/pkg/registry/common/localbypass/server.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -39,6 +39,7 @@ type localBypassNSEFindServer struct { } func (s *localBypassNSEFindServer) Send(nseResp *registry.NetworkServiceEndpointResponse) error { + nseResp = nseResp.Clone() if u, ok := s.nseURLs.Load(nseResp.NetworkServiceEndpoint.Name); ok { nseResp.NetworkServiceEndpoint.Url = u.String() } diff --git a/pkg/registry/common/querycache/cache.go b/pkg/registry/common/querycache/cache.go deleted file mode 100644 index f9558dc42..000000000 --- a/pkg/registry/common/querycache/cache.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// Copyright (c) 2023 Cisco and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package querycache - -import ( - "context" - "sync" - "time" - - "github.com/edwarnicke/genericsync" - "github.com/networkservicemesh/api/pkg/api/registry" - - "github.com/networkservicemesh/sdk/pkg/tools/clock" -) - -type cache struct { - expireTimeout time.Duration - entries genericsync.Map[string, *cacheEntry] - clockTime clock.Clock -} - -func newCache(ctx context.Context, opts ...Option) *cache { - c := &cache{ - expireTimeout: time.Minute, - clockTime: clock.FromContext(ctx), - } - - for _, opt := range opts { - opt(c) - } - - ticker := c.clockTime.Ticker(c.expireTimeout) - go func() { - for { - select { - case <-ctx.Done(): - ticker.Stop() - return - case <-ticker.C(): - c.entries.Range(func(_ string, e *cacheEntry) bool { - e.lock.Lock() - defer e.lock.Unlock() - - if c.clockTime.Until(e.expirationTime) < 0 { - e.cleanup() - } - - return true - }) - } - } - }() - - return c -} - -func (c *cache) LoadOrStore(key string, nse *registry.NetworkServiceEndpoint, cancel context.CancelFunc) (*cacheEntry, bool) { - var once sync.Once - return c.entries.LoadOrStore(key, &cacheEntry{ - nse: nse, - expirationTime: c.clockTime.Now().Add(c.expireTimeout), - cleanup: func() { - once.Do(func() { - c.entries.Delete(key) - cancel() - }) - }, - }) -} - -func (c *cache) Load(key string) (*registry.NetworkServiceEndpoint, bool) { - e, ok := c.entries.Load(key) - if !ok { - return nil, false - } - - e.lock.Lock() - defer e.lock.Unlock() - - if c.clockTime.Until(e.expirationTime) < 0 { - e.cleanup() - return nil, false - } - - e.expirationTime = c.clockTime.Now().Add(c.expireTimeout) - - return e.nse, true -} - -type cacheEntry struct { - nse *registry.NetworkServiceEndpoint - expirationTime time.Time - lock sync.Mutex - cleanup func() -} - -func (e *cacheEntry) Update(nse *registry.NetworkServiceEndpoint) { - e.lock.Lock() - defer e.lock.Unlock() - - e.nse = nse -} - -func (e *cacheEntry) Cleanup() { - e.lock.Lock() - defer e.lock.Unlock() - - e.cleanup() -} diff --git a/pkg/registry/common/querycache/ns_client.go b/pkg/registry/common/querycache/ns_client.go new file mode 100644 index 000000000..7c155487d --- /dev/null +++ b/pkg/registry/common/querycache/ns_client.go @@ -0,0 +1,88 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package querycache adds possible to cache Find queries +package querycache + +import ( + "context" + "time" + + cache "github.com/go-pkgz/expirable-cache/v3" + "github.com/golang/protobuf/ptypes/empty" + "google.golang.org/grpc" + + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" +) + +type queryCacheNSClient struct { + chainContext context.Context + cache cache.Cache[string, []*registry.NetworkService] +} + +// NewNetworkServiceRegistryClient creates new querycache NS registry client that caches all resolved NSs +func NewNetworkServiceRegistryClient(ctx context.Context) registry.NetworkServiceRegistryClient { + var res = &queryCacheNSClient{ + chainContext: ctx, + cache: cache.NewCache[string, []*registry.NetworkService]().WithLRU().WithMaxKeys(32).WithTTL(time.Millisecond * 100), + } + return res +} + +func (q *queryCacheNSClient) Register(ctx context.Context, nse *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + resp, err := next.NetworkServiceRegistryClient(ctx).Register(ctx, nse, opts...) + if err == nil { + q.cache.Add(resp.GetName(), []*registry.NetworkService{resp}) + } + return resp, err +} + +func (q *queryCacheNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + if query.Watch { + return next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) + } + + var list []*registry.NetworkService + if v, ok := q.cache.Get(query.GetNetworkService().GetName()); ok { + list = v + } else { + var streamClient, err = next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) + if err != nil { + return streamClient, err + } + list = registry.ReadNetworkServiceList(streamClient) + for _, item := range list { + q.cache.Add(item.GetName(), []*registry.NetworkService{item.Clone()}) + } + } + var resultStreamChannel = make(chan *registry.NetworkServiceResponse, len(list)) + for _, item := range list { + resultStreamChannel <- ®istry.NetworkServiceResponse{NetworkService: item} + } + close(resultStreamChannel) + return streamchannel.NewNetworkServiceFindClient(ctx, resultStreamChannel), nil +} + +func (q *queryCacheNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + resp, err := next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) + if err == nil { + q.cache.Remove(in.GetName()) + } + return resp, err +} diff --git a/pkg/registry/common/querycache/nse_client.go b/pkg/registry/common/querycache/nse_client.go index efc37f204..14d671578 100644 --- a/pkg/registry/common/querycache/nse_client.go +++ b/pkg/registry/common/querycache/nse_client.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,7 +21,9 @@ package querycache import ( "context" + "time" + cache "github.com/go-pkgz/expirable-cache/v3" "github.com/golang/protobuf/ptypes/empty" "google.golang.org/grpc" @@ -30,20 +34,25 @@ import ( ) type queryCacheNSEClient struct { - ctx context.Context - cache *cache + chainContext context.Context + cache cache.Cache[string, []*registry.NetworkServiceEndpoint] } -// NewClient creates new querycache NSE registry client that caches all resolved NSEs -func NewClient(ctx context.Context, opts ...Option) registry.NetworkServiceEndpointRegistryClient { - return &queryCacheNSEClient{ - ctx: ctx, - cache: newCache(ctx, opts...), +// NewNetworkServiceEndpointRegistryClient creates new querycache NSE registry client that caches all resolved NSEs +func NewNetworkServiceEndpointRegistryClient(ctx context.Context) registry.NetworkServiceEndpointRegistryClient { + var res = &queryCacheNSEClient{ + chainContext: ctx, + cache: cache.NewCache[string, []*registry.NetworkServiceEndpoint]().WithLRU().WithMaxKeys(32).WithTTL(time.Millisecond * 300), } + return res } func (q *queryCacheNSEClient) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { - return next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, nse, opts...) + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, nse, opts...) + if err == nil { + q.cache.Add(resp.GetName(), []*registry.NetworkServiceEndpoint{resp}) + } + return resp, err } func (q *queryCacheNSEClient) Find(ctx context.Context, query *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { @@ -51,80 +60,31 @@ func (q *queryCacheNSEClient) Find(ctx context.Context, query *registry.NetworkS return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, query, opts...) } - if client, ok := q.findInCache(ctx, query.String()); ok { - return client, nil - } - - client, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, query, opts...) - if err != nil { - return nil, err - } - - nses := registry.ReadNetworkServiceEndpointList(client) - - resultCh := make(chan *registry.NetworkServiceEndpointResponse, len(nses)) - for _, nse := range nses { - resultCh <- ®istry.NetworkServiceEndpointResponse{NetworkServiceEndpoint: nse} - q.storeInCache(ctx, nse.Clone(), opts...) - } - close(resultCh) - - return streamchannel.NewNetworkServiceEndpointFindClient(ctx, resultCh), nil -} - -func (q *queryCacheNSEClient) findInCache(ctx context.Context, key string) (registry.NetworkServiceEndpointRegistry_FindClient, bool) { - nse, ok := q.cache.Load(key) - if !ok { - return nil, false - } - - resultCh := make(chan *registry.NetworkServiceEndpointResponse, 1) - resultCh <- ®istry.NetworkServiceEndpointResponse{NetworkServiceEndpoint: nse.Clone()} - close(resultCh) - - return streamchannel.NewNetworkServiceEndpointFindClient(ctx, resultCh), true -} - -func (q *queryCacheNSEClient) storeInCache(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) { - nseQuery := ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ - Name: nse.Name, - }, - } - - key := nseQuery.String() - - findCtx, cancel := context.WithCancel(q.ctx) - - entry, loaded := q.cache.LoadOrStore(key, nse, cancel) - if loaded { - cancel() - return - } - - go func() { - defer entry.Cleanup() - - nseQuery.Watch = true - - stream, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(findCtx, nseQuery, opts...) + var list []*registry.NetworkServiceEndpoint + if v, ok := q.cache.Get(query.GetNetworkServiceEndpoint().GetName()); ok { + list = v + } else { + var streamClient, err = next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, query, opts...) if err != nil { - return + return streamClient, err } - - for nseResp, err := stream.Recv(); err == nil; nseResp, err = stream.Recv() { - if nseResp.NetworkServiceEndpoint.Name != nseQuery.NetworkServiceEndpoint.Name { - continue - } - if nseResp.Deleted { - break - } - - entry.Update(nseResp.NetworkServiceEndpoint) + list = registry.ReadNetworkServiceEndpointList(streamClient) + for _, item := range list { + q.cache.Add(item.GetName(), []*registry.NetworkServiceEndpoint{item.Clone()}) } - }() + } + var resultStreamChannel = make(chan *registry.NetworkServiceEndpointResponse, len(list)) + for _, item := range list { + resultStreamChannel <- ®istry.NetworkServiceEndpointResponse{NetworkServiceEndpoint: item} + } + close(resultStreamChannel) + return streamchannel.NewNetworkServiceEndpointFindClient(ctx, resultStreamChannel), nil } func (q *queryCacheNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { - return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + if err == nil { + q.cache.Remove(in.GetName()) + } + return resp, err } diff --git a/pkg/registry/common/querycache/nse_client_test.go b/pkg/registry/common/querycache/nse_client_test.go index 516a911ba..3c1787726 100644 --- a/pkg/registry/common/querycache/nse_client_test.go +++ b/pkg/registry/common/querycache/nse_client_test.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,12 +40,12 @@ import ( ) const ( - expireTimeout = time.Minute + expireTimeout = time.Second name = "nse" url1 = "tcp://1.1.1.1" url2 = "tcp://2.2.2.2" - testWait = 100 * time.Millisecond - testTick = testWait / 100 + testWait = time.Second + testTick = time.Second / 15 ) func testNSEQuery(nseName string) *registry.NetworkServiceEndpointQuery { @@ -53,7 +55,6 @@ func testNSEQuery(nseName string) *registry.NetworkServiceEndpointQuery { }, } } - func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) @@ -64,7 +65,7 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { failureClient := new(failureNSEClient) c := next.NewNetworkServiceEndpointRegistryClient( - querycache.NewClient(ctx, querycache.WithExpireTimeout(expireTimeout)), + querycache.NewNetworkServiceEndpointRegistryClient(ctx), failureClient, adapters.NetworkServiceEndpointServerToClient(mem), ) @@ -75,9 +76,6 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { }) require.NoError(t, err) - // Goroutines should be cleaned up on NSE unregister - t.Cleanup(func() { goleak.VerifyNone(t) }) - // 1. Find from memory atomic.StoreInt32(&failureClient.shouldFail, 0) @@ -86,28 +84,24 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { nseResp, err := stream.Recv() require.NoError(t, err) - require.Equal(t, name, nseResp.NetworkServiceEndpoint.Name) require.Equal(t, url1, nseResp.NetworkServiceEndpoint.Url) // 2. Find from cache atomic.StoreInt32(&failureClient.shouldFail, 1) - require.Eventually(t, func() bool { - if stream, err = c.Find(ctx, testNSEQuery(name)); err != nil { - return false - } - if nseResp, err = stream.Recv(); err != nil { - return false - } - return name == nseResp.NetworkServiceEndpoint.Name && url1 == nseResp.NetworkServiceEndpoint.Url - }, testWait, testTick) + stream, err = c.Find(ctx, testNSEQuery(name)) + require.NoError(t, err) + nseResp, err = stream.Recv() + require.NoError(t, err) + require.Equal(t, name, nseResp.NetworkServiceEndpoint.Name) + require.Equal(t, url1, nseResp.NetworkServiceEndpoint.Url) // 3. Update NSE in memory reg.Url = url2 - reg, err = mem.Register(ctx, reg) require.NoError(t, err) + atomic.StoreInt32(&failureClient.shouldFail, 0) require.Eventually(t, func() bool { if stream, err = c.Find(ctx, testNSEQuery(name)); err != nil { @@ -124,8 +118,11 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { require.NoError(t, err) require.Eventually(t, func() bool { - _, err = c.Find(ctx, testNSEQuery(name)) - return err != nil + s, err := c.Find(ctx, testNSEQuery(name)) + if err != nil { + return false + } + return len(registry.ReadNetworkServiceEndpointList(s)) == 0 }, testWait, testTick) } @@ -142,7 +139,7 @@ func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { failureClient := new(failureNSEClient) c := next.NewNetworkServiceEndpointRegistryClient( - querycache.NewClient(ctx, querycache.WithExpireTimeout(expireTimeout)), + querycache.NewNetworkServiceEndpointRegistryClient(ctx), failureClient, adapters.NetworkServiceEndpointServerToClient(mem), ) @@ -184,7 +181,7 @@ func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { } // 4. Wait for the expire to happen - clockMock.Add(expireTimeout) + time.Sleep(expireTimeout) _, err = c.Find(ctx, testNSEQuery(name)) require.Errorf(t, err, "find error") diff --git a/pkg/registry/common/querycache/option.go b/pkg/registry/common/querycache/option.go deleted file mode 100644 index f70220737..000000000 --- a/pkg/registry/common/querycache/option.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package querycache - -import "time" - -// Option is an option for cache -type Option func(c *cache) - -// WithExpireTimeout sets cache expire timeout -func WithExpireTimeout(expireTimeout time.Duration) Option { - return func(c *cache) { - c.expireTimeout = expireTimeout - } -}