Minor SendToDevice fix (#2565)

* Avoid unnecessary marshalling if sending to the local server

* Fix ordering of ToDevice messages

* Revive SendToDevice test
This commit is contained in:
Till 2022-07-12 08:23:58 +02:00 committed by GitHub
parent 3ea21273bc
commit 09f0ff14c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 129 additions and 85 deletions

View File

@ -95,6 +95,11 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Ms
return true return true
} }
// The SyncAPI is already handling sendToDevice for the local server
if destServerName == t.ServerName {
return true
}
// Pack the EDU and marshal it // Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{ edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDirectToDevice, Type: gomatrixserverlib.MDirectToDevice,

View File

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
) )
const sendToDeviceSchema = ` const sendToDeviceSchema = `
@ -51,7 +52,7 @@ const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id ASC
` `
const deleteSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
@ -112,17 +113,18 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil { if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
} }
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue continue
} }
if id > lastPos {
lastPos = id
}
events = append(events, event) events = append(events, event)
} }
if lastPos == 0 { if lastPos == 0 {

View File

@ -49,7 +49,7 @@ const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id ASC
` `
const deleteSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
@ -120,9 +120,6 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message") logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
@ -132,6 +129,9 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message") logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue continue
} }
if id > lastPos {
lastPos = id
}
events = append(events, event) events = append(events, event)
} }
if lastPos == 0 { if lastPos == 0 {

View File

@ -1,7 +1,9 @@
package storage_test package storage_test
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
@ -394,28 +396,34 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
from = topologyTokenBefore(t, db, paginatedEvents[len(paginatedEvents)-1].EventID()) from = topologyTokenBefore(t, db, paginatedEvents[len(paginatedEvents)-1].EventID())
} }
} }
*/
func TestSendToDeviceBehaviour(t *testing.T) { func TestSendToDeviceBehaviour(t *testing.T) {
//t.Parallel() t.Parallel()
db := MustCreateDatabase(t) alice := test.NewUser(t)
bob := test.NewUser(t)
deviceID := "one"
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
_, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{}) _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 0 {
t.Fatal("first call should have no updates") t.Fatal("first call should have no updates")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{})
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, 100)
if err != nil { if err != nil {
return return
} }
// Try sending a message. // Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{ streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
Sender: "bob", Sender: bob.ID,
Type: "m.type", Type: "m.type",
Content: json.RawMessage("{}"), Content: json.RawMessage("{}"),
}) })
@ -426,14 +434,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position // At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated // that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at. // in the database to reflect that this was the sync position we sent the message at.
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos}) streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { if count := len(events); count != 1 {
t.Fatal("second call should have one update") t.Fatalf("second call should have one update, got %d", count)
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos}) err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
if err != nil { if err != nil {
return return
} }
@ -441,43 +449,72 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the // At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying // sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position. // with the same position.
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos}) streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 1 {
t.Fatal("third call should have one update still") t.Fatal("third call should have one update still")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos}) err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos+1)
if err != nil { if err != nil {
return return
} }
// At this point we should now have no updates, because we've progressed the sync // At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again. // position. Therefore the update from before will not be sent again.
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1}) _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos+1, streamPos+2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { if len(events) != 0 {
t.Fatal("fourth call should have no updates") t.Fatal("fourth call should have no updates")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos + 1}) err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos+1)
if err != nil { if err != nil {
return return
} }
// At this point we should still have no updates, because no new updates have been // At this point we should still have no updates, because no new updates have been
// sent. // sent.
_, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2}) _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 0 {
t.Fatal("fifth call should have no updates") t.Fatal("fifth call should have no updates")
} }
// Send some more messages and verify the ordering is correct ("in order of arrival")
var lastPos types.StreamPosition = 0
for i := 0; i < 10; i++ {
streamPos, err = db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
Sender: bob.ID,
Type: "m.type",
Content: json.RawMessage(fmt.Sprintf(`{ "count": %d }`, i)),
})
if err != nil {
t.Fatal(err)
}
lastPos = streamPos
} }
_, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
if err != nil {
t.Fatalf("unable to get events: %v", err)
}
for i := 0; i < 10; i++ {
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
got := events[i].Content
if !bytes.Equal(got, want) {
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
}
}
})
}
/*
func TestInviteBehaviour(t *testing.T) { func TestInviteBehaviour(t *testing.T) {
db := MustCreateDatabase(t) db := MustCreateDatabase(t)
inviteRoom1 := "!inviteRoom1:somewhere" inviteRoom1 := "!inviteRoom1:somewhere"