diff --git a/pkg/streaming/walimpls/impls/kafka/builder.go b/pkg/streaming/walimpls/impls/kafka/builder.go new file mode 100644 index 0000000000000..256c52686907e --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/builder.go @@ -0,0 +1,108 @@ +package kafka + +import ( + "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/registry" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +const ( + walName = "kafka" +) + +func init() { + // register the builder to the wal registry. + registry.RegisterBuilder(&builderImpl{}) + // register the unmarshaler to the message registry. + message.RegisterMessageIDUnmsarshaler(walName, UnmarshalMessageID) +} + +// builderImpl is the builder for pulsar wal. +type builderImpl struct{} + +// Name returns the name of the wal. +func (b *builderImpl) Name() string { + return walName +} + +// Build build a wal instance. +func (b *builderImpl) Build() (walimpls.OpenerImpls, error) { + producerConfig, consumerConfig := b.getProducerAndConsumerConfig() + + p, err := kafka.NewProducer(&producerConfig) + if err != nil { + return nil, err + } + return &openerImpl{ + p: p, + consumerConfig: consumerConfig, + }, nil +} + +// getProducerAndConsumerConfig returns the producer and consumer config. +func (b *builderImpl) getProducerAndConsumerConfig() (producerConfig kafka.ConfigMap, consumerConfig kafka.ConfigMap) { + config := ¶mtable.Get().KafkaCfg + producerConfig = getBasicConfig(config) + consumerConfig = cloneKafkaConfig(producerConfig) + + producerConfig.SetKey("message.max.bytes", 10485760) + producerConfig.SetKey("compression.codec", "zstd") + // we want to ensure tt send out as soon as possible + producerConfig.SetKey("linger.ms", 5) + for k, v := range config.ProducerExtraConfig.GetValue() { + producerConfig.SetKey(k, v) + } + + consumerConfig.SetKey("allow.auto.create.topics", true) + for k, v := range config.ConsumerExtraConfig.GetValue() { + consumerConfig.SetKey(k, v) + } + + return producerConfig, consumerConfig +} + +// getBasicConfig returns the basic kafka config. +func getBasicConfig(config *paramtable.KafkaConfig) kafka.ConfigMap { + basicConfig := kafka.ConfigMap{ + "bootstrap.servers": config.Address.GetValue(), + "api.version.request": true, + "reconnect.backoff.ms": 20, + "reconnect.backoff.max.ms": 5000, + } + + if (config.SaslUsername.GetValue() == "" && config.SaslPassword.GetValue() != "") || + (config.SaslUsername.GetValue() != "" && config.SaslPassword.GetValue() == "") { + panic("enable security mode need config username and password at the same time!") + } + + if config.SecurityProtocol.GetValue() != "" { + basicConfig.SetKey("security.protocol", config.SecurityProtocol.GetValue()) + } + + if config.SaslUsername.GetValue() != "" && config.SaslPassword.GetValue() != "" { + basicConfig.SetKey("sasl.mechanisms", config.SaslMechanisms.GetValue()) + basicConfig.SetKey("sasl.username", config.SaslUsername.GetValue()) + basicConfig.SetKey("sasl.password", config.SaslPassword.GetValue()) + } + + if config.KafkaUseSSL.GetAsBool() { + basicConfig.SetKey("ssl.certificate.location", config.KafkaTLSCert.GetValue()) + basicConfig.SetKey("ssl.key.location", config.KafkaTLSKey.GetValue()) + basicConfig.SetKey("ssl.ca.location", config.KafkaTLSCACert.GetValue()) + if config.KafkaTLSKeyPassword.GetValue() != "" { + basicConfig.SetKey("ssl.key.password", config.KafkaTLSKeyPassword.GetValue()) + } + } + return basicConfig +} + +// cloneKafkaConfig clones a kafka config. +func cloneKafkaConfig(config kafka.ConfigMap) kafka.ConfigMap { + newConfig := make(kafka.ConfigMap) + for k, v := range config { + newConfig[k] = v + } + return newConfig +} diff --git a/pkg/streaming/walimpls/impls/kafka/kafka_test.go b/pkg/streaming/walimpls/impls/kafka/kafka_test.go new file mode 100644 index 0000000000000..ce2c32d18d13c --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/kafka_test.go @@ -0,0 +1,53 @@ +package kafka + +import ( + "testing" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/registry" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/zeebo/assert" +) + +func TestMain(m *testing.M) { + paramtable.Init() + m.Run() +} + +func TestRegistry(t *testing.T) { + registeredB := registry.MustGetBuilder(walName) + assert.NotNil(t, registeredB) + assert.Equal(t, walName, registeredB.Name()) + + id, err := message.UnmarshalMessageID(walName, + kafkaID(123).Marshal()) + assert.NoError(t, err) + assert.True(t, id.EQ(kafkaID(123))) +} + +func TestKafka(t *testing.T) { + walimpls.NewWALImplsTestFramework(t, 1000, &builderImpl{}).Run() +} + +func TestGetBasicConfig(t *testing.T) { + config := ¶mtable.Get().KafkaCfg + oldSecurityProtocol := config.SecurityProtocol.SwapTempValue("test") + oldSaslUsername := config.SaslUsername.SwapTempValue("test") + oldSaslPassword := config.SaslPassword.SwapTempValue("test") + oldkafkaUseSSL := config.KafkaUseSSL.SwapTempValue("true") + oldKafkaTLSKeyPassword := config.KafkaTLSKeyPassword.SwapTempValue("test") + defer func() { + config.SecurityProtocol.SwapTempValue(oldSecurityProtocol) + config.SaslUsername.SwapTempValue(oldSaslUsername) + config.SaslPassword.SwapTempValue(oldSaslPassword) + config.KafkaUseSSL.SwapTempValue(oldkafkaUseSSL) + config.KafkaTLSKeyPassword.SwapTempValue(oldKafkaTLSKeyPassword) + }() + basicConfig := getBasicConfig(config) + + assert.NotNil(t, basicConfig["ssl.key.password"]) + assert.NotNil(t, basicConfig["ssl.certificate.location"]) + assert.NotNil(t, basicConfig["sasl.username"]) + assert.NotNil(t, basicConfig["security.protocol"]) +} diff --git a/pkg/streaming/walimpls/impls/kafka/message_id.go b/pkg/streaming/walimpls/impls/kafka/message_id.go new file mode 100644 index 0000000000000..6ce14f19522ac --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/message_id.go @@ -0,0 +1,63 @@ +package kafka + +import ( + "strconv" + + "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/pkg/errors" +) + +func UnmarshalMessageID(data string) (message.MessageID, error) { + id, err := unmarshalMessageID(data) + if err != nil { + return nil, err + } + return id, nil +} + +func unmarshalMessageID(data string) (kafkaID, error) { + v, err := message.DecodeUint64(data) + if err != nil { + return 0, errors.Wrapf(message.ErrInvalidMessageID, "decode kafkaID fail with err: %s, id: %s", err.Error(), data) + } + return kafkaID(v), nil +} + +type kafkaID kafka.Offset + +// RmqID returns the message id for conversion +// Don't delete this function until conversion logic removed. +// TODO: remove in future. +func (id kafkaID) KafkaID() kafka.Offset { + return kafka.Offset(id) +} + +// WALName returns the name of message id related wal. +func (id kafkaID) WALName() string { + return walName +} + +// LT less than. +func (id kafkaID) LT(other message.MessageID) bool { + return id < other.(kafkaID) +} + +// LTE less than or equal to. +func (id kafkaID) LTE(other message.MessageID) bool { + return id <= other.(kafkaID) +} + +// EQ Equal to. +func (id kafkaID) EQ(other message.MessageID) bool { + return id == other.(kafkaID) +} + +// Marshal marshal the message id. +func (id kafkaID) Marshal() string { + return message.EncodeInt64(int64(id)) +} + +func (id kafkaID) String() string { + return strconv.FormatInt(int64(id), 10) +} diff --git a/pkg/streaming/walimpls/impls/kafka/message_id_test.go b/pkg/streaming/walimpls/impls/kafka/message_id_test.go new file mode 100644 index 0000000000000..ae1184b254d14 --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/message_id_test.go @@ -0,0 +1,31 @@ +package kafka + +import ( + "testing" + + "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/stretchr/testify/assert" +) + +func TestMessageID(t *testing.T) { + assert.Equal(t, kafka.Offset(1), message.MessageID(kafkaID(1)).(interface{ KafkaID() kafka.Offset }).KafkaID()) + + assert.Equal(t, walName, kafkaID(1).WALName()) + + assert.True(t, kafkaID(1).LT(kafkaID(2))) + assert.True(t, kafkaID(1).EQ(kafkaID(1))) + assert.True(t, kafkaID(1).LTE(kafkaID(1))) + assert.True(t, kafkaID(1).LTE(kafkaID(2))) + assert.False(t, kafkaID(2).LT(kafkaID(1))) + assert.False(t, kafkaID(2).EQ(kafkaID(1))) + assert.False(t, kafkaID(2).LTE(kafkaID(1))) + assert.True(t, kafkaID(2).LTE(kafkaID(2))) + + msgID, err := UnmarshalMessageID(kafkaID(1).Marshal()) + assert.NoError(t, err) + assert.Equal(t, kafkaID(1), msgID) + + _, err = UnmarshalMessageID(string([]byte{0x01, 0x02, 0x03, 0x04})) + assert.Error(t, err) +} diff --git a/pkg/streaming/walimpls/impls/kafka/opener.go b/pkg/streaming/walimpls/impls/kafka/opener.go new file mode 100644 index 0000000000000..6bb9a2954761a --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/opener.go @@ -0,0 +1,28 @@ +package kafka + +import ( + "context" + + "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/helper" +) + +var _ walimpls.OpenerImpls = (*openerImpl)(nil) + +type openerImpl struct { + p *kafka.Producer + consumerConfig kafka.ConfigMap +} + +func (o *openerImpl) Open(ctx context.Context, opt *walimpls.OpenOption) (walimpls.WALImpls, error) { + return &walImpl{ + WALHelper: helper.NewWALHelper(opt), + p: o.p, + consumerConfig: o.consumerConfig, + }, nil +} + +func (o *openerImpl) Close() { + o.p.Close() +} diff --git a/pkg/streaming/walimpls/impls/kafka/scanner.go b/pkg/streaming/walimpls/impls/kafka/scanner.go new file mode 100644 index 0000000000000..a55981d9dc7f8 --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/scanner.go @@ -0,0 +1,87 @@ +package kafka + +import ( + "time" + + "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/helper" +) + +var _ walimpls.ScannerImpls = (*scannerImpl)(nil) + +// newScanner creates a new scanner. +func newScanner(scannerName string, exclude *kafkaID, consumer *kafka.Consumer) *scannerImpl { + s := &scannerImpl{ + ScannerHelper: helper.NewScannerHelper(scannerName), + consumer: consumer, + msgChannel: make(chan message.ImmutableMessage, 1), + exclude: exclude, + } + go s.executeConsume() + return s +} + +// scannerImpl is the implementation of ScannerImpls for kafka. +type scannerImpl struct { + *helper.ScannerHelper + consumer *kafka.Consumer + msgChannel chan message.ImmutableMessage + exclude *kafkaID +} + +// Chan returns the channel of message. +func (s *scannerImpl) Chan() <-chan message.ImmutableMessage { + return s.msgChannel +} + +// Close the scanner, release the underlying resources. +// Return the error same with `Error` +func (s *scannerImpl) Close() error { + s.consumer.Unassign() + err := s.ScannerHelper.Close() + s.consumer.Close() + return err +} + +func (s *scannerImpl) executeConsume() { + defer close(s.msgChannel) + for { + msg, err := s.consumer.ReadMessage(200 * time.Millisecond) + if err != nil { + if s.Context().Err() != nil { + // context canceled, means the the scanner is closed. + s.Finish(nil) + return + } + if c, ok := err.(kafka.Error); ok && c.Code() == kafka.ErrTimedOut { + continue + } + s.Finish(err) + return + } + messageID := kafkaID(msg.TopicPartition.Offset) + if s.exclude != nil && messageID.EQ(*s.exclude) { + // Skip the message that is exclude for StartAfter semantics. + continue + } + + properties := make(map[string]string, len(msg.Headers)) + for _, header := range msg.Headers { + properties[header.Key] = string(header.Value) + } + + newImmutableMessage := message.NewImmutableMesasge( + messageID, + msg.Value, + properties, + ) + select { + case <-s.Context().Done(): + s.Finish(nil) + return + case s.msgChannel <- newImmutableMessage: + } + } +} diff --git a/pkg/streaming/walimpls/impls/kafka/wal.go b/pkg/streaming/walimpls/impls/kafka/wal.go new file mode 100644 index 0000000000000..af83a9e710036 --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/wal.go @@ -0,0 +1,106 @@ +package kafka + +import ( + "context" + + "github.com/cockroachdb/errors" + "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/helper" +) + +var _ walimpls.WALImpls = (*walImpl)(nil) + +type walImpl struct { + *helper.WALHelper + p *kafka.Producer + consumerConfig kafka.ConfigMap +} + +func (w *walImpl) WALName() string { + return walName +} + +func (w *walImpl) Append(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { + properties := msg.Properties().ToRawMap() + headers := make([]kafka.Header, 0, len(properties)) + for key, value := range properties { + header := kafka.Header{Key: key, Value: []byte(value)} + headers = append(headers, header) + } + ch := make(chan kafka.Event, 1) + topic := w.Channel().Name + + if err := w.p.Produce(&kafka.Message{ + TopicPartition: kafka.TopicPartition{Topic: &topic, Partition: 0}, + Value: msg.Payload(), + Headers: headers, + }, ch); err != nil { + return nil, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case event := <-ch: + relatedMsg := event.(*kafka.Message) + if relatedMsg.TopicPartition.Error != nil { + return nil, relatedMsg.TopicPartition.Error + } + return kafkaID(relatedMsg.TopicPartition.Offset), nil + } +} + +func (w *walImpl) Read(ctx context.Context, opt walimpls.ReadOption) (s walimpls.ScannerImpls, err error) { + // The scanner is stateless, so we can create a scanner with an anonymous consumer. + // and there's no commit opeartions. + consumerConfig := cloneKafkaConfig(w.consumerConfig) + consumerConfig.SetKey("group.id", opt.Name) + switch opt.DeliverPolicy.GetPolicy().(type) { + } + c, err := kafka.NewConsumer(&consumerConfig) + if err != nil { + return nil, errors.Wrap(err, "failed to create kafka consumer") + } + + topic := w.Channel().Name + seekPosition := kafka.TopicPartition{ + Topic: &topic, + Partition: 0, + } + var exclude *kafkaID + switch t := opt.DeliverPolicy.GetPolicy().(type) { + case *streamingpb.DeliverPolicy_All: + seekPosition.Offset = kafka.OffsetBeginning + case *streamingpb.DeliverPolicy_Latest: + seekPosition.Offset = kafka.OffsetEnd + case *streamingpb.DeliverPolicy_StartFrom: + id, err := unmarshalMessageID(t.StartFrom.GetId()) + if err != nil { + return nil, err + } + seekPosition.Offset = kafka.Offset(id) + case *streamingpb.DeliverPolicy_StartAfter: + id, err := unmarshalMessageID(t.StartAfter.GetId()) + if err != nil { + return nil, err + } + seekPosition.Offset = kafka.Offset(id) + exclude = &id + default: + panic("unknown deliver policy") + } + + if err := c.Assign([]kafka.TopicPartition{seekPosition}); err != nil { + return nil, errors.Wrap(err, "failed to assign kafka consumer") + } + return newScanner(opt.Name, exclude, c), nil +} + +func (w *walImpl) Close() { + // The lifetime control of the producer is delegated to the wal adaptor. + // So we just make resource cleanup here. + // But kafka producer is not topic level, so we don't close it here. +} diff --git a/pkg/streaming/walimpls/impls/pulsar/message_id_test.go b/pkg/streaming/walimpls/impls/pulsar/message_id_test.go index 5635c939c9666..4475f5d9e22cc 100644 --- a/pkg/streaming/walimpls/impls/pulsar/message_id_test.go +++ b/pkg/streaming/walimpls/impls/pulsar/message_id_test.go @@ -4,11 +4,19 @@ import ( "testing" "github.com/apache/pulsar-client-go/pulsar" + "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" ) func TestMessageID(t *testing.T) { + pid := message.MessageID(newMessageIDOfPulsar(1, 2, 3)).(interface{ PulsarID() pulsar.MessageID }).PulsarID() + assert.Equal(t, walName, newMessageIDOfPulsar(1, 2, 3).WALName()) + + assert.Equal(t, int64(1), pid.LedgerID()) + assert.Equal(t, int64(2), pid.EntryID()) + assert.Equal(t, int32(3), pid.BatchIdx()) + ids := []pulsarID{ newMessageIDOfPulsar(0, 0, 0), newMessageIDOfPulsar(0, 0, 1), diff --git a/pkg/streaming/walimpls/impls/rmq/message_id_test.go b/pkg/streaming/walimpls/impls/rmq/message_id_test.go index b757e57ab67ea..54cbfbd5ceb04 100644 --- a/pkg/streaming/walimpls/impls/rmq/message_id_test.go +++ b/pkg/streaming/walimpls/impls/rmq/message_id_test.go @@ -3,10 +3,14 @@ package rmq import ( "testing" + "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/stretchr/testify/assert" ) func TestMessageID(t *testing.T) { + assert.Equal(t, int64(1), message.MessageID(rmqID(1)).(interface{ RmqID() int64 }).RmqID()) + assert.Equal(t, walName, rmqID(1).WALName()) + assert.True(t, rmqID(1).LT(rmqID(2))) assert.True(t, rmqID(1).EQ(rmqID(1))) assert.True(t, rmqID(1).LTE(rmqID(1))) diff --git a/pkg/util/paramtable/param_item.go b/pkg/util/paramtable/param_item.go index b8718b65ba848..1f03e098f05a4 100644 --- a/pkg/util/paramtable/param_item.go +++ b/pkg/util/paramtable/param_item.go @@ -101,12 +101,18 @@ func (pi *ParamItem) getWithRaw() (result, raw string, err error) { // SetTempValue set the value for this ParamItem, // Once value set, ParamItem will use the value instead of underlying config manager. // Usage: should only use for unittest, swap empty string will remove the value. -func (pi *ParamItem) SwapTempValue(s string) *string { +func (pi *ParamItem) SwapTempValue(s string) string { if s == "" { - return pi.tempValue.Swap(nil) + if old := pi.tempValue.Swap(nil); old != nil { + return *old + } + return "" } pi.manager.EvictCachedValue(pi.Key) - return pi.tempValue.Swap(&s) + if old := pi.tempValue.Swap(&s); old != nil { + return *old + } + return "" } func (pi *ParamItem) GetValue() string {