Add transaction ID to events if sending device (#368)

This commit is contained in:
Erik Johnston 2017-12-15 15:42:55 +00:00 committed by GitHub
parent de6529d766
commit b835e585c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 24 deletions

View File

@ -19,6 +19,9 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -86,13 +89,17 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[str
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database. // Returns a list of events matching the requested IDs found in the database.
// If an event is not found in the database then it will be omitted from the list. // If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return streamEventsToEvents(streamEvents), nil
// We don't include a device here as we only include transaction IDs in
// incremental syncs.
return streamEventsToEvents(nil, streamEvents), nil
} }
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
@ -208,10 +215,14 @@ func (d *SyncServerDatabase) syncStreamPositionTx(
return types.StreamPosition(maxID), nil return types.StreamPosition(maxID), nil
} }
// IncrementalSync returns all the data needed in order to create an incremental sync response. // IncrementalSync returns all the data needed in order to create an incremental
// sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction
// ID.
func (d *SyncServerDatabase) IncrementalSync( func (d *SyncServerDatabase) IncrementalSync(
ctx context.Context, ctx context.Context,
userID string, device authtypes.Device,
fromPos, toPos types.StreamPosition, fromPos, toPos types.StreamPosition,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) (*types.Response, error) { ) (*types.Response, error) {
@ -226,21 +237,21 @@ func (d *SyncServerDatabase) IncrementalSync(
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
// This works out what the 'state' key should be for each room as well as which membership block // This works out what the 'state' key should be for each room as well as which membership block
// to put the room into. // to put the room into.
deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID) deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
res := types.NewResponse(toPos) res := types.NewResponse(toPos)
for _, delta := range deltas { for _, delta := range deltas {
err = d.addRoomDeltaToResponse(ctx, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// TODO: This should be done in getStateDeltas // TODO: This should be done in getStateDeltas
if err = d.addInvitesToResponse(ctx, txn, userID, fromPos, toPos, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil {
return nil, err return nil, err
} }
@ -292,7 +303,10 @@ func (d *SyncServerDatabase) CompleteSync(
if err != nil { if err != nil {
return nil, err return nil, err
} }
recentEvents := streamEventsToEvents(recentStreamEvents)
// We don't include a device here as we don't need to send down
// transaction IDs for complete syncs
recentEvents := streamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
@ -390,7 +404,9 @@ func (d *SyncServerDatabase) addInvitesToResponse(
// addRoomDeltaToResponse adds a room state delta to a sync response // addRoomDeltaToResponse adds a room state delta to a sync response
func (d *SyncServerDatabase) addRoomDeltaToResponse( func (d *SyncServerDatabase) addRoomDeltaToResponse(
ctx context.Context, txn *sql.Tx, ctx context.Context,
device *authtypes.Device,
txn *sql.Tx,
fromPos, toPos types.StreamPosition, fromPos, toPos types.StreamPosition,
delta stateDelta, delta stateDelta,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
@ -412,7 +428,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse(
if err != nil { if err != nil {
return err return err
} }
recentEvents := streamEventsToEvents(recentStreamEvents) recentEvents := streamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
// Don't bother appending empty room entries // Don't bother appending empty room entries
@ -529,7 +545,7 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(
} }
func (d *SyncServerDatabase) getStateDeltas( func (d *SyncServerDatabase) getStateDeltas(
ctx context.Context, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos types.StreamPosition, userID string, fromPos, toPos types.StreamPosition, userID string,
) ([]stateDelta, error) { ) ([]stateDelta, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
@ -578,7 +594,7 @@ func (d *SyncServerDatabase) getStateDeltas(
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: membership, membership: membership,
membershipPos: ev.streamPosition, membershipPos: ev.streamPosition,
stateEvents: streamEventsToEvents(stateStreamEvents), stateEvents: streamEventsToEvents(device, stateStreamEvents),
roomID: roomID, roomID: roomID,
}) })
break break
@ -594,7 +610,7 @@ func (d *SyncServerDatabase) getStateDeltas(
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: "join", membership: "join",
stateEvents: streamEventsToEvents(state[joinedRoomID]), stateEvents: streamEventsToEvents(device, state[joinedRoomID]),
roomID: joinedRoomID, roomID: joinedRoomID,
}) })
} }
@ -602,10 +618,25 @@ func (d *SyncServerDatabase) getStateDeltas(
return deltas, nil return deltas, nil
} }
func streamEventsToEvents(in []streamEvent) []gomatrixserverlib.Event { // streamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event {
out := make([]gomatrixserverlib.Event, len(in)) out := make([]gomatrixserverlib.Event, len(in))
for i := 0; i < len(in); i++ { for i := 0; i < len(in); i++ {
out[i] = in[i].Event out[i] = in[i].Event
if device != nil && in[i].transactionID != nil {
if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].transactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
} }
return out return out
} }

View File

@ -123,7 +123,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener {
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
return n.fetchUserStream(req.userID, true).GetListener(req.ctx) return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx)
} }
// Load the membership states required to notify users correctly. // Load the membership states required to notify users correctly.

View File

@ -21,6 +21,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -262,7 +264,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) {
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
return types.StreamPosition(0), fmt.Errorf( return types.StreamPosition(0), fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%d)", req.userID, req.since, "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(*req.since):
p := listener.GetStreamPosition() p := listener.GetStreamPosition()
@ -280,7 +282,7 @@ func waitForBlocking(s *UserStream, numBlocking uint) {
func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest { func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest {
return syncRequest{ return syncRequest{
userID: userID, device: authtypes.Device{UserID: userID},
timeout: 1 * time.Minute, timeout: 1 * time.Minute,
since: &since, since: &since,
wantFullState: false, wantFullState: false,

View File

@ -20,6 +20,8 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -31,7 +33,7 @@ const defaultTimelineLimit = 20
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied. // syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
type syncRequest struct { type syncRequest struct {
ctx context.Context ctx context.Context
userID string device authtypes.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.StreamPosition // nil means that no since token was supplied since *types.StreamPosition // nil means that no since token was supplied
@ -39,7 +41,7 @@ type syncRequest struct {
log *log.Entry log *log.Entry
} }
func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, error) {
timeout := getTimeout(req.URL.Query().Get("timeout")) timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state") fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false" wantFullState := fullState != "" && fullState != "false"
@ -50,7 +52,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
// TODO: Additional query params: set_presence, filter // TODO: Additional query params: set_presence, filter
return &syncRequest{ return &syncRequest{
ctx: req.Context(), ctx: req.Context(),
userID: userID, device: device,
timeout: timeout, timeout: timeout,
since: since, since: since,
wantFullState: wantFullState, wantFullState: wantFullState,

View File

@ -48,7 +48,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// Extract values from request // Extract values from request
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
userID := device.UserID userID := device.UserID
syncReq, err := newSyncRequest(req, userID) syncReq, err := newSyncRequest(req, *device)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
@ -122,16 +122,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) {
// TODO: handle ignored users // TODO: handle ignored users
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.userID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
} else { } else {
res, err = rp.db.IncrementalSync(req.ctx, req.userID, *req.since, currentPos, req.limit) res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, currentPos, req.limit)
} }
if err != nil { if err != nil {
return return
} }
res, err = rp.appendAccountData(res, req.userID, req, currentPos) res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos)
return return
} }