Skip to content

Commit

Permalink
enhance: implement kafka for wal
Browse files Browse the repository at this point in the history
Signed-off-by: chyezh <[email protected]>
  • Loading branch information
chyezh committed Dec 20, 2024
1 parent 90de37e commit 6a5ef9c
Show file tree
Hide file tree
Showing 10 changed files with 497 additions and 3 deletions.
108 changes: 108 additions & 0 deletions pkg/streaming/walimpls/impls/kafka/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package kafka

import (
"github.com/confluentinc/confluent-kafka-go/kafka"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
"github.com/milvus-io/milvus/pkg/streaming/walimpls/registry"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)

const (
walName = "kafka"
)

func init() {
// register the builder to the wal registry.
registry.RegisterBuilder(&builderImpl{})
// register the unmarshaler to the message registry.
message.RegisterMessageIDUnmsarshaler(walName, UnmarshalMessageID)
}

// builderImpl is the builder for pulsar wal.
type builderImpl struct{}

// Name returns the name of the wal.
func (b *builderImpl) Name() string {
return walName
}

// Build build a wal instance.
func (b *builderImpl) Build() (walimpls.OpenerImpls, error) {
producerConfig, consumerConfig := b.getProducerAndConsumerConfig()

p, err := kafka.NewProducer(&producerConfig)
if err != nil {
return nil, err
}
return &openerImpl{
p: p,
consumerConfig: consumerConfig,
}, nil
}

// getProducerAndConsumerConfig returns the producer and consumer config.
func (b *builderImpl) getProducerAndConsumerConfig() (producerConfig kafka.ConfigMap, consumerConfig kafka.ConfigMap) {
config := &paramtable.Get().KafkaCfg
producerConfig = getBasicConfig(config)
consumerConfig = cloneKafkaConfig(producerConfig)

producerConfig.SetKey("message.max.bytes", 10485760)
producerConfig.SetKey("compression.codec", "zstd")
// we want to ensure tt send out as soon as possible
producerConfig.SetKey("linger.ms", 5)
for k, v := range config.ProducerExtraConfig.GetValue() {
producerConfig.SetKey(k, v)
}

consumerConfig.SetKey("allow.auto.create.topics", true)
for k, v := range config.ConsumerExtraConfig.GetValue() {
consumerConfig.SetKey(k, v)
}

return producerConfig, consumerConfig
}

// getBasicConfig returns the basic kafka config.
func getBasicConfig(config *paramtable.KafkaConfig) kafka.ConfigMap {
basicConfig := kafka.ConfigMap{
"bootstrap.servers": config.Address.GetValue(),
"api.version.request": true,
"reconnect.backoff.ms": 20,
"reconnect.backoff.max.ms": 5000,
}

if (config.SaslUsername.GetValue() == "" && config.SaslPassword.GetValue() != "") ||
(config.SaslUsername.GetValue() != "" && config.SaslPassword.GetValue() == "") {
panic("enable security mode need config username and password at the same time!")
}

if config.SecurityProtocol.GetValue() != "" {
basicConfig.SetKey("security.protocol", config.SecurityProtocol.GetValue())
}

if config.SaslUsername.GetValue() != "" && config.SaslPassword.GetValue() != "" {
basicConfig.SetKey("sasl.mechanisms", config.SaslMechanisms.GetValue())
basicConfig.SetKey("sasl.username", config.SaslUsername.GetValue())
basicConfig.SetKey("sasl.password", config.SaslPassword.GetValue())
}

if config.KafkaUseSSL.GetAsBool() {
basicConfig.SetKey("ssl.certificate.location", config.KafkaTLSCert.GetValue())
basicConfig.SetKey("ssl.key.location", config.KafkaTLSKey.GetValue())
basicConfig.SetKey("ssl.ca.location", config.KafkaTLSCACert.GetValue())
if config.KafkaTLSKeyPassword.GetValue() != "" {
basicConfig.SetKey("ssl.key.password", config.KafkaTLSKeyPassword.GetValue())
}
}
return basicConfig
}

// cloneKafkaConfig clones a kafka config.
func cloneKafkaConfig(config kafka.ConfigMap) kafka.ConfigMap {
newConfig := make(kafka.ConfigMap)
for k, v := range config {
newConfig[k] = v
}
return newConfig
}
53 changes: 53 additions & 0 deletions pkg/streaming/walimpls/impls/kafka/kafka_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package kafka

import (
"testing"

"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
"github.com/milvus-io/milvus/pkg/streaming/walimpls/registry"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/zeebo/assert"
)

func TestMain(m *testing.M) {
paramtable.Init()
m.Run()
}

func TestRegistry(t *testing.T) {
registeredB := registry.MustGetBuilder(walName)
assert.NotNil(t, registeredB)
assert.Equal(t, walName, registeredB.Name())

id, err := message.UnmarshalMessageID(walName,
kafkaID(123).Marshal())
assert.NoError(t, err)
assert.True(t, id.EQ(kafkaID(123)))
}

func TestKafka(t *testing.T) {
walimpls.NewWALImplsTestFramework(t, 1000, &builderImpl{}).Run()
}

func TestGetBasicConfig(t *testing.T) {
config := &paramtable.Get().KafkaCfg
oldSecurityProtocol := config.SecurityProtocol.SwapTempValue("test")
oldSaslUsername := config.SaslUsername.SwapTempValue("test")
oldSaslPassword := config.SaslPassword.SwapTempValue("test")
oldkafkaUseSSL := config.KafkaUseSSL.SwapTempValue("true")
oldKafkaTLSKeyPassword := config.KafkaTLSKeyPassword.SwapTempValue("test")
defer func() {
config.SecurityProtocol.SwapTempValue(oldSecurityProtocol)
config.SaslUsername.SwapTempValue(oldSaslUsername)
config.SaslPassword.SwapTempValue(oldSaslPassword)
config.KafkaUseSSL.SwapTempValue(oldkafkaUseSSL)
config.KafkaTLSKeyPassword.SwapTempValue(oldKafkaTLSKeyPassword)
}()
basicConfig := getBasicConfig(config)

assert.NotNil(t, basicConfig["ssl.key.password"])
assert.NotNil(t, basicConfig["ssl.certificate.location"])
assert.NotNil(t, basicConfig["sasl.username"])
assert.NotNil(t, basicConfig["security.protocol"])
}
63 changes: 63 additions & 0 deletions pkg/streaming/walimpls/impls/kafka/message_id.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package kafka

import (
"strconv"

"github.com/confluentinc/confluent-kafka-go/kafka"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/pkg/errors"
)

func UnmarshalMessageID(data string) (message.MessageID, error) {
id, err := unmarshalMessageID(data)
if err != nil {
return nil, err
}
return id, nil
}

func unmarshalMessageID(data string) (kafkaID, error) {
v, err := message.DecodeUint64(data)
if err != nil {
return 0, errors.Wrapf(message.ErrInvalidMessageID, "decode kafkaID fail with err: %s, id: %s", err.Error(), data)
}
return kafkaID(v), nil
}

type kafkaID kafka.Offset

// RmqID returns the message id for conversion
// Don't delete this function until conversion logic removed.
// TODO: remove in future.
func (id kafkaID) KafkaID() kafka.Offset {
return kafka.Offset(id)
}

// WALName returns the name of message id related wal.
func (id kafkaID) WALName() string {
return walName
}

// LT less than.
func (id kafkaID) LT(other message.MessageID) bool {
return id < other.(kafkaID)
}

// LTE less than or equal to.
func (id kafkaID) LTE(other message.MessageID) bool {
return id <= other.(kafkaID)
}

// EQ Equal to.
func (id kafkaID) EQ(other message.MessageID) bool {
return id == other.(kafkaID)
}

// Marshal marshal the message id.
func (id kafkaID) Marshal() string {
return message.EncodeInt64(int64(id))
}

func (id kafkaID) String() string {
return strconv.FormatInt(int64(id), 10)
}
31 changes: 31 additions & 0 deletions pkg/streaming/walimpls/impls/kafka/message_id_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package kafka

import (
"testing"

"github.com/confluentinc/confluent-kafka-go/kafka"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/stretchr/testify/assert"
)

func TestMessageID(t *testing.T) {
assert.Equal(t, kafka.Offset(1), message.MessageID(kafkaID(1)).(interface{ KafkaID() kafka.Offset }).KafkaID())

assert.Equal(t, walName, kafkaID(1).WALName())

assert.True(t, kafkaID(1).LT(kafkaID(2)))
assert.True(t, kafkaID(1).EQ(kafkaID(1)))
assert.True(t, kafkaID(1).LTE(kafkaID(1)))
assert.True(t, kafkaID(1).LTE(kafkaID(2)))
assert.False(t, kafkaID(2).LT(kafkaID(1)))
assert.False(t, kafkaID(2).EQ(kafkaID(1)))
assert.False(t, kafkaID(2).LTE(kafkaID(1)))
assert.True(t, kafkaID(2).LTE(kafkaID(2)))

msgID, err := UnmarshalMessageID(kafkaID(1).Marshal())
assert.NoError(t, err)
assert.Equal(t, kafkaID(1), msgID)

_, err = UnmarshalMessageID(string([]byte{0x01, 0x02, 0x03, 0x04}))
assert.Error(t, err)
}
28 changes: 28 additions & 0 deletions pkg/streaming/walimpls/impls/kafka/opener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package kafka

import (
"context"

"github.com/confluentinc/confluent-kafka-go/kafka"
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
"github.com/milvus-io/milvus/pkg/streaming/walimpls/helper"
)

var _ walimpls.OpenerImpls = (*openerImpl)(nil)

type openerImpl struct {
p *kafka.Producer
consumerConfig kafka.ConfigMap
}

func (o *openerImpl) Open(ctx context.Context, opt *walimpls.OpenOption) (walimpls.WALImpls, error) {
return &walImpl{
WALHelper: helper.NewWALHelper(opt),
p: o.p,
consumerConfig: o.consumerConfig,
}, nil
}

func (o *openerImpl) Close() {
o.p.Close()
}
87 changes: 87 additions & 0 deletions pkg/streaming/walimpls/impls/kafka/scanner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package kafka

import (
"time"

"github.com/confluentinc/confluent-kafka-go/kafka"
"github.com/milvus-io/milvus/pkg/streaming/util/message"
"github.com/milvus-io/milvus/pkg/streaming/walimpls"
"github.com/milvus-io/milvus/pkg/streaming/walimpls/helper"
)

var _ walimpls.ScannerImpls = (*scannerImpl)(nil)

// newScanner creates a new scanner.
func newScanner(scannerName string, exclude *kafkaID, consumer *kafka.Consumer) *scannerImpl {
s := &scannerImpl{
ScannerHelper: helper.NewScannerHelper(scannerName),
consumer: consumer,
msgChannel: make(chan message.ImmutableMessage, 1),
exclude: exclude,
}
go s.executeConsume()
return s
}

// scannerImpl is the implementation of ScannerImpls for kafka.
type scannerImpl struct {
*helper.ScannerHelper
consumer *kafka.Consumer
msgChannel chan message.ImmutableMessage
exclude *kafkaID
}

// Chan returns the channel of message.
func (s *scannerImpl) Chan() <-chan message.ImmutableMessage {
return s.msgChannel
}

// Close the scanner, release the underlying resources.
// Return the error same with `Error`
func (s *scannerImpl) Close() error {
s.consumer.Unassign()
err := s.ScannerHelper.Close()
s.consumer.Close()
return err
}

func (s *scannerImpl) executeConsume() {
defer close(s.msgChannel)
for {
msg, err := s.consumer.ReadMessage(200 * time.Millisecond)
if err != nil {
if s.Context().Err() != nil {
// context canceled, means the the scanner is closed.
s.Finish(nil)
return
}
if c, ok := err.(kafka.Error); ok && c.Code() == kafka.ErrTimedOut {
continue
}
s.Finish(err)
return
}
messageID := kafkaID(msg.TopicPartition.Offset)
if s.exclude != nil && messageID.EQ(*s.exclude) {
// Skip the message that is exclude for StartAfter semantics.
continue
}

properties := make(map[string]string, len(msg.Headers))
for _, header := range msg.Headers {
properties[header.Key] = string(header.Value)
}

newImmutableMessage := message.NewImmutableMesasge(
messageID,
msg.Value,
properties,
)
select {
case <-s.Context().Done():
s.Finish(nil)
return
case s.msgChannel <- newImmutableMessage:
}
}
}
Loading

0 comments on commit 6a5ef9c

Please sign in to comment.