diff --git a/chatwoot-handler.go b/chatwoot-handler.go index 4497bdc..2ac735a 100644 --- a/chatwoot-handler.go +++ b/chatwoot-handler.go @@ -3,7 +3,9 @@ package main import ( "bytes" "context" + "database/sql" "encoding/json" + "errors" "fmt" "image" _ "image/gif" @@ -229,6 +231,11 @@ func handleAttachment(ctx context.Context, roomID id.RoomID, chatwootMessageID i }) } +type StartNewChatResp struct { + RoomID id.RoomID `json:"room_id,omitempty"` + Error string `json:"error,omitempty"` +} + func HandleMessageCreated(ctx context.Context, mc chatwootapi.MessageCreated) error { log := zerolog.Ctx(ctx).With(). Str("component", "handle_message_created"). @@ -243,14 +250,61 @@ func HandleMessageCreated(ctx context.Context, mc chatwootapi.MessageCreated) er roomID, _, err := stateStore.GetMatrixRoomFromChatwootConversation(ctx, mc.Conversation.ID) if err != nil { - log.Err(err).Int("conversation_id", mc.Conversation.ID).Msg("no room found for conversation") - return err + if !errors.Is(err, sql.ErrNoRows) { + log.Err(err).Msg("couldn't find room for conversation") + return err + } + + if !configuration.StartNewChat.Enable { + log.Err(err).Msg("couldn't find room for conversation") + return err + } + + log := log.With().Bool("snc_enabled", true).Logger() + + // Create a new room for this conversation using the start new chat + // endpoint. + body, err := json.Marshal(mc.Conversation.Meta.Sender) + if err != nil { + log.Err(err).Msg("failed to marshal sender to JSON") + return err + } + req, err := http.NewRequest(http.MethodPost, configuration.StartNewChat.Endpoint, bytes.NewReader(body)) + if err != nil { + log.Err(err).Msg("failed to create request") + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", configuration.StartNewChat.Token)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Err(err).Msg("failed to make request") + return err + } + defer resp.Body.Close() + var sncResp StartNewChatResp + err = json.NewDecoder(resp.Body).Decode(&sncResp) + if err != nil { + log.Err(err).Msg("failed to read response body") + return err + } + + if resp.StatusCode != http.StatusOK { + log.Warn().Int("status_code", resp.StatusCode).Any("resp", sncResp).Msg("failed to create new chat") + return fmt.Errorf("failed to create new chat: %s", sncResp.Error) + } else if sncResp.RoomID == "" { + log.Warn().Any("resp", sncResp).Msg("invalid start new chat response") + return fmt.Errorf("invalid start new chat response: %s", sncResp.Error) + } + + log.Info().Str("room_id", sncResp.RoomID.String()).Msg("created new chat for conversation") + roomID = sncResp.RoomID } log = log.With().Str("room_id", roomID.String()).Logger() ctx = log.WithContext(ctx) - // Acquire the lock, so that we don't have race conditions with the - // matrix handler. + // Acquire the lock, so that we don't have race conditions with the matrix + // handler. if _, found := roomSendlocks[roomID]; !found { log.Debug().Msg("creating send lock") roomSendlocks[roomID] = &sync.Mutex{} diff --git a/chatwoot.go b/chatwoot.go index 6c2603c..de9717f 100644 --- a/chatwoot.go +++ b/chatwoot.go @@ -33,7 +33,6 @@ var configuration Configuration var stateStore *database.Database var chatwootAPI *chatwootapi.ChatwootAPI -var botHomeserver string var roomSendlocks map[id.RoomID]*sync.Mutex @@ -62,11 +61,12 @@ func main() { // Default configuration values configuration = Configuration{ - AllowMessagesFromUsersOnOtherHomeservers: false, - ChatwootBaseUrl: "https://app.chatwoot.com/", - ListenPort: 8080, - BridgeIfMembersLessThan: -1, - RenderMarkdown: false, + HomeserverWhitelist: HomeserverWhitelist{Enable: false}, + StartNewChat: StartNewChat{Enable: false}, + ChatwootBaseUrl: "https://app.chatwoot.com/", + ListenPort: 8080, + BridgeIfMembersLessThan: -1, + RenderMarkdown: false, Backfill: BackfillConfiguration{ ChatwootConversations: true, }, @@ -84,7 +84,6 @@ func main() { } log.Info().Interface("configuration", configuration).Msg("Config loaded") - botHomeserver = configuration.Username.Homeserver() log.Info().Msg("Chatwoot service starting...") @@ -148,7 +147,7 @@ func main() { log.Error().Err(decryptErr).Msg("Failed to decrypt message") stateStore.UpdateMostRecentEventIdForRoom(ctx, evt.RoomID, evt.ID) - if !VerifyFromAuthorizedUser(evt.Sender) { + if !VerifyFromAuthorizedUser(ctx, evt.Sender) { return } @@ -179,7 +178,7 @@ func main() { ctx := log.WithContext(context.TODO()) stateStore.UpdateMostRecentEventIdForRoom(ctx, evt.RoomID, evt.ID) - if VerifyFromAuthorizedUser(evt.Sender) { + if VerifyFromAuthorizedUser(ctx, evt.Sender) { go HandleBeeperClientInfo(ctx, evt) go HandleMessage(ctx, source, evt) } @@ -189,7 +188,7 @@ func main() { ctx := log.WithContext(context.TODO()) stateStore.UpdateMostRecentEventIdForRoom(ctx, evt.RoomID, evt.ID) - if VerifyFromAuthorizedUser(evt.Sender) { + if VerifyFromAuthorizedUser(ctx, evt.Sender) { go HandleBeeperClientInfo(ctx, evt) go HandleReaction(ctx, source, evt) } @@ -199,7 +198,7 @@ func main() { ctx := log.WithContext(context.TODO()) stateStore.UpdateMostRecentEventIdForRoom(ctx, evt.RoomID, evt.ID) - if VerifyFromAuthorizedUser(evt.Sender) { + if VerifyFromAuthorizedUser(ctx, evt.Sender) { go HandleBeeperClientInfo(ctx, evt) go HandleRedaction(ctx, source, evt) } @@ -371,14 +370,24 @@ func AllowKeyShare(ctx context.Context, device *id.Device, info event.RequestedK } } -func VerifyFromAuthorizedUser(sender id.UserID) bool { - if configuration.AllowMessagesFromUsersOnOtherHomeservers { +func VerifyFromAuthorizedUser(ctx context.Context, sender id.UserID) bool { + log := zerolog.Ctx(ctx) + if !configuration.HomeserverWhitelist.Enable { + log.Debug().Msg("homeserver whitelist disabled, allowing all messages") return true } _, homeserver, err := sender.Parse() if err != nil { + log.Warn().Err(err).Msg("failed to parse sender") return false } - return botHomeserver == homeserver + for _, allowedHS := range configuration.HomeserverWhitelist.Allowed { + if homeserver == allowedHS { + log.Debug().Str("sender_hs", allowedHS).Msg("allowing messages from whitelisted homeserver") + return true + } + } + log.Debug().Str("sender_hs", homeserver).Msg("rejecting messages from other homeserver") + return false } diff --git a/chatwootapi/api.go b/chatwootapi/api.go index b0f71c8..050e672 100644 --- a/chatwootapi/api.go +++ b/chatwootapi/api.go @@ -80,11 +80,19 @@ func (api *ChatwootAPI) CreateContact(ctx context.Context, userID id.UserID) (in Str("user_id", userID.String()). Logger() - log.Info().Msg("Creating contact") + name := userID.String() + if userID.Homeserver() == "beeper.local" { + decoded, err := id.DecodeUserLocalpart(strings.TrimPrefix(userID.Localpart(), "imessagego_1.")) + if err == nil { + name = decoded + } + } + + log.Info().Str("name", name).Msg("Creating contact") payload := CreateContactPayload{ InboxID: api.InboxID, - Name: userID.String(), - Identifier: userID.String(), + Name: name, + Identifier: name, } jsonValue, _ := json.Marshal(payload) req, err := http.NewRequest(http.MethodPost, api.MakeUri("contacts"), bytes.NewBuffer(jsonValue)) @@ -106,10 +114,8 @@ func (api *ChatwootAPI) CreateContact(ctx context.Context, userID id.UserID) (in return 0, fmt.Errorf("POST contacts returned non-200 status code: %d", resp.StatusCode) } - decoder := json.NewDecoder(resp.Body) var contactPayload ContactPayload - err = decoder.Decode(&contactPayload) - if err != nil { + if err := json.NewDecoder(resp.Body).Decode(&contactPayload); err != nil { return 0, err } @@ -117,14 +123,28 @@ func (api *ChatwootAPI) CreateContact(ctx context.Context, userID id.UserID) (in return contactPayload.Payload.Contact.ID, nil } -func (api *ChatwootAPI) ContactIDForMXID(userID id.UserID) (int, error) { +func (api *ChatwootAPI) ContactIDForMXID(ctx context.Context, userID id.UserID) (int, error) { + log := zerolog.Ctx(ctx) + query := userID.String() + if userID.Homeserver() == "beeper.local" { + // Special handling for bridged iMessages. + if strings.HasPrefix(userID.Localpart(), "imessagego_1.") { + decoded, err := id.DecodeUserLocalpart(strings.TrimPrefix(userID.Localpart(), "imessagego_1.")) + if err == nil { + query = decoded + } + } + } + + log.Info().Str("query", query).Msg("Searching for contact") + req, err := http.NewRequest(http.MethodGet, api.MakeUri("contacts/search"), nil) if err != nil { return 0, err } q := req.URL.Query() - q.Add("q", userID.String()) + q.Add("q", query) req.URL.RawQuery = q.Encode() resp, err := api.DoRequest(req) @@ -135,19 +155,22 @@ func (api *ChatwootAPI) ContactIDForMXID(userID id.UserID) (int, error) { return 0, fmt.Errorf("GET contacts/search returned non-200 status code: %d", resp.StatusCode) } - decoder := json.NewDecoder(resp.Body) var contactsPayload ContactsPayload - err = decoder.Decode(&contactsPayload) - if err != nil { + if err := json.NewDecoder(resp.Body).Decode(&contactsPayload); err != nil { return 0, err } + for _, contact := range contactsPayload.Payload { - if contact.Identifier == userID.String() { + if contact.Identifier == query { + return contact.ID, nil + } else if contact.Email == query { + return contact.ID, nil + } else if contact.PhoneNumber == query { return contact.ID, nil } } - return 0, fmt.Errorf("couldn't find user with user ID %s", userID) + return 0, fmt.Errorf("couldn't find user with user ID %s", query) } func (api *ChatwootAPI) GetChatwootConversation(conversationID int) (*Conversation, error) { @@ -164,13 +187,9 @@ func (api *ChatwootAPI) GetChatwootConversation(conversationID int) (*Conversati return nil, fmt.Errorf("GET conversations/%d returned non-200 status code: %d", conversationID, resp.StatusCode) } - decoder := json.NewDecoder(resp.Body) var conversation Conversation - err = decoder.Decode(&conversation) - if err != nil { - return nil, err - } - return &conversation, nil + err = json.NewDecoder(resp.Body).Decode(&conversation) + return &conversation, err } func (api *ChatwootAPI) CreateConversation(sourceID string, contactID int, additionalAttrs map[string]string) (*Conversation, error) { @@ -196,13 +215,9 @@ func (api *ChatwootAPI) CreateConversation(sourceID string, contactID int, addit return nil, fmt.Errorf("POST conversations returned non-200 status code: %d: %s", resp.StatusCode, string(content)) } - decoder := json.NewDecoder(resp.Body) var conversation Conversation - err = decoder.Decode(&conversation) - if err != nil { - return nil, err - } - return &conversation, nil + err = json.NewDecoder(resp.Body).Decode(&conversation) + return &conversation, err } func (api *ChatwootAPI) GetConversationLabels(conversationID int) ([]string, error) { @@ -220,9 +235,8 @@ func (api *ChatwootAPI) GetConversationLabels(conversationID int) ([]string, err return nil, fmt.Errorf("POST conversations returned non-200 status code: %d: %s", resp.StatusCode, string(content)) } - decoder := json.NewDecoder(resp.Body) var labels ConversationLabelsPayload - err = decoder.Decode(&labels) + err = json.NewDecoder(resp.Body).Decode(&labels) return labels.Payload, err } diff --git a/chatwootapi/objects.go b/chatwootapi/objects.go index bfe6bcf..47245ee 100644 --- a/chatwootapi/objects.go +++ b/chatwootapi/objects.go @@ -2,8 +2,10 @@ package chatwootapi // Contact type Contact struct { - ID int `json:"id"` - Identifier string `json:"identifier"` + ID int `json:"id"` + Identifier string `json:"identifier"` + PhoneNumber string `json:"phone_number,omitempty"` + Email string `json:"email,omitempty"` } type ContactsPayload struct { diff --git a/configuration.go b/configuration.go index 515cb04..eec12d0 100644 --- a/configuration.go +++ b/configuration.go @@ -15,6 +15,17 @@ type BackfillConfiguration struct { ConversationIDStateEvents bool `yaml:"conversation_id_state_events"` } +type HomeserverWhitelist struct { + Enable bool `yaml:"enable"` + Allowed []string `yaml:"allowed"` +} + +type StartNewChat struct { + Enable bool `yaml:"enable"` + Endpoint string `yaml:"endpoint"` + Token string `yaml:"token"` +} + type Configuration struct { // Authentication settings Homeserver string `yaml:"homeserver"` @@ -31,10 +42,11 @@ type Configuration struct { Database dbutil.Config `yaml:"database"` // Bot settings - AllowMessagesFromUsersOnOtherHomeservers bool `yaml:"allow_messages_from_users_on_other_homeservers"` - CanonicalDMPrefix string `yaml:"canonical_dm_prefix"` - BridgeIfMembersLessThan int `yaml:"bridge_if_members_less_than"` - RenderMarkdown bool `yaml:"render_markdown"` + HomeserverWhitelist HomeserverWhitelist `yaml:"homeserver_whitelist"` + StartNewChat StartNewChat `yaml:"start_new_chat"` + CanonicalDMPrefix string `yaml:"canonical_dm_prefix"` + BridgeIfMembersLessThan int `yaml:"bridge_if_members_less_than"` + RenderMarkdown bool `yaml:"render_markdown"` // Webhook listener settings ListenPort int `yaml:"listen_port"` diff --git a/example-config.yaml b/example-config.yaml index a5475a7..d529e1b 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -33,9 +33,47 @@ database: max_conn_lifetime: null # ===== Bot Settings ===== -# Boolean indicating whether or not to create conversations for messages -# originating from users on other homeservers. Defaults to false. -allow_messages_from_users_on_other_homeservers: false +# Whitelist for which homeservers should be allowed to create conversations. +homeserver_whitelist: + # Whether to enable the whitelist. + enable: false + # A list of allowed homeservers. + allowed: + - example.com +# Configure whether and how to start new chats from Chatwoot. +start_new_chat: + # Whether to enable starting new chats from Chatwoot. + enable: false + # Endpoint that will create the Matrix conversation. Auth should be with the + # Matrix access token. + # + # The endpoint must be a POST endpoint that accepts a JSON object with the + # following content (both fields optional, and potentially other fields + # included, but should not be depended on): + # + # { + # "phone_number":"+11234567890", + # "email":"email@example.com" + # } + # + # and should return 200 with a JSON object with the following content (and + # optionally other fields): + # + # { + # "room_id": "!abcd:beeper.local" + # } + # + # or a non-200 status code with an error (and optionally other fields): + # + # { + # "error": "+11234567890 is not on iMessage" + # } + # + # the behavior of the endpoint is undefined if the request payload is + # malformed. + endpoint: + # The Authentication token to use on the request. + token: # If not "", when creating a conversation, if the Matrix room name starts # with this prefix, it will be labeled with the "canonical-dm" label. Defaults # to "". diff --git a/matrix-handler.go b/matrix-handler.go index 1f70a97..c451e82 100644 --- a/matrix-handler.go +++ b/matrix-handler.go @@ -38,7 +38,7 @@ func createChatwootConversation(ctx context.Context, roomID id.RoomID, contactMX return conversationID, nil } - contactID, err := chatwootAPI.ContactIDForMXID(contactMXID) + contactID, err := chatwootAPI.ContactIDForMXID(ctx, contactMXID) if err != nil { log.Warn().Err(err).Msg("contact ID not found for user, will attempt to create one")