Minor SendToDevice fix (#2565)

* Avoid unnecessary marshalling if sending to the local server

* Fix ordering of ToDevice messages

* Revive SendToDevice test
This commit is contained in:
Till 2022-07-12 08:23:58 +02:00 committed by GitHub
parent 3ea21273bc
commit 09f0ff14c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 129 additions and 85 deletions

View File

@ -95,6 +95,11 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Ms
return true
}
// The SyncAPI is already handling sendToDevice for the local server
if destServerName == t.ServerName {
return true
}
// Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDirectToDevice,

View File

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
)
const sendToDeviceSchema = `
@ -51,7 +52,7 @@ const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC
ORDER BY id ASC
`
const deleteSendToDeviceMessagesSQL = `
@ -112,17 +113,18 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
return
}
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue
}
if id > lastPos {
lastPos = id
}
events = append(events, event)
}
if lastPos == 0 {

View File

@ -49,7 +49,7 @@ const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC
ORDER BY id ASC
`
const deleteSendToDeviceMessagesSQL = `
@ -120,9 +120,6 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return
}
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
@ -132,6 +129,9 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue
}
if id > lastPos {
lastPos = id
}
events = append(events, event)
}
if lastPos == 0 {

View File

@ -1,7 +1,9 @@
package storage_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"
"testing"
@ -394,28 +396,34 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
from = topologyTokenBefore(t, db, paginatedEvents[len(paginatedEvents)-1].EventID())
}
}
*/
func TestSendToDeviceBehaviour(t *testing.T) {
//t.Parallel()
db := MustCreateDatabase(t)
t.Parallel()
alice := test.NewUser(t)
bob := test.NewUser(t)
deviceID := "one"
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
// At this point there should be no messages. We haven't sent anything
// yet.
_, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
_, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
if len(events) != 0 {
t.Fatal("first call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{})
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, 100)
if err != nil {
return
}
// Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
Sender: "bob",
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
Sender: bob.ID,
Type: "m.type",
Content: json.RawMessage("{}"),
})
@ -426,14 +434,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at.
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update")
if count := len(events); count != 1 {
t.Fatalf("second call should have one update, got %d", count)
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
if err != nil {
return
}
@ -441,43 +449,72 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
if len(events) != 1 {
t.Fatal("third call should have one update still")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos+1)
if err != nil {
return
}
// At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again.
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos+1, streamPos+2)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
if len(events) != 0 {
t.Fatal("fourth call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos + 1})
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos+1)
if err != nil {
return
}
// At this point we should still have no updates, because no new updates have been
// sent.
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+2)
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
if len(events) != 0 {
t.Fatal("fifth call should have no updates")
}
// Send some more messages and verify the ordering is correct ("in order of arrival")
var lastPos types.StreamPosition = 0
for i := 0; i < 10; i++ {
streamPos, err = db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
Sender: bob.ID,
Type: "m.type",
Content: json.RawMessage(fmt.Sprintf(`{ "count": %d }`, i)),
})
if err != nil {
t.Fatal(err)
}
lastPos = streamPos
}
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
if err != nil {
t.Fatalf("unable to get events: %v", err)
}
for i := 0; i < 10; i++ {
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
got := events[i].Content
if !bytes.Equal(got, want) {
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
}
}
})
}
/*
func TestInviteBehaviour(t *testing.T) {
db := MustCreateDatabase(t)
inviteRoom1 := "!inviteRoom1:somewhere"