Skip to content

Commit

Permalink
support ttl check and update ttl
Browse files Browse the repository at this point in the history
  • Loading branch information
nnothing1 committed Dec 2, 2024
1 parent 232fbcb commit f74fe5e
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 8 deletions.
55 changes: 51 additions & 4 deletions consul_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
package consul

import (
"context"
"errors"
"fmt"
"strings"
"time"

"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/registry"
"github.com/hashicorp/consul/api"
)
Expand All @@ -30,8 +33,9 @@ type options struct {
}

type consulRegistry struct {
consulClient *api.Client
opts options
consulClient *api.Client
opts options
cancelUpdateTTL context.CancelFunc
}

const kvJoinChar = ":"
Expand Down Expand Up @@ -135,12 +139,23 @@ func (c *consulRegistry) Register(info *registry.Info) error {
Check: c.opts.check,
}

if c.opts.check != nil {
if c.opts.check != nil && c.opts.check.TTL == "" {
c.opts.check.TCP = fmt.Sprintf("%s:%d", host, port)
svcInfo.Check = c.opts.check
}

return c.consulClient.Agent().ServiceRegister(svcInfo)
if err := c.consulClient.Agent().ServiceRegister(svcInfo); err != nil {
return err
}

if c.opts.check.TTL != "" {
if ttl, err := time.ParseDuration(c.opts.check.TTL); err != nil {
return err
} else {
return c.startTTLHeartbeat(ttl)
}
}
return nil
}

// Deregister deregister a service from consul.
Expand All @@ -149,9 +164,41 @@ func (c *consulRegistry) Deregister(info *registry.Info) error {
if err != nil {
return err
}

if c.cancelUpdateTTL != nil {
defer c.cancelUpdateTTL()
}
return c.consulClient.Agent().ServiceDeregister(svcID)
}

// startTTLHeartbeat start a goroutine to periodically update TTL.
func (c *consulRegistry) startTTLHeartbeat(ttl time.Duration) error {
if ttl <= 1*time.Second {
return errors.New("consul check ttl must be greater than one second")
}

ctx, cancel := context.WithCancel(context.Background())
c.cancelUpdateTTL = cancel
go func() {
if err := c.consulClient.Agent().UpdateTTL(c.opts.check.CheckID, "online", api.HealthPassing); err != nil {
klog.Errorf("update ttl to consul failed, err=%v", err)
}
ticker := time.NewTicker(ttl - 1*time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.consulClient.Agent().UpdateTTL(c.opts.check.CheckID, "online", api.HealthPassing); err != nil {
klog.Errorf("update ttl to consul failed, err=%v", err)
}
case <-ctx.Done():
return
}
}
}()
return nil
}

func validateRegistryInfo(info *registry.Info) error {
if info.ServiceName == "" {
return errors.New("missing service name in consul register")
Expand Down
69 changes: 65 additions & 4 deletions consul_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ const (
)

var (
consulClient *consulapi.Client
cRegistry registry.Registry
cResolver discovery.Resolver
localIpAddr string
consulClient *consulapi.Client
cRegistry registry.Registry
cRegistryWithTTL registry.Registry
cResolver discovery.Resolver
localIpAddr string
)

func init() {
Expand All @@ -60,6 +61,19 @@ func init() {
}
cRegistry = r

r, err = NewConsulRegister(consulAddr, WithCheck(
&consulapi.AgentServiceCheck{
CheckID: "TEST-MY-CHECK-ID1",
TTL: "5s",
Timeout: "5s",
DeregisterCriticalServiceAfter: "1m",
},
))
if err != nil {
return
}
cRegistryWithTTL = r

resolver, err := NewConsulResolver(consulAddr)
if err != nil {
return
Expand Down Expand Up @@ -169,6 +183,53 @@ func TestRegister(t *testing.T) {
}
}

// TestRegisterWithTTLCheck tests the Register function with ttl check.
func TestRegisterWithTTLCheck(t *testing.T) {
var (
testSvcName = strconv.Itoa(int(time.Now().Unix())) + ".svc.local"
testSvcPort = 8081
testSvcWeight = 777
tagMap = map[string]string{
"k1": "vv1",
"k2": "vv2",
"k3": "vv3",
}
tagList = []string{"k1:vv1", "k2:vv2", "k3:vv3"}
)

// listen on the port, and wait for the health check to connect
addr := fmt.Sprintf("%s:%d", localIpAddr, testSvcPort)
lis, err := net.Listen("tcp", addr)
if err != nil {
t.Errorf("listen tcp %s failed!", addr)
t.Fail()
}
defer lis.Close()

testSvcAddr, _ := net.ResolveTCPAddr("tcp", addr)
info := &registry.Info{
ServiceName: testSvcName,
Weight: testSvcWeight,
Addr: testSvcAddr,
Tags: tagMap,
}
err = cRegistryWithTTL.Register(info)
assert.Nil(t, err)
// wait for health check passing
time.Sleep(time.Second * 6)

list, _, err := consulClient.Health().Service(testSvcName, "", true, nil)
assert.Nil(t, err)
if assert.Equal(t, 1, len(list)) {
ss := list[0]
gotSvc := ss.Service
assert.Equal(t, testSvcName, gotSvc.Service)
assert.Equal(t, testSvcAddr.String(), fmt.Sprintf("%s:%d", gotSvc.Address, gotSvc.Port))
assert.Equal(t, testSvcWeight, gotSvc.Weights.Passing)
assert.Equal(t, tagList, gotSvc.Tags)
}
}

// TestConsulDiscovery tests the ConsulDiscovery function.
func TestConsulDiscovery(t *testing.T) {
var (
Expand Down

0 comments on commit f74fe5e

Please sign in to comment.