Simplify send-to-device messaging (#1702)

* Simplify send-to-device messaging

* Don't return error if there's no work to do

* Remove SQLite migrations for now

* Tweak Postgres migrations

* Tweaks

* Fixes

* Cleanup separately

* Fix SQLite migration
This commit is contained in:
Neil Alexander 2021-01-13 17:29:46 +00:00 committed by GitHub
parent bb9e6a1281
commit d8fba52e97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 248 additions and 254 deletions

View File

@ -33,6 +33,7 @@ type Database interface {
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@ -117,26 +118,14 @@ type Database interface {
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// - "events": a list of send-to-device events that should be included in the sync // relevant events within the given ranges for the supplied user ID and device ID.
// - "changes": a list of send-to-device events that should be updated in the database by SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// CleanSendToDeviceUpdates
// - "deletions": a list of send-to-device events which have been confirmed as sent and
// can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (pos types.StreamPosition, events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows // from position, preventing the send-to-device table from growing indefinitely.
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
// starting to wait for an incremental sync with timeout).
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
// GetFilter looks up the filter associated with a given local user and filter ID. // GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists // Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database. // or if there was an error talking to the database.

View File

@ -24,6 +24,7 @@ import (
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences) goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
} }
func LoadFixSequences(m *sqlutil.Migrations) { func LoadFixSequences(m *sqlutil.Migrations) {

View File

@ -0,0 +1,48 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
DROP COLUMN IF EXISTS sent_by_token;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
ADD COLUMN IF NOT EXISTS sent_by_token TEXT;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View File

@ -19,7 +19,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
@ -38,11 +37,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The device ID to send the message to. -- The device ID to send the message to.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL, content TEXT NOT NULL
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
); );
` `
@ -52,34 +47,26 @@ const insertSendToDeviceMessageSQL = `
RETURNING id RETURNING id
` `
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
const updateSentSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1 DELETE FROM syncapi_send_to_device
WHERE id = ANY($2) WHERE user_id = $1 AND device_id = $2 AND id < $3
` `
const deleteSendToDeviceMessagesSQL = ` const selectMaxSendToDeviceIDSQL = "" +
DELETE FROM syncapi_send_to_device WHERE id = ANY($1) "SELECT MAX(id) FROM syncapi_send_to_device"
`
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
updateSentSendToDeviceMessagesStmt *sql.Stmt
deleteSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt
selectMaxSendToDeviceIDStmt *sql.Stmt
} }
func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -91,16 +78,13 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil { if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
@ -113,64 +97,55 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
return return
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() { for rows.Next() {
var id types.SendToDeviceNID var id types.StreamPosition
var userID, deviceID, content string var userID, deviceID, content string
var sentByToken *string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
} }
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return continue
}
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
} }
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
} }
if lastPos == 0 {
lastPos = to
} }
return lastPos, events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
return return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx,
) (err error) { ) (id int64, err error) {
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return return
} }

View File

@ -89,6 +89,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m) deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil { if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err return nil, err
} }

View File

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -85,6 +86,14 @@ func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.Strea
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil)
if err != nil {
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil) id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil)
if err != nil { if err != nil {
@ -168,30 +177,6 @@ func (d *Database) GetEventsInStreamingRange(
return events, err return events, err
} }
/*
func (d *Database) AddTypingUser(
userID, roomID string, expireTime *time.Time,
) types.StreamPosition {
return types.StreamPosition(d.EDUCache.AddTypingUser(userID, roomID, expireTime))
}
func (d *Database) RemoveTypingUser(
userID, roomID string,
) types.StreamPosition {
return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID))
}
func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.EDUCache.SetTimeoutCallback(fn)
}
*/
/*
func (d *Database) AddSendToDevice() types.StreamPosition {
return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage())
}
*/
func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsers(ctx) return d.CurrentRoomState.SelectJoinedUsers(ctx)
} }
@ -891,16 +876,6 @@ func (d *Database) currentStateStreamEventsForRoom(
return s, nil return s, nil
} }
func (d *Database) SendToDeviceUpdatesWaiting(
ctx context.Context, userID, deviceID string,
) (bool, error) {
count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *Database) StoreNewSendForDeviceMessage( func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (newPos types.StreamPosition, err error) { ) (newPos types.StreamPosition, err error) {
@ -919,78 +894,38 @@ func (d *Database) StoreNewSendForDeviceMessage(
if err != nil { if err != nil {
return 0, err return 0, err
} }
return 0, nil return newPos, nil
} }
func (d *Database) SendToDeviceUpdatesForSync( func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context, ctx context.Context,
userID, deviceID string, userID, deviceID string,
token types.StreamingToken, from, to types.StreamPosition,
) (types.StreamPosition, []types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { ) (types.StreamPosition, []types.SendToDeviceEvent, error) {
// First of all, get our send-to-device updates for this user. // First of all, get our send-to-device updates for this user.
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to)
if err != nil { if err != nil {
return 0, nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
} }
// If there's nothing to do then stop here. // If there's nothing to do then stop here.
if len(events) == 0 { if len(events) == 0 {
return 0, nil, nil, nil, nil return to, nil, nil
} }
return lastPos, events, nil
// Work out whether we need to update any of the database entries.
toReturn := []types.SendToDeviceEvent{}
toUpdate := []types.SendToDeviceNID{}
toDelete := []types.SendToDeviceNID{}
for _, event := range events {
if event.SentByToken == nil {
// If the event has no sent-by token yet then we haven't attempted to send
// it. Record the current requested sync token in the database.
toUpdate = append(toUpdate, event.ID)
toReturn = append(toReturn, event)
event.SentByToken = &token
} else if token.IsAfter(*event.SentByToken) {
// The event had a sync token, therefore we've sent it before. The current
// sync token is now after the stored one so we can assume that the client
// successfully completed the previous sync (it would re-request it otherwise)
// so we can remove the entry from the database.
toDelete = append(toDelete, event.ID)
} else {
// It looks like the sync is being re-requested, maybe it timed out or
// failed. Re-send any that should have been acknowledged by now.
toReturn = append(toReturn, event)
}
}
return lastPos, toReturn, toUpdate, toDelete, nil
} }
func (d *Database) CleanSendToDeviceUpdates( func (d *Database) CleanSendToDeviceUpdates(
ctx context.Context, ctx context.Context,
toUpdate, toDelete []types.SendToDeviceNID, userID, deviceID string, before types.StreamPosition,
token types.StreamingToken,
) (err error) { ) (err error) {
if len(toUpdate) == 0 && len(toDelete) == 0 { if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before)
}); err != nil {
logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID)
return err
}
return nil return nil
} }
// If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
}
// Now update any outstanding send-to-device messages with the new sync token.
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
}
return nil
})
return
}
// getMembershipFromEvent returns the value of content.membership iff the event is a state event // getMembershipFromEvent returns the value of content.membership iff the event is a state event
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. // with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.

View File

@ -24,6 +24,7 @@ import (
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences) goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
} }
func LoadFixSequences(m *sqlutil.Migrations) { func LoadFixSequences(m *sqlutil.Migrations) {

View File

@ -0,0 +1,67 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
CREATE TABLE syncapi_send_to_device(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
DROP TABLE syncapi_send_to_device_backup;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
CREATE TABLE syncapi_send_to_device(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL,
sent_by_token TEXT
);
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
DROP TABLE syncapi_send_to_device_backup;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}

View File

@ -18,12 +18,12 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
) )
const sendToDeviceSchema = ` const sendToDeviceSchema = `
@ -36,11 +36,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The device ID to send the message to. -- The device ID to send the message to.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL, content TEXT NOT NULL
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
); );
` `
@ -49,33 +45,27 @@ const insertSendToDeviceMessageSQL = `
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
` `
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
const updateSentSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1 DELETE FROM syncapi_send_to_device
WHERE id IN ($2) WHERE user_id = $1 AND device_id = $2 AND id < $3
` `
const deleteSendToDeviceMessagesSQL = ` const selectMaxSendToDeviceIDSQL = "" +
DELETE FROM syncapi_send_to_device WHERE id IN ($1) "SELECT MAX(id) FROM syncapi_send_to_device"
`
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *sql.DB db *sql.DB
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt
selectMaxSendToDeviceIDStmt *sql.Stmt
} }
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -86,15 +76,18 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -111,75 +104,57 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
return return
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() { for rows.Next() {
var id types.SendToDeviceNID var id types.StreamPosition
var userID, deviceID, content string var userID, deviceID, content string
var sentByToken *string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
} }
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
} continue
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
} }
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
} }
if lastPos == 0 {
lastPos = to
} }
return lastPos, events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
) (err error) { ) (err error) {
query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", sqlutil.QueryVariadic(1+len(nids)), 1) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
params := make([]interface{}, 1+len(nids))
params[0] = token
for k, v := range nids {
params[k+1] = v
}
_, err = txn.ExecContext(ctx, query, params...)
return return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx,
) (err error) { ) (id int64, err error) {
query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) var nullableID sql.NullInt64
params := make([]interface{}, 1+len(nids)) stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
for k, v := range nids { err = stmt.QueryRowContext(ctx).Scan(&nullableID)
params[k] = v if nullableID.Valid {
id = nullableID.Int64
} }
_, err = txn.ExecContext(ctx, query, params...)
return return
} }

View File

@ -102,6 +102,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m) deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil { if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err return err
} }

View File

@ -147,10 +147,9 @@ type BackwardsExtremities interface {
// sync response, as the client is seemingly trying to repeat the same /sync. // sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface { type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error) InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from types.StreamPosition) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
} }
type Filter interface { type Filter interface {

View File

@ -10,6 +10,16 @@ type SendToDeviceStreamProvider struct {
StreamProvider StreamProvider
} }
func (p *SendToDeviceStreamProvider) Setup() {
p.StreamProvider.Setup()
id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background())
if err != nil {
panic(err)
}
p.latest = id
}
func (p *SendToDeviceStreamProvider) CompleteSync( func (p *SendToDeviceStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
@ -23,24 +33,19 @@ func (p *SendToDeviceStreamProvider) IncrementalSync(
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// See if we have any new tasks to do for the send-to-device messaging. // See if we have any new tasks to do for the send-to-device messaging.
lastPos, events, updates, deletions, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, req.Since) lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed") req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
return from return from
} }
// Before we return the sync response, make sure that we take action on if len(events) > 0 {
// any send-to-device database updates or deletions that we need to do. // Clean up old send-to-device messages from before this stream position.
// Then add the updates into the sync response. if err := p.DB.CleanSendToDeviceUpdates(req.Context, req.Device.UserID, req.Device.ID, from); err != nil {
if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database.
err = p.DB.CleanSendToDeviceUpdates(context.Background(), updates, deletions, req.Since)
if err != nil {
req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed") req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
return from return from
} }
}
if len(events) > 0 {
// Add the updates into the sync response. // Add the updates into the sync response.
for _, event := range events { for _, event := range events {
req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent) req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)

View File

@ -492,14 +492,11 @@ func NewLeaveResponse() *LeaveResponse {
return &res return &res
} }
type SendToDeviceNID int
type SendToDeviceEvent struct { type SendToDeviceEvent struct {
gomatrixserverlib.SendToDeviceEvent gomatrixserverlib.SendToDeviceEvent
ID SendToDeviceNID ID StreamPosition
UserID string UserID string
DeviceID string DeviceID string
SentByToken *StreamingToken
} }
type PeekingDevice struct { type PeekingDevice struct {