Skip to content

Commit

Permalink
deps/mautrix: convert all functions to use new interfaces
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Mar 26, 2024
1 parent c7736e2 commit 73ddbe6
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 168 deletions.
16 changes: 8 additions & 8 deletions chatwoot-handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func SendMessage(ctx context.Context, roomID id.RoomID, content *event.MessageEv
}

r, err := DoRetry(ctx, "send message to "+roomID.String(), func(ctx context.Context) (*mautrix.RespSendEvent, error) {
return client.SendMessageEvent(roomID, event.EventMessage, &wrappedContent)
return client.SendMessageEvent(ctx, roomID, event.EventMessage, &wrappedContent)
})
if err != nil {
// give up
Expand Down Expand Up @@ -161,8 +161,8 @@ func handleAttachment(ctx context.Context, roomID id.RoomID, chatwootMessageID i
info.ThumbnailFile.EncryptInPlace(thumbnailData)

// Upload the thumbnail
uploadedThumbnail, err := DoRetry(ctx, "upload thumbnail to Matrix", func(context.Context) (*mautrix.RespMediaUpload, error) {
return client.UploadMedia(mautrix.ReqUploadMedia{
uploadedThumbnail, err := DoRetry(ctx, "upload thumbnail to Matrix", func(ctx context.Context) (*mautrix.RespMediaUpload, error) {
return client.UploadMedia(ctx, mautrix.ReqUploadMedia{
ContentBytes: thumbnailData,
ContentLength: int64(len(thumbnailData)),
ContentType: "application/octet-stream",
Expand Down Expand Up @@ -197,8 +197,8 @@ func handleAttachment(ctx context.Context, roomID id.RoomID, chatwootMessageID i
}

// Upload it to the media repo
uploaded, err := DoRetry(ctx, fmt.Sprintf("upload %s to Matrix", filename), func(context.Context) (*mautrix.RespMediaUpload, error) {
return client.UploadMedia(mautrix.ReqUploadMedia{
uploaded, err := DoRetry(ctx, fmt.Sprintf("upload %s to Matrix", filename), func(ctx context.Context) (*mautrix.RespMediaUpload, error) {
return client.UploadMedia(ctx, mautrix.ReqUploadMedia{
ContentBytes: attachmentData,
ContentLength: int64(len(attachmentData)),
ContentType: "application/octet-stream",
Expand Down Expand Up @@ -307,7 +307,7 @@ func HandleMessageCreated(ctx context.Context, mc chatwootapi.MessageCreated) er
return err
}

_, err = client.State(sncResp.RoomID)
_, err = client.State(ctx, sncResp.RoomID)
if err != nil {
log.Err(err).Msg("failed to get room state")
return err
Expand All @@ -334,13 +334,13 @@ func HandleMessageCreated(ctx context.Context, mc chatwootapi.MessageCreated) er
log.Info().Int("message_id", mc.ID).Msg("message deleted")
var errs []error
for _, eventID := range eventIDs {
event, err := client.GetEvent(roomID, eventID)
event, err := client.GetEvent(ctx, roomID, eventID)
if err == nil && event.Unsigned.RedactedBecause != nil {
// Already redacted
log.Info().Int("message_id", mc.ID).Msg("message was already redacted")
continue
}
_, err = client.RedactEvent(roomID, eventID)
_, err = client.RedactEvent(ctx, roomID, eventID)
if err != nil {
errs = append(errs, err)
}
Expand Down
64 changes: 33 additions & 31 deletions chatwoot.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,18 @@ func main() {
globallog.Fatal().Err(err).Msg("Failed to compile logging configuration")
}

log.Info().Interface("configuration", configuration).Msg("Config loaded")
getLogger := func(evt *event.Event) zerolog.Logger {
return log.With().
Stringer("event_type", &evt.Type).
Stringer("sender", evt.Sender).
Str("room_id", string(evt.RoomID)).
Str("event_id", string(evt.ID)).
Logger()
}

log.Info().Interface("configuration", configuration).Msg("Config loaded")
log.Info().Msg("Chatwoot service starting...")
ctx := log.WithContext(context.TODO())

// Open the chatwoot database
db, err := dbutil.NewFromConfig("chatwoot", configuration.Database, dbutil.ZeroLogger(*log))
Expand All @@ -97,7 +106,7 @@ func main() {
roomSendlocks = map[id.RoomID]*sync.Mutex{}

stateStore = database.NewDatabase(db)
if err := stateStore.DB.Upgrade(); err != nil {
if err := stateStore.DB.Upgrade(ctx); err != nil {
log.Fatal().Err(err).Msg("failed to upgrade the Chatwoot database")
}

Expand All @@ -118,15 +127,6 @@ func main() {
accessToken,
)

getLogger := func(evt *event.Event) zerolog.Logger {
return log.With().
Stringer("event_type", &evt.Type).
Stringer("sender", evt.Sender).
Str("room_id", string(evt.RoomID)).
Str("event_id", string(evt.ID)).
Logger()
}

cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("chatwoot_cryptostore_key"), db)
if err != nil {
log.Fatal().Err(err).Msg("Failed to create crypto helper")
Expand Down Expand Up @@ -165,42 +165,44 @@ func main() {
})
}

err = cryptoHelper.Init()
err = cryptoHelper.Init(ctx)
if err != nil {
log.Fatal().Err(err).Msg("Failed to initialize crypto helper")
}
cryptoHelper.Machine().AllowKeyShare = AllowKeyShare
client.Crypto = cryptoHelper

addEvtContext := func(ctx context.Context, evt *event.Event) context.Context {
return zerolog.Ctx(ctx).With().
Stringer("event_type", &evt.Type).
Stringer("sender", evt.Sender).
Str("room_id", string(evt.RoomID)).
Str("event_id", string(evt.ID)).
Logger().
WithContext(ctx)
}
syncer := client.Syncer.(*mautrix.DefaultSyncer)
syncer.OnEventType(event.EventMessage, func(source mautrix.EventSource, evt *event.Event) {
log := getLogger(evt)
ctx := log.WithContext(context.TODO())

syncer.OnEventType(event.EventMessage, func(ctx context.Context, evt *event.Event) {
ctx = addEvtContext(ctx, evt)
stateStore.UpdateMostRecentEventIDForRoom(ctx, evt.RoomID, evt.ID)
if VerifyFromAuthorizedUser(ctx, evt.Sender) {
go HandleBeeperClientInfo(ctx, evt)
go HandleMessage(ctx, source, evt)
go HandleMessage(ctx, evt)
}
})
syncer.OnEventType(event.EventReaction, func(source mautrix.EventSource, evt *event.Event) {
log := getLogger(evt)
ctx := log.WithContext(context.TODO())
syncer.OnEventType(event.EventReaction, func(ctx context.Context, evt *event.Event) {
ctx = addEvtContext(ctx, evt)

stateStore.UpdateMostRecentEventIDForRoom(ctx, evt.RoomID, evt.ID)
if VerifyFromAuthorizedUser(ctx, evt.Sender) {
go HandleBeeperClientInfo(ctx, evt)
go HandleReaction(ctx, source, evt)
go HandleReaction(ctx, evt)
}
})
syncer.OnEventType(event.EventRedaction, func(source mautrix.EventSource, evt *event.Event) {
log := getLogger(evt)
ctx := log.WithContext(context.TODO())
syncer.OnEventType(event.EventRedaction, func(ctx context.Context, evt *event.Event) {
ctx = addEvtContext(ctx, evt)

stateStore.UpdateMostRecentEventIDForRoom(ctx, evt.RoomID, evt.ID)
if VerifyFromAuthorizedUser(ctx, evt.Sender) {
go HandleBeeperClientInfo(ctx, evt)
go HandleRedaction(ctx, source, evt)
go HandleRedaction(ctx, evt)
}
})

Expand Down Expand Up @@ -232,7 +234,7 @@ func main() {

log.Info().Msg("starting to create conversations for rooms that don't have a conversation yet")

joined, err := client.JoinedRooms()
joined, err := client.JoinedRooms(ctx)
if err != nil {
log.Fatal().Err(err).Msg("Failed to get joined rooms")
}
Expand All @@ -253,7 +255,7 @@ func main() {
// If we already have a Chatwoot conversation, make sure that
// the room has a state event with the Chatwoot conversation
// ID.
_, err = client.SendStateEvent(roomID, chatwootConversationIDType, "", ChatwootConversationIDEventContent{
_, err = client.SendStateEvent(ctx, roomID, chatwootConversationIDType, "", ChatwootConversationIDEventContent{
ConversationID: chatwootConversationID,
})
if err != nil {
Expand Down Expand Up @@ -308,7 +310,7 @@ func backfillConversationForRoom(ctx context.Context, roomID id.RoomID) error {

log.Info().Msg("Creating conversation for room")

messages, err := client.Messages(roomID, "", "", mautrix.DirectionBackward, nil, 50)
messages, err := client.Messages(ctx, roomID, "", "", mautrix.DirectionBackward, nil, 50)
if err != nil {
log.Err(err).Msg("Failed to get messages for room")
return err
Expand Down
60 changes: 23 additions & 37 deletions database/chatwoot-conversation-to-matrix-room.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package database
import (
"context"
"database/sql"
"fmt"

"github.com/rs/zerolog"
"maunium.net/go/mautrix/id"
)

func (store *Database) GetChatwootConversationIDFromMatrixRoom(ctx context.Context, roomID id.RoomID) (int, error) {
row := store.DB.QueryRowContext(ctx, `
row := store.DB.QueryRow(ctx, `
SELECT chatwoot_conversation_id
FROM chatwoot_conversation_to_matrix_room
WHERE matrix_room_id = $1`, roomID)
Expand All @@ -21,7 +22,7 @@ func (store *Database) GetChatwootConversationIDFromMatrixRoom(ctx context.Conte
}

func (store *Database) GetMatrixRoomFromChatwootConversation(ctx context.Context, conversationID int) (id.RoomID, id.EventID, error) {
row := store.DB.QueryRowContext(ctx, `
row := store.DB.QueryRow(ctx, `
SELECT matrix_room_id, most_recent_event_id
FROM chatwoot_conversation_to_matrix_room
WHERE chatwoot_conversation_id = $1`, conversationID)
Expand All @@ -45,24 +46,17 @@ func (store *Database) UpdateMostRecentEventIDForRoom(ctx context.Context, roomI
ctx = log.WithContext(ctx)

log.Debug().Msg("setting most recent event ID for room")
tx, err := store.DB.Begin()
if err != nil {
tx.Rollback()
return err
}

update := `
UPDATE chatwoot_conversation_to_matrix_room
SET most_recent_event_id = $2
WHERE matrix_room_id = $1
`
if _, err := tx.ExecContext(ctx, update, roomID, mostRecentEventID); err != nil {
tx.Rollback()
log.Err(err).Msg("failed to update most recent event ID")
return err
}

return tx.Commit()
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
update := `
UPDATE chatwoot_conversation_to_matrix_room
SET most_recent_event_id = $2
WHERE matrix_room_id = $1
`
if _, err := store.DB.Exec(ctx, update, roomID, mostRecentEventID); err != nil {
return fmt.Errorf("failed to update most recent event ID: %w", err)
}
return nil
})
}

func (store *Database) UpdateConversationIDForRoom(ctx context.Context, roomID id.RoomID, conversationID int) error {
Expand All @@ -73,22 +67,14 @@ func (store *Database) UpdateConversationIDForRoom(ctx context.Context, roomID i
ctx = log.WithContext(ctx)

log.Debug().Msg("setting conversation ID for room")
tx, err := store.DB.Begin()
if err != nil {
tx.Rollback()
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
upsert := `
INSERT INTO chatwoot_conversation_to_matrix_room (matrix_room_id, chatwoot_conversation_id)
VALUES ($1, $2)
ON CONFLICT (matrix_room_id) DO UPDATE
SET chatwoot_conversation_id = $2
`
_, err := store.DB.Exec(ctx, upsert, roomID, conversationID)
return err
}

upsert := `
INSERT INTO chatwoot_conversation_to_matrix_room (matrix_room_id, chatwoot_conversation_id)
VALUES ($1, $2)
ON CONFLICT (matrix_room_id) DO UPDATE
SET chatwoot_conversation_id = $2
`
if _, err := tx.ExecContext(ctx, upsert, roomID, conversationID); err != nil {
tx.Rollback()
return err
}

return tx.Commit()
})
}
33 changes: 14 additions & 19 deletions database/chatwoot-message-to-matrix-event.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"context"
"fmt"

"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
Expand All @@ -16,31 +17,25 @@ func (store *Database) SetChatwootMessageIDForMatrixEvent(ctx context.Context, e
ctx = log.WithContext(ctx)

log.Debug().Msg("setting chatwoot message ID for matrix event")
tx, err := store.DB.Begin()
if err != nil {
tx.Rollback()
return err
}

insert := `
INSERT INTO chatwoot_message_to_matrix_event (matrix_event_id, chatwoot_message_id)
VALUES ($1, $2)
`
if _, err := tx.ExecContext(ctx, insert, eventID, chatwootMessageID); err != nil {
log.Err(err).Msg("failed to set chatwoot message ID for matrix event")
tx.Rollback()
return err
}

return tx.Commit()
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
insert := `
INSERT INTO chatwoot_message_to_matrix_event (matrix_event_id, chatwoot_message_id)
VALUES ($1, $2)
`
_, err := store.DB.Exec(ctx, insert, eventID, chatwootMessageID)
if err != nil {
return fmt.Errorf("failed to insert chatwoot message ID for matrix event: %w", err)
}
return nil
})
}

func (store *Database) GetMatrixEventIDsForChatwootMessage(ctx context.Context, chatwootMessageID int) []id.EventID {
log := zerolog.Ctx(ctx).With().Int("message_id", chatwootMessageID).Logger()
ctx = log.WithContext(ctx)

log.Debug().Msg("getting Matrix event IDs for chatwoot message")
rows, err := store.DB.QueryContext(ctx, `
rows, err := store.DB.Query(ctx, `
SELECT matrix_event_id
FROM chatwoot_message_to_matrix_event
WHERE chatwoot_message_id = $1`, chatwootMessageID)
Expand All @@ -65,7 +60,7 @@ func (store *Database) GetChatwootMessageIDsForMatrixEventID(ctx context.Context

log.Debug().Msg("getting chatwoot message IDs for matrix event ID")
var rows dbutil.Rows
rows, err = store.DB.QueryContext(ctx, `
rows, err = store.DB.Query(ctx, `
SELECT chatwoot_message_id
FROM chatwoot_message_to_matrix_event
WHERE matrix_event_id = $1`, matrixEventID)
Expand Down
Loading

0 comments on commit 73ddbe6

Please sign in to comment.