Skip to content

Commit

Permalink
fix(feature/msmsg): Message sending (encryption) done!
Browse files Browse the repository at this point in the history
make sure to use PR tulir/libsignal-protocol-go#4
  • Loading branch information
purpshell committed Jul 2, 2024
1 parent 517493f commit ae28c33
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 33 deletions.
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ var (
ErrUnknownServer = errors.New("can't send message to unknown server")
ErrRecipientADJID = errors.New("message recipient must be a user JID with no device part")
ErrServerReturnedError = errors.New("server returned error")
ErrInvalidInlineBotID = errors.New("invalid inline bot ID")
)

type DownloadHTTPError struct {
Expand Down
1 change: 1 addition & 0 deletions mdtest/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*.json
*.png
*.jpe
vendor/
47 changes: 35 additions & 12 deletions mdtest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -1024,23 +1024,46 @@ func handleCmd(cmd string, args []string) {
}
case "sendbotmsg":
if len(args) < 1 {
log.Errorf("Usage: sendBotMsg <text>")
log.Errorf("Usage: sendBotMsg <inline jid (optional)> <text>")
return
}
text := strings.Join(args, " ")
personaID := "867051314767696$760019659443059"
msg := &waE2E.Message{
Conversation: &text,
var inlineJID types.JID
if len(args) > 1 {
jid, ok := parseJID(args[0])
if ok {
inlineJID = jid
} else {
inlineJID = types.EmptyJID
}
}

personaID := "867051314767696$760019659443059" // default meta bot personality: "Assistant"

// todo: make all of this as part of extras, and hide message secret generation
MessageContextInfo: &waE2E.MessageContextInfo{
BotMetadata: &waE2E.BotMetadata{
PersonaID: &personaID,
var resp, err = whatsmeow.SendResponse{}, error(nil)
if !inlineJID.IsEmpty() {
text := fmt.Sprintf("@%s %s", types.MetaAIJID.User, strings.Join(args[1:], " "))
msg := &waE2E.Message{
ExtendedTextMessage: &waE2E.ExtendedTextMessage{
Text: &text,
ContextInfo: &waE2E.ContextInfo{
MentionedJID: []string{types.MetaAIJID.String()},
},
},
//MessageSecret: random.Bytes(32),
},
}

resp, err = cli.SendMessage(context.Background(), inlineJID, msg, whatsmeow.SendRequestExtra{
BotPersonaID: personaID,
InlineBotJID: types.MetaAIJID,
})
} else {
text := strings.Join(args, " ")
msg := &waE2E.Message{
Conversation: &text,
}
resp, err = cli.SendMessage(context.Background(), types.MetaAIJID, msg, whatsmeow.SendRequestExtra{
BotPersonaID: personaID,
})
}
resp, err := cli.SendMessage(context.Background(), types.MetaAIJID, msg)
if err != nil {
log.Errorf("Error sending bot message: %v", err)
} else {
Expand Down
29 changes: 22 additions & 7 deletions prekeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,34 @@ func nodeToPreKeyBundle(deviceID uint32, node waBinary.Node) (*prekey.Bundle, er
}
identityKeyPub := *(*[32]byte)(identityKeyRaw)

preKey, err := nodeToPreKey(keysNode.GetChildByTag("key"))
if err != nil {
return nil, fmt.Errorf("invalid prekey in prekey response: %w", err)
preKeyNode, ok := keysNode.GetOptionalChildByTag("key")
preKey := &keys.PreKey{}
if ok {
var err error
preKey, err = nodeToPreKey(preKeyNode)
if err != nil {
return nil, fmt.Errorf("invalid prekey in prekey response: %w", err)
}
}

signedPreKey, err := nodeToPreKey(keysNode.GetChildByTag("skey"))
if err != nil {
return nil, fmt.Errorf("invalid signed prekey in prekey response: %w", err)
}

return prekey.NewBundle(registrationID, deviceID,
optional.NewOptionalUint32(preKey.KeyID), signedPreKey.KeyID,
ecc.NewDjbECPublicKey(*preKey.Pub), ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub))), nil
var bundle *prekey.Bundle
if ok {
bundle = prekey.NewBundle(registrationID, deviceID,
optional.NewOptionalUint32(preKey.KeyID), signedPreKey.KeyID,
ecc.NewDjbECPublicKey(*preKey.Pub), ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub)))
} else {
bundle = prekey.NewBundle(registrationID, deviceID, optional.NewEmptyUint32(), signedPreKey.KeyID,
nil, ecc.NewDjbECPublicKey(*signedPreKey.Pub), *signedPreKey.Signature,
identity.NewKey(ecc.NewDjbECPublicKey(identityKeyPub)))
}

return bundle, nil
}

func nodeToPreKey(node waBinary.Node) (*keys.PreKey, error) {
Expand Down
2 changes: 1 addition & 1 deletion retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (cli *Client) handleRetryReceipt(receipt *events.Receipt, node *waBinary.No
}
var content []waBinary.Node
if msg.wa != nil {
content = cli.getMessageContent(*encrypted, msg.wa, attrs, includeDeviceIdentity)
content = cli.getMessageContent(*encrypted, msg.wa, attrs, includeDeviceIdentity, waBinary.Node{})
} else {
content = []waBinary.Node{
*encrypted,
Expand Down
94 changes: 81 additions & 13 deletions send.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ type SendResponse struct {
type SendRequestExtra struct {
// The message ID to use when sending. If this is not provided, a random message ID will be generated
ID types.MessageID
// JID of the bot to be invoked (optional)
InlineBotJID types.JID
// Persona ID for the bot if you are accessing a bot (optional)
BotPersonaID string
// Should the message be sent as a peer message (protocol messages to your own devices, e.g. app state key requests)
Peer bool
// A timeout for the send request. Unlike timeouts using the context parameter, this only applies
Expand Down Expand Up @@ -164,7 +168,7 @@ type SendRequestExtra struct {
// in binary/proto/def.proto may be useful to find out all the allowed fields. Printing the RawMessage
// field in incoming message events to figure out what it contains is also a good way to learn how to
// send the same kind of message.
func (cli *Client) SendMessage(ctx context.Context, to types.JID, message *waProto.Message, extra ...SendRequestExtra) (resp SendResponse, err error) {
func (cli *Client) SendMessage(ctx context.Context, to types.JID, message *waE2E.Message, extra ...SendRequestExtra) (resp SendResponse, err error) {
var req SendRequestExtra
if len(extra) > 1 {
err = errors.New("only one extra parameter may be provided to SendMessage")
Expand Down Expand Up @@ -198,6 +202,64 @@ func (cli *Client) SendMessage(ctx context.Context, to types.JID, message *waPro
}
resp.ID = req.ID

isInlineBotMode := false

if !req.InlineBotJID.IsEmpty() {
if !req.InlineBotJID.IsBot() {
err = ErrInvalidInlineBotID
return
}
isInlineBotMode = true
}

isBotMode := isInlineBotMode || to.IsBot()
var botNode = waBinary.Node{
Tag: "bot",
}

if isBotMode {
// common code for both inline and not inline modes
message.MessageContextInfo = &waE2E.MessageContextInfo{
BotMetadata: &waE2E.BotMetadata{
PersonaID: &req.BotPersonaID,
},
MessageSecret: random.Bytes(32),
}

if isInlineBotMode {
// inline mode specific code
messageSecret := message.GetMessageContextInfo().GetMessageSecret()
message = &waE2E.Message{
BotInvokeMessage: &waE2E.FutureProofMessage{
Message: &waE2E.Message{
ExtendedTextMessage: message.ExtendedTextMessage,
MessageContextInfo: &waE2E.MessageContextInfo{
BotMetadata: message.MessageContextInfo.BotMetadata,
},
},
},
MessageContextInfo: message.MessageContextInfo,
}

botMessage := &waE2E.Message{
BotInvokeMessage: message.BotInvokeMessage,
MessageContextInfo: &waE2E.MessageContextInfo{
BotMetadata: message.MessageContextInfo.BotMetadata,
BotMessageSecret: applyBotMessageHKDF(messageSecret),
},
}

messagePlaintext, deviceSentMessagePlaintext, marshalErr := marshalMessage(req.InlineBotJID, botMessage)
if marshalErr != nil {
err = marshalErr
return
}

participantNodes, _ := cli.encryptMessageForDevices(ctx, []types.JID{req.InlineBotJID}, ownID, resp.ID, messagePlaintext, deviceSentMessagePlaintext, waBinary.Attrs{})
botNode.Content = participantNodes
}
}

start := time.Now()
// Sending multiple messages at a time can cause weird issues and makes it harder to retry safely
cli.messageSendLock.Lock()
Expand All @@ -209,6 +271,7 @@ func (cli *Client) SendMessage(ctx context.Context, to types.JID, message *waPro
if !req.Peer {
cli.addRecentMessage(to, req.ID, message, nil)
}

if message.GetMessageContextInfo().GetMessageSecret() != nil {
err = cli.Store.MsgSecrets.PutMessageSecret(to, ownID, req.ID, message.GetMessageContextInfo().GetMessageSecret())
if err != nil {
Expand All @@ -221,12 +284,12 @@ func (cli *Client) SendMessage(ctx context.Context, to types.JID, message *waPro
var data []byte
switch to.Server {
case types.GroupServer, types.BroadcastServer:
phash, data, err = cli.sendGroup(ctx, to, ownID, req.ID, message, &resp.DebugTimings)
phash, data, err = cli.sendGroup(ctx, to, ownID, req.ID, message, &resp.DebugTimings, botNode)
case types.DefaultUserServer:
if req.Peer {
data, err = cli.sendPeerMessage(to, req.ID, message, &resp.DebugTimings)
} else {
data, err = cli.sendDM(ctx, to, ownID, req.ID, message, &resp.DebugTimings)
data, err = cli.sendDM(ctx, to, ownID, req.ID, message, &resp.DebugTimings, botNode)
}
case types.NewsletterServer:
data, err = cli.sendNewsletter(to, req.ID, message, req.MediaHandle, &resp.DebugTimings)
Expand Down Expand Up @@ -532,7 +595,7 @@ func (cli *Client) sendNewsletter(to types.JID, id types.MessageID, message *waP
return data, nil
}

func (cli *Client) sendGroup(ctx context.Context, to, ownID types.JID, id types.MessageID, message *waProto.Message, timings *MessageDebugTimings) (string, []byte, error) {
func (cli *Client) sendGroup(ctx context.Context, to, ownID types.JID, id types.MessageID, message *waProto.Message, timings *MessageDebugTimings, botNode waBinary.Node) (string, []byte, error) {
var participants []types.JID
var err error
start := time.Now()
Expand Down Expand Up @@ -563,8 +626,8 @@ func (cli *Client) sendGroup(ctx context.Context, to, ownID types.JID, id types.
if err != nil {
return "", nil, fmt.Errorf("failed to create sender key distribution message to send %s to %s: %w", id, to, err)
}
skdMessage := &waProto.Message{
SenderKeyDistributionMessage: &waProto.SenderKeyDistributionMessage{
skdMessage := &waE2E.Message{
SenderKeyDistributionMessage: &waE2E.SenderKeyDistributionMessage{
GroupID: proto.String(to.String()),
AxolotlSenderKeyDistributionMessage: signalSKDMessage.Serialize(),
},
Expand All @@ -582,7 +645,7 @@ func (cli *Client) sendGroup(ctx context.Context, to, ownID types.JID, id types.
ciphertext := encrypted.SignedSerialize()
timings.GroupEncrypt = time.Since(start)

node, allDevices, err := cli.prepareMessageNode(ctx, to, ownID, id, message, participants, skdPlaintext, nil, timings)
node, allDevices, err := cli.prepareMessageNode(ctx, to, ownID, id, message, participants, skdPlaintext, nil, timings, botNode)
if err != nil {
return "", nil, err
}
Expand All @@ -608,7 +671,7 @@ func (cli *Client) sendGroup(ctx context.Context, to, ownID types.JID, id types.
return phash, data, nil
}

func (cli *Client) sendPeerMessage(to types.JID, id types.MessageID, message *waProto.Message, timings *MessageDebugTimings) ([]byte, error) {
func (cli *Client) sendPeerMessage(to types.JID, id types.MessageID, message *waE2E.Message, timings *MessageDebugTimings) ([]byte, error) {
node, err := cli.preparePeerMessageNode(to, id, message, timings)
if err != nil {
return nil, err
Expand All @@ -622,15 +685,15 @@ func (cli *Client) sendPeerMessage(to types.JID, id types.MessageID, message *wa
return data, nil
}

func (cli *Client) sendDM(ctx context.Context, to, ownID types.JID, id types.MessageID, message *waProto.Message, timings *MessageDebugTimings) ([]byte, error) {
func (cli *Client) sendDM(ctx context.Context, to, ownID types.JID, id types.MessageID, message *waE2E.Message, timings *MessageDebugTimings, botNode waBinary.Node) ([]byte, error) {
start := time.Now()
messagePlaintext, deviceSentMessagePlaintext, err := marshalMessage(to, message)
timings.Marshal = time.Since(start)
if err != nil {
return nil, err
}

node, _, err := cli.prepareMessageNode(ctx, to, ownID, id, message, []types.JID{to, ownID.ToNonAD()}, messagePlaintext, deviceSentMessagePlaintext, timings)
node, _, err := cli.prepareMessageNode(ctx, to, ownID, id, message, []types.JID{to, ownID.ToNonAD()}, messagePlaintext, deviceSentMessagePlaintext, timings, botNode)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -829,7 +892,7 @@ func (cli *Client) preparePeerMessageNode(to types.JID, id types.MessageID, mess
}, nil
}

func (cli *Client) getMessageContent(baseNode waBinary.Node, message *waProto.Message, msgAttrs waBinary.Attrs, includeIdentity bool) []waBinary.Node {
func (cli *Client) getMessageContent(baseNode waBinary.Node, message *waE2E.Message, msgAttrs waBinary.Attrs, includeIdentity bool, botNode waBinary.Node) []waBinary.Node {
content := []waBinary.Node{baseNode}
if includeIdentity {
content = append(content, cli.makeDeviceIdentityNode())
Expand All @@ -846,6 +909,10 @@ func (cli *Client) getMessageContent(baseNode waBinary.Node, message *waProto.Me
},
})
}

if botNode.Tag == "bot" && len(botNode.Content.([]waBinary.Node)) > 0 {
content = append(content, botNode)
}
if buttonType := getButtonTypeFromMessage(message); buttonType != "" {
content = append(content, waBinary.Node{
Tag: "biz",
Expand All @@ -858,7 +925,7 @@ func (cli *Client) getMessageContent(baseNode waBinary.Node, message *waProto.Me
return content
}

func (cli *Client) prepareMessageNode(ctx context.Context, to, ownID types.JID, id types.MessageID, message *waProto.Message, participants []types.JID, plaintext, dsmPlaintext []byte, timings *MessageDebugTimings) (*waBinary.Node, []types.JID, error) {
func (cli *Client) prepareMessageNode(ctx context.Context, to, ownID types.JID, id types.MessageID, message *waE2E.Message, participants []types.JID, plaintext, dsmPlaintext []byte, timings *MessageDebugTimings, botNode waBinary.Node) (*waBinary.Node, []types.JID, error) {
start := time.Now()
allDevices, err := cli.GetUserDevicesContext(ctx, participants)
timings.GetDevices = time.Since(start)
Expand Down Expand Up @@ -895,7 +962,7 @@ func (cli *Client) prepareMessageNode(ctx context.Context, to, ownID types.JID,
return &waBinary.Node{
Tag: "message",
Attrs: attrs,
Content: cli.getMessageContent(participantNode, message, attrs, includeIdentity),
Content: cli.getMessageContent(participantNode, message, attrs, includeIdentity, botNode),
}, allDevices, nil
}

Expand Down Expand Up @@ -956,6 +1023,7 @@ func (cli *Client) encryptMessageForDevices(ctx context.Context, allDevices []ty
cli.Log.Warnf("Failed to encrypt %s for %s: %v", id, jid, err)
continue
}

participantNodes = append(participantNodes, *encrypted)
if isPreKey {
includeIdentity = true
Expand Down
3 changes: 3 additions & 0 deletions user.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ func (cli *Client) GetUserDevicesContext(ctx context.Context, jids []types.JID)
devices = append(devices, cached.devices...)
} else if jid.Server == types.MessengerServer {
fbJIDsToSync = append(fbJIDsToSync, jid)
} else if jid.IsBot() {
// Bot JIDs do not have devices, the usync query is empty
devices = append(devices, jid)
} else {
jidsToSync = append(jidsToSync, jid)
}
Expand Down

0 comments on commit ae28c33

Please sign in to comment.