CS API: Support for /messages, fixes for /sync (#847)

* Merge forward

* Tidy up a bit

* TODO: What to do with NextBatch here?

* Replace SyncPosition with PaginationToken throughout syncapi

* Fix PaginationTokens

* Fix lint errors

* Add a couple of missing functions into the syncapi external storage interface

* Some updates based on review comments from @babolivier

* Some updates based on review comments from @babolivier

* argh whitespacing

* Fix opentracing span

* Remove dead code

* Don't overshadow err (fix lint issue)

* Handle extremities after inserting event into topology

* Try insert event topology as ON CONFLICT DO NOTHING

* Prevent OOB error in addRoomDeltaToResponse

* Thwarted by gocyclo again

* Fix NewPaginationTokenFromString, define unit test for it

* Update pagination token test

* Update sytest-whitelist

* Hopefully fix some of the sync batch tokens

* Remove extraneous sync position func

* Revert to topology tokens in addRoomDeltaToResponse etc

* Fix typo

* Remove prevPDUPos as dead now that backwardTopologyPos is used instead

* Fix selectEventsWithEventIDsSQL

* Update sytest-blacklist

* Update sytest-whitelist
This commit is contained in:
Neil Alexander 2020-01-23 17:51:10 +00:00 committed by GitHub
parent 43ecf8d1f9
commit 49f760a30b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1601 additions and 286 deletions

View File

@ -70,7 +70,7 @@ func main() {
federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI) federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI)
mediaapi.SetupMediaAPIComponent(base, deviceDB) mediaapi.SetupMediaAPIComponent(base, deviceDB)
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB) publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB)
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query) syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg)
httpHandler := common.WrapHandlerInCORS(base.APIMux) httpHandler := common.WrapHandlerInCORS(base.APIMux)

View File

@ -26,10 +26,11 @@ func main() {
deviceDB := base.CreateDeviceDB() deviceDB := base.CreateDeviceDB()
accountDB := base.CreateAccountsDB() accountDB := base.CreateAccountsDB()
federation := base.CreateFederationClient()
_, _, query := base.CreateHTTPRoomserverAPIs() _, _, query := base.CreateHTTPRoomserverAPIs()
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query) syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg)
base.SetupAndServeHTTP(string(base.Cfg.Bind.SyncAPI), string(base.Cfg.Listen.SyncAPI)) base.SetupAndServeHTTP(string(base.Cfg.Bind.SyncAPI), string(base.Cfg.Listen.SyncAPI))

View File

@ -230,6 +230,20 @@ type QueryBackfillResponse struct {
Events []gomatrixserverlib.Event `json:"events"` Events []gomatrixserverlib.Event `json:"events"`
} }
// QueryServersInRoomAtEventRequest is a request to QueryServersInRoomAtEvent
type QueryServersInRoomAtEventRequest struct {
// ID of the room to retrieve member servers for.
RoomID string `json:"room_id"`
// ID of the event for which to retrieve member servers.
EventID string `json:"event_id"`
}
// QueryServersInRoomAtEventResponse is a response to QueryServersInRoomAtEvent
type QueryServersInRoomAtEventResponse struct {
// Servers present in the room for these events.
Servers []gomatrixserverlib.ServerName `json:"servers"`
}
// RoomserverQueryAPI is used to query information from the room server. // RoomserverQueryAPI is used to query information from the room server.
type RoomserverQueryAPI interface { type RoomserverQueryAPI interface {
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.
@ -303,6 +317,12 @@ type RoomserverQueryAPI interface {
request *QueryBackfillRequest, request *QueryBackfillRequest,
response *QueryBackfillResponse, response *QueryBackfillResponse,
) error ) error
QueryServersInRoomAtEvent(
ctx context.Context,
request *QueryServersInRoomAtEventRequest,
response *QueryServersInRoomAtEventResponse,
) error
} }
// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. // RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API.
@ -332,8 +352,11 @@ const RoomserverQueryMissingEventsPath = "/api/roomserver/queryMissingEvents"
// RoomserverQueryStateAndAuthChainPath is the HTTP path for the QueryStateAndAuthChain API // RoomserverQueryStateAndAuthChainPath is the HTTP path for the QueryStateAndAuthChain API
const RoomserverQueryStateAndAuthChainPath = "/api/roomserver/queryStateAndAuthChain" const RoomserverQueryStateAndAuthChainPath = "/api/roomserver/queryStateAndAuthChain"
// RoomserverQueryBackfillPath is the HTTP path for the QueryBackfill API // RoomserverQueryBackfillPath is the HTTP path for the QueryBackfillPath API
const RoomserverQueryBackfillPath = "/api/roomserver/QueryBackfill" const RoomserverQueryBackfillPath = "/api/roomserver/queryBackfill"
// RoomserverQueryServersInRoomAtEventPath is the HTTP path for the QueryServersInRoomAtEvent API
const RoomserverQueryServersInRoomAtEventPath = "/api/roomserver/queryServersInRoomAtEvents"
// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API.
// If httpClient is nil then it uses the http.DefaultClient // If httpClient is nil then it uses the http.DefaultClient
@ -478,3 +501,16 @@ func (h *httpRoomserverQueryAPI) QueryBackfill(
apiURL := h.roomserverURL + RoomserverQueryBackfillPath apiURL := h.roomserverURL + RoomserverQueryBackfillPath
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// QueryServersInRoomAtEvent implements RoomServerQueryAPI
func (h *httpRoomserverQueryAPI) QueryServersInRoomAtEvent(
ctx context.Context,
request *QueryServersInRoomAtEventRequest,
response *QueryServersInRoomAtEventResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServersInRoomAtEvent")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryServersInRoomAtEventPath
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}

View File

@ -660,6 +660,41 @@ func getAuthChain(
return authEvents, nil return authEvents, nil
} }
// QueryServersInRoomAtEvent implements api.RoomserverQueryAPI
func (r *RoomserverQueryAPI) QueryServersInRoomAtEvent(
ctx context.Context,
request *api.QueryServersInRoomAtEventRequest,
response *api.QueryServersInRoomAtEventResponse,
) error {
// getMembershipsBeforeEventNID requires a NID, so retrieving the NID for
// the event is necessary.
NIDs, err := r.DB.EventNIDs(ctx, []string{request.EventID})
if err != nil {
return err
}
// Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all
// the servers in it at that moment.
events, err := r.getMembershipsBeforeEventNID(ctx, NIDs[request.EventID], true)
if err != nil {
return err
}
// Store the server names in a temporary map to avoid duplicates.
servers := make(map[gomatrixserverlib.ServerName]bool)
for _, event := range events {
servers[event.Origin()] = true
}
// Populate the response.
for server := range servers {
response.Servers = append(response.Servers, server)
}
return nil
}
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. // SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
// nolint: gocyclo // nolint: gocyclo
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
@ -803,4 +838,18 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
servMux.Handle(
api.RoomserverQueryServersInRoomAtEventPath,
common.MakeInternalAPI("QueryServersInRoomAtEvent", func(req *http.Request) util.JSONResponse {
var request api.QueryServersInRoomAtEventRequest
var response api.QueryServersInRoomAtEventResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServersInRoomAtEvent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View File

@ -90,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.SyncPosition{PDUPosition: pduPos}) s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.PaginationToken{PDUPosition: pduPos})
return nil return nil
} }

View File

@ -133,6 +133,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
msg.AddsStateEventIDs, msg.AddsStateEventIDs,
msg.RemovesStateEventIDs, msg.RemovesStateEventIDs,
msg.TransactionID, msg.TransactionID,
false,
) )
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
@ -144,7 +145,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure") }).Panicf("roomserver output log: write event failure")
return nil return nil
} }
s.notifier.OnNewEvent(&ev, "", nil, types.SyncPosition{PDUPosition: pduPos}) s.notifier.OnNewEvent(&ev, "", nil, types.PaginationToken{PDUPosition: pduPos})
return nil return nil
} }
@ -161,7 +162,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(&msg.Event, "", nil, types.SyncPosition{PDUPosition: pduPos}) s.notifier.OnNewEvent(&msg.Event, "", nil, types.PaginationToken{PDUPosition: pduPos})
return nil return nil
} }

View File

@ -63,7 +63,12 @@ func NewOutputTypingEventConsumer(
// Start consuming from typing api // Start consuming from typing api
func (s *OutputTypingEventConsumer) Start() error { func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent(nil, roomID, nil, types.SyncPosition{TypingPosition: latestSyncPosition}) s.notifier.OnNewEvent(
nil, roomID, nil,
types.PaginationToken{
EDUTypingPosition: types.StreamPosition(latestSyncPosition),
},
)
}) })
return s.typingConsumer.Start() return s.typingConsumer.Start()
@ -83,7 +88,7 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
"typing": output.Event.Typing, "typing": output.Event.Typing,
}).Debug("received data from typing server") }).Debug("received data from typing server")
var typingPos int64 var typingPos types.StreamPosition
typingEvent := output.Event typingEvent := output.Event
if typingEvent.Typing { if typingEvent.Typing {
typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime) typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime)
@ -91,6 +96,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
} }
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.SyncPosition{TypingPosition: typingPos}) s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.PaginationToken{EDUTypingPosition: typingPos})
return nil return nil
} }

482
syncapi/routing/messages.go Normal file
View File

@ -0,0 +1,482 @@
// Copyright 2018 New Vector Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"net/http"
"sort"
"strconv"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
)
type messagesReq struct {
ctx context.Context
db storage.Database
queryAPI api.RoomserverQueryAPI
federation *gomatrixserverlib.FederationClient
cfg *config.Dendrite
roomID string
from *types.PaginationToken
to *types.PaginationToken
wasToProvided bool
limit int
backwardOrdering bool
}
type messagesResp struct {
Start string `json:"start"`
End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
}
const defaultMessagesLimit = 10
// OnIncomingMessagesRequest implements the /messages endpoint from the
// client-server API.
// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
func OnIncomingMessagesRequest(
req *http.Request, db storage.Database, roomID string,
federation *gomatrixserverlib.FederationClient,
queryAPI api.RoomserverQueryAPI,
cfg *config.Dendrite,
) util.JSONResponse {
var err error
// Extract parameters from the request's URL.
// Pagination tokens.
from, err := types.NewPaginationTokenFromString(req.URL.Query().Get("from"))
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()),
}
}
// Direction to return events from.
dir := req.URL.Query().Get("dir")
if dir != "b" && dir != "f" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"),
}
}
// A boolean is easier to handle in this case, especially since dir is sure
// to have one of the two accepted values (so dir == "f" <=> !backwardOrdering).
backwardOrdering := (dir == "b")
// Pagination tokens. To is optional, and its default value depends on the
// direction ("b" or "f").
var to *types.PaginationToken
wasToProvided := true
if s := req.URL.Query().Get("to"); len(s) > 0 {
to, err = types.NewPaginationTokenFromString(s)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
}
}
} else {
// If "to" isn't provided, it defaults to either the earliest stream
// position (if we're going backward) or to the latest one (if we're
// going forward).
to, err = setToDefault(req.Context(), db, backwardOrdering, roomID)
if err != nil {
return httputil.LogThenError(req, err)
}
wasToProvided = false
}
// Maximum number of events to return; defaults to 10.
limit := defaultMessagesLimit
if len(req.URL.Query().Get("limit")) > 0 {
limit, err = strconv.Atoi(req.URL.Query().Get("limit"))
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()),
}
}
}
// TODO: Implement filtering (#587)
// Check the room ID's format.
if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Bad room ID: " + err.Error()),
}
}
mReq := messagesReq{
ctx: req.Context(),
db: db,
queryAPI: queryAPI,
federation: federation,
cfg: cfg,
roomID: roomID,
from: from,
to: to,
wasToProvided: wasToProvided,
limit: limit,
backwardOrdering: backwardOrdering,
}
clientEvents, start, end, err := mReq.retrieveEvents()
if err != nil {
return httputil.LogThenError(req, err)
}
// Respond with the events.
return util.JSONResponse{
Code: http.StatusOK,
JSON: messagesResp{
Chunk: clientEvents,
Start: start.String(),
End: end.String(),
},
}
}
// retrieveEvents retrieve events from the local database for a request on
// /messages. If there's not enough events to retrieve, it asks another
// homeserver in the room for older events.
// Returns an error if there was an issue talking to the database or with the
// remote homeserver.
func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start,
end *types.PaginationToken, err error,
) {
// Retrieve the events from the local database.
streamEvents, err := r.db.GetEventsInRange(
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
)
if err != nil {
return
}
var events []gomatrixserverlib.Event
// There can be two reasons for streamEvents to be empty: either we've
// reached the oldest event in the room (or the most recent one, depending
// on the ordering), or we've reached a backward extremity.
if len(streamEvents) == 0 {
if events, err = r.handleEmptyEventsSlice(); err != nil {
return
}
} else {
if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil {
return
}
}
// If we didn't get any event, we don't need to proceed any further.
if len(events) == 0 {
return []gomatrixserverlib.ClientEvent{}, r.from, r.to, nil
}
// Sort the events to ensure we send them in the right order. We currently
// do that based on the event's timestamp.
if r.backwardOrdering {
sort.SliceStable(events, func(i int, j int) bool {
// Backward ordering is antichronological (latest event to oldest
// one).
return sortEvents(&(events[j]), &(events[i]))
})
} else {
sort.SliceStable(events, func(i int, j int) bool {
// Forward ordering is chronological (oldest event to latest one).
return sortEvents(&(events[i]), &(events[j]))
})
}
// Convert all of the events into client events.
clientEvents = gomatrixserverlib.ToClientEvents(events, gomatrixserverlib.FormatAll)
// Get the position of the first and the last event in the room's topology.
// This position is currently determined by the event's depth, so we could
// also use it instead of retrieving from the database. However, if we ever
// change the way topological positions are defined (as depth isn't the most
// reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database.
startPos, err := r.db.EventPositionInTopology(
r.ctx, streamEvents[0].EventID(),
)
if err != nil {
return
}
endPos, err := r.db.EventPositionInTopology(
r.ctx, streamEvents[len(streamEvents)-1].EventID(),
)
if err != nil {
return
}
// Generate pagination tokens to send to the client using the positions
// retrieved previously.
start = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, startPos, 0,
)
end = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, endPos, 0,
)
if r.backwardOrdering {
// A stream/topological position is a cursor located between two events.
// While they are identified in the code by the event on their right (if
// we consider a left to right chronological order), tokens need to refer
// to them by the event on their left, therefore we need to decrement the
// end position we send in the response if we're going backward.
end.PDUPosition--
}
// The lowest token value is 1, therefore we need to manually set it to that
// value if we're below it.
if end.PDUPosition < types.StreamPosition(1) {
end.PDUPosition = types.StreamPosition(1)
}
return clientEvents, start, end, err
}
// handleEmptyEventsSlice handles the case where the initial request to the
// database returned an empty slice of events. It does so by checking whether
// the set is empty because we've reached a backward extremity, and if that is
// the case, by retrieving as much events as requested by backfilling from
// another homeserver.
// Returns an error if there was an issue talking with the database or
// backfilling.
func (r *messagesReq) handleEmptyEventsSlice() (
events []gomatrixserverlib.Event, err error,
) {
backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
// Check if we have backward extremities for this room.
if len(backwardExtremities) > 0 {
// If so, retrieve as much events as needed through backfilling.
events, err = r.backfill(backwardExtremities, r.limit)
if err != nil {
return
}
} else {
// If not, it means the slice was empty because we reached the room's
// creation, so return an empty slice.
events = []gomatrixserverlib.Event{}
}
return
}
// handleNonEmptyEventsSlice handles the case where the initial request to the
// database returned a non-empty slice of events. It does so by checking whether
// events are missing from the expected result, and retrieve missing events
// through backfilling if needed.
// Returns an error if there was an issue while backfilling.
func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent) (
events []gomatrixserverlib.Event, err error,
) {
// Check if we have enough events.
isSetLargeEnough := true
if len(streamEvents) < r.limit {
if r.backwardOrdering {
if r.wasToProvided {
// The condition in the SQL query is a strict "greater than" so
// we need to check against to-1.
streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition)
isSetLargeEnough = (r.to.PDUPosition-1 == streamPos)
}
} else {
streamPos := types.StreamPosition(streamEvents[0].StreamPosition)
isSetLargeEnough = (r.from.PDUPosition-1 == streamPos)
}
}
// Check if the slice contains a backward extremity.
backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
if err != nil {
return
}
// Backfill is needed if we've reached a backward extremity and need more
// events. It's only needed if the direction is backward.
if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering {
var pdus []gomatrixserverlib.Event
// Only ask the remote server for enough events to reach the limit.
pdus, err = r.backfill(backwardExtremities, r.limit-len(streamEvents))
if err != nil {
return
}
// Append the PDUs to the list to send back to the client.
events = append(events, pdus...)
}
// Append the events ve previously retrieved locally.
events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...)
return
}
// containsBackwardExtremity checks if a slice of StreamEvent contains a
// backward extremity. It does so by selecting the earliest event in the slice
// and by checking the presence in the database of all of its parent events, and
// considers the event itself a backward extremity if at least one of the parent
// events doesn't exist in the database.
// Returns an error if there was an issue with talking to the database.
func (r *messagesReq) containsBackwardExtremity(events []types.StreamEvent) (bool, error) {
// Select the earliest retrieved event.
var ev *types.StreamEvent
if r.backwardOrdering {
ev = &(events[len(events)-1])
} else {
ev = &(events[0])
}
// Get the earliest retrieved event's parents.
prevIDs := ev.PrevEventIDs()
prevs, err := r.db.Events(r.ctx, prevIDs)
if err != nil {
return false, nil
}
// Check if we have all of the events we requested. If not, it means we've
// reached a backward extremity.
var eventInDB bool
var id string
// Iterate over the IDs we used in the request.
for _, id = range prevIDs {
eventInDB = false
// Iterate over the events we got in response.
for _, ev := range prevs {
if ev.EventID() == id {
eventInDB = true
}
}
// One occurrence of one the event's parents not being present in the
// database is enough to say that the event is a backward extremity.
if !eventInDB {
return true, nil
}
}
return false, nil
}
// backfill performs a backfill request over the federation on another
// homeserver in the room.
// See: https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
// It also stores the PDUs retrieved from the remote homeserver's response to
// the database.
// Returns with an empty string if the remote homeserver didn't return with any
// event, or if there is no remote homeserver to contact.
// Returns an error if there was an issue with retrieving the list of servers in
// the room or sending the request.
func (r *messagesReq) backfill(fromEventIDs []string, limit int) ([]gomatrixserverlib.Event, error) {
// Query the list of servers in the room when one of the backward extremities
// was sent.
var serversResponse api.QueryServersInRoomAtEventResponse
serversRequest := api.QueryServersInRoomAtEventRequest{
RoomID: r.roomID,
EventID: fromEventIDs[0],
}
if err := r.queryAPI.QueryServersInRoomAtEvent(r.ctx, &serversRequest, &serversResponse); err != nil {
return nil, err
}
// Use the first server from the response, except if that server is us.
// In that case, use the second one if the roomserver responded with
// enough servers. If not, use an empty string to prevent the backfill
// from happening as there's no server to direct the request towards.
// TODO: Be smarter at selecting the server to direct the request
// towards.
srvToBackfillFrom := serversResponse.Servers[0]
if srvToBackfillFrom == r.cfg.Matrix.ServerName {
if len(serversResponse.Servers) > 1 {
srvToBackfillFrom = serversResponse.Servers[1]
} else {
srvToBackfillFrom = gomatrixserverlib.ServerName("")
log.Warn("Not enough servers to backfill from")
}
}
pdus := make([]gomatrixserverlib.Event, 0)
// If the roomserver responded with at least one server that isn't us,
// send it a request for backfill.
if len(srvToBackfillFrom) > 0 {
txn, err := r.federation.Backfill(
r.ctx, srvToBackfillFrom, r.roomID, limit, fromEventIDs,
)
if err != nil {
return nil, err
}
pdus = txn.PDUs
// Store the events in the database, while marking them as unfit to show
// up in responses to sync requests.
for _, pdu := range pdus {
if _, err = r.db.WriteEvent(
r.ctx, &pdu, []gomatrixserverlib.Event{}, []string{}, []string{},
nil, true,
); err != nil {
return nil, err
}
}
}
return pdus, nil
}
// setToDefault returns the default value for the "to" query parameter of a
// request to /messages if not provided. It defaults to either the earliest
// topological position (if we're going backward) or to the latest one (if we're
// going forward).
// Returns an error if there was an issue with retrieving the latest position
// from the database
func setToDefault(
ctx context.Context, db storage.Database, backwardOrdering bool,
roomID string,
) (to *types.PaginationToken, err error) {
if backwardOrdering {
to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 1, 0)
} else {
var pos types.StreamPosition
pos, err = db.MaxTopologicalPosition(ctx, roomID)
if err != nil {
return
}
to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, 0)
}
return
}
// sortEvents is a function to give to sort.SliceStable, and compares the
// timestamp of two Matrix events.
// Returns true if the first event happened before the second one, false
// otherwise.
func sortEvents(e1 *gomatrixserverlib.Event, e2 *gomatrixserverlib.Event) bool {
t := e1.OriginServerTS().Time()
return e2.OriginServerTS().Time().After(t)
}

View File

@ -22,8 +22,11 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -34,7 +37,12 @@ const pathPrefixR0 = "/_matrix/client/r0"
// Due to Setup being used to call many other functions, a gocyclo nolint is // Due to Setup being used to call many other functions, a gocyclo nolint is
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, deviceDB *devices.Database) { func Setup(
apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database,
deviceDB *devices.Database, federation *gomatrixserverlib.FederationClient,
queryAPI api.RoomserverQueryAPI,
cfg *config.Dendrite,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{ authData := auth.Data{
@ -71,4 +79,12 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, d
} }
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"]) return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"])
})).Methods(http.MethodGet, http.MethodOptions) })).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/messages", common.MakeAuthAPI("room_messages", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, queryAPI, cfg)
})).Methods(http.MethodGet, http.MethodOptions)
} }

View File

@ -21,6 +21,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
) )
@ -89,7 +90,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) insertAccountData(
ctx context.Context, ctx context.Context,
userID, roomID, dataType string, userID, roomID, dataType string,
) (pos int64, err error) { ) (pos types.StreamPosition, err error) {
err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos)
return return
} }
@ -97,7 +98,7 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountDataInRange( func (s *accountDataStatements) selectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
oldPos, newPos int64, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrix.FilterPart, accountDataFilterPart *gomatrix.FilterPart,
) (data map[string][]string, err error) { ) (data map[string][]string, err error) {
data = make(map[string][]string) data = make(map[string][]string)

View File

@ -0,0 +1,118 @@
// Copyright 2018 New Vector Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
)
const backwardExtremitiesSchema = `
-- Stores output room events received from the roomserver.
CREATE TABLE IF NOT EXISTS syncapi_backward_extremities (
-- The 'room_id' key for the event.
room_id TEXT NOT NULL,
-- The event ID for the event.
event_id TEXT NOT NULL,
PRIMARY KEY(room_id, event_id)
);
`
const insertBackwardExtremitySQL = "" +
"INSERT INTO syncapi_backward_extremities (room_id, event_id)" +
" VALUES ($1, $2)"
const selectBackwardExtremitiesForRoomSQL = "" +
"SELECT event_id FROM syncapi_backward_extremities WHERE room_id = $1"
const isBackwardExtremitySQL = "" +
"SELECT EXISTS (" +
" SELECT TRUE FROM syncapi_backward_extremities" +
" WHERE room_id = $1 AND event_id = $2" +
")"
const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND event_id = $2"
type backwardExtremitiesStatements struct {
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
isBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
}
func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(backwardExtremitiesSchema)
if err != nil {
return
}
if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil {
return
}
if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil {
return
}
if s.isBackwardExtremityStmt, err = db.Prepare(isBackwardExtremitySQL); err != nil {
return
}
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return
}
return
}
func (s *backwardExtremitiesStatements) insertsBackwardExtremity(
ctx context.Context, roomID, eventID string,
) (err error) {
_, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID)
return
}
func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom(
ctx context.Context, roomID string,
) (eventIDs []string, err error) {
eventIDs = make([]string, 0)
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
if err != nil {
return
}
for rows.Next() {
var eID string
if err = rows.Scan(&eID); err != nil {
return
}
eventIDs = append(eventIDs, eID)
}
return
}
func (s *backwardExtremitiesStatements) isBackwardExtremity(
ctx context.Context, roomID, eventID string,
) (isBE bool, err error) {
err = s.isBackwardExtremityStmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE)
return
}
func (s *backwardExtremitiesStatements) deleteBackwardExtremity(
ctx context.Context, roomID, eventID string,
) (err error) {
_, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID)
return
}

View File

@ -22,6 +22,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -87,10 +88,10 @@ const selectStateEventSQL = "" +
const selectEventsWithEventIDsSQL = "" + const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise // TODO: The session_id and transaction_id blanks are here because otherwise
// the rowsToStreamEvents expects there to be exactly four columns. We need to // the rowsToStreamEvents expects there to be exactly five columns. We need to
// figure out if these really need to be in the DB, and if so, we need a // figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020 // better permanent fix for this. - neilalexander, 2 Jan 2020
"SELECT added_at, event_json, 0 AS session_id, '' AS transaction_id" + "SELECT added_at, event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)" " FROM syncapi_current_room_state WHERE event_id = ANY($1)"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
@ -213,7 +214,7 @@ func (s *currentRoomStateStatements) deleteRoomStateByEventID(
func (s *currentRoomStateStatements) upsertRoomState( func (s *currentRoomStateStatements) upsertRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event gomatrixserverlib.Event, membership *string, addedAt int64, event gomatrixserverlib.Event, membership *string, addedAt types.StreamPosition,
) error { ) error {
// Parse content as JSON and search for an "url" key // Parse content as JSON and search for an "url" key
containsURL := false containsURL := false
@ -242,7 +243,7 @@ func (s *currentRoomStateStatements) upsertRoomState(
func (s *currentRoomStateStatements) selectEventsWithEventIDs( func (s *currentRoomStateStatements) selectEventsWithEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]streamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt) stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {

View File

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -86,7 +87,7 @@ func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) {
func (s *inviteEventsStatements) insertInviteEvent( func (s *inviteEventsStatements) insertInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (streamPos int64, err error) { ) (streamPos types.StreamPosition, err error) {
err = s.insertInviteEventStmt.QueryRowContext( err = s.insertInviteEventStmt.QueryRowContext(
ctx, ctx,
inviteEvent.RoomID(), inviteEvent.RoomID(),
@ -107,7 +108,7 @@ func (s *inviteEventsStatements) deleteInviteEvent(
// selectInviteEventsInRange returns a map of room ID to invite event for the // selectInviteEventsInRange returns a map of room ID to invite event for the
// active invites for the target user ID in the supplied range. // active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) selectInviteEventsInRange( func (s *inviteEventsStatements) selectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos int64, ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition,
) (map[string]gomatrixserverlib.Event, error) { ) (map[string]gomatrixserverlib.Event, error) {
stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos) rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos)

View File

@ -22,6 +22,7 @@ import (
"sort" "sort"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/lib/pq" "github.com/lib/pq"
@ -56,8 +57,15 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
-- if there is no delta. -- if there is no delta.
add_state_ids TEXT[], add_state_ids TEXT[],
remove_state_ids TEXT[], remove_state_ids TEXT[],
session_id BIGINT, -- The client session that sent the event, if any -- The client session that sent the event, if any
transaction_id TEXT -- The transaction id used to send the event, if any session_id BIGINT,
-- The transaction id used to send the event, if any
transaction_id TEXT,
-- Should the event be excluded from responses to /sync requests. Useful for
-- events retrieved through backfilling that have a position in the stream
-- that relates to the moment these were retrieved rather than the moment these
-- were emitted.
exclude_from_sync BOOL DEFAULT FALSE
); );
-- for event selection -- for event selection
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_events(event_id); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_events(event_id);
@ -65,23 +73,33 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" + "INSERT INTO syncapi_output_room_events (" +
"room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" + "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id" ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4" " ORDER BY id DESC LIMIT $4"
const selectRecentEventsForSyncSQL = "" +
"SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" ORDER BY id DESC LIMIT $4"
const selectEarlyEventsSQL = "" +
"SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4"
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events" "SELECT MAX(id) FROM syncapi_output_room_events"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" + const selectStateInRangeSQL = "" +
"SELECT id, event_json, add_state_ids, remove_state_ids" + "SELECT id, event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" AND ( $3::text[] IS NULL OR sender = ANY($3) )" + " AND ( $3::text[] IS NULL OR sender = ANY($3) )" +
@ -97,6 +115,8 @@ type outputRoomEventsStatements struct {
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
} }
@ -117,6 +137,12 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil { if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
return return
} }
if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil {
return
}
if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil {
return
}
if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
return return
} }
@ -127,9 +153,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) selectStateInRange( func (s *outputRoomEventsStatements) selectStateInRange(
ctx context.Context, txn *sql.Tx, oldPos, newPos int64, ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition,
stateFilterPart *gomatrix.FilterPart, stateFilterPart *gomatrix.FilterPart,
) (map[string]map[string]bool, map[string]streamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectStateInRangeStmt) stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
@ -149,19 +175,20 @@ func (s *outputRoomEventsStatements) selectStateInRange(
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes // - For each room ID, build up an array of event IDs which represents cumulative adds/removes
// For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID // For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID
// if they aren't in the event ID cache. We don't handle state deletion yet. // if they aren't in the event ID cache. We don't handle state deletion yet.
eventIDToEvent := make(map[string]streamEvent) eventIDToEvent := make(map[string]types.StreamEvent)
// RoomID => A set (map[string]bool) of state event IDs which are between the two positions // RoomID => A set (map[string]bool) of state event IDs which are between the two positions
stateNeeded := make(map[string]map[string]bool) stateNeeded := make(map[string]map[string]bool)
for rows.Next() { for rows.Next() {
var ( var (
streamPos int64 streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool
addIDs pq.StringArray addIDs pq.StringArray
delIDs pq.StringArray delIDs pq.StringArray
) )
if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil { if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil {
return nil, nil, err return nil, nil, err
} }
// Sanity check for deleted state and whine if we see it. We don't need to do anything // Sanity check for deleted state and whine if we see it. We don't need to do anything
@ -192,9 +219,10 @@ func (s *outputRoomEventsStatements) selectStateInRange(
} }
stateNeeded[ev.RoomID()] = needSet stateNeeded[ev.RoomID()] = needSet
eventIDToEvent[ev.EventID()] = streamEvent{ eventIDToEvent[ev.EventID()] = types.StreamEvent{
Event: ev, Event: ev,
streamPosition: streamPos, StreamPosition: streamPos,
ExcludeFromSync: excludeFromSync,
} }
} }
@ -221,8 +249,8 @@ func (s *outputRoomEventsStatements) selectMaxEventID(
func (s *outputRoomEventsStatements) insertEvent( func (s *outputRoomEventsStatements) insertEvent(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.Event, addState, removeState []string, event *gomatrixserverlib.Event, addState, removeState []string,
transactionID *api.TransactionID, transactionID *api.TransactionID, excludeFromSync bool,
) (streamPos int64, err error) { ) (streamPos types.StreamPosition, err error) {
var txnID *string var txnID *string
var sessionID *int64 var sessionID *int64
if transactionID != nil { if transactionID != nil {
@ -251,16 +279,53 @@ func (s *outputRoomEventsStatements) insertEvent(
pq.StringArray(removeState), pq.StringArray(removeState),
sessionID, sessionID,
txnID, txnID,
excludeFromSync,
).Scan(&streamPos) ).Scan(&streamPos)
return return
} }
// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. // selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude
// from sync.
func (s *outputRoomEventsStatements) selectRecentEvents( func (s *outputRoomEventsStatements) selectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, fromPos, toPos int64, limit int, roomID string, fromPos, toPos types.StreamPosition, limit int,
) ([]streamEvent, error) { chronologicalOrder bool, onlySyncEvents bool,
stmt := common.TxStmt(txn, s.selectRecentEventsStmt) ) ([]types.StreamEvent, error) {
var stmt *sql.Stmt
if onlySyncEvents {
stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt)
} else {
stmt = common.TxStmt(txn, s.selectRecentEventsStmt)
}
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
events, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if chronologicalOrder {
// The events need to be returned from oldest to latest, which isn't
// necessary the way the SQL query returns them, so a sort is necessary to
// ensure the events are in the right order in the slice.
sort.SliceStable(events, func(i int, j int) bool {
return events[i].StreamPosition < events[j].StreamPosition
})
}
return events, nil
}
// selectEarlyEvents returns the earliest events in the given room, starting
// from a given position, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) selectEarlyEvents(
ctx context.Context, txn *sql.Tx,
roomID string, fromPos, toPos types.StreamPosition, limit int,
) ([]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
if err != nil { if err != nil {
return nil, err return nil, err
@ -274,16 +339,16 @@ func (s *outputRoomEventsStatements) selectRecentEvents(
// necessarily the way the SQL query returns them, so a sort is necessary to // necessarily the way the SQL query returns them, so a sort is necessary to
// ensure the events are in the right order in the slice. // ensure the events are in the right order in the slice.
sort.SliceStable(events, func(i int, j int) bool { sort.SliceStable(events, func(i int, j int) bool {
return events[i].streamPosition < events[j].streamPosition return events[i].StreamPosition < events[j].StreamPosition
}) })
return events, nil return events, nil
} }
// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing // selectEvents returns the events for the given event IDs. If an event is
// from the database. // missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) selectEvents( func (s *outputRoomEventsStatements) selectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]streamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectEventsStmt) stmt := common.TxStmt(txn, s.selectEventsStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil { if err != nil {
@ -293,17 +358,18 @@ func (s *outputRoomEventsStatements) selectEvents(
return rowsToStreamEvents(rows) return rowsToStreamEvents(rows)
} }
func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []streamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {
var ( var (
streamPos int64 streamPos types.StreamPosition
eventBytes []byte eventBytes []byte
excludeFromSync bool
sessionID *int64 sessionID *int64
txnID *string txnID *string
transactionID *api.TransactionID transactionID *api.TransactionID
) )
if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil { if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
return nil, err return nil, err
} }
// TODO: Handle redacted events // TODO: Handle redacted events
@ -319,10 +385,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
} }
} }
result = append(result, streamEvent{ result = append(result, types.StreamEvent{
Event: ev, Event: ev,
streamPosition: streamPos, StreamPosition: streamPos,
transactionID: transactionID, TransactionID: transactionID,
ExcludeFromSync: excludeFromSync,
}) })
} }
return result, nil return result, nil

View File

@ -0,0 +1,188 @@
// Copyright 2018 New Vector Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const outputRoomEventsTopologySchema = `
-- Stores output room events received from the roomserver.
CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
-- The event ID for the event.
event_id TEXT PRIMARY KEY,
-- The place of the event in the room's topology. This can usually be determined
-- from the event's depth.
topological_position BIGINT NOT NULL,
-- The 'room_id' key for the event.
room_id TEXT NOT NULL
);
-- The topological order will be used in events selection and ordering
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id);
`
const insertEventInTopologySQL = "" +
"INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING"
const selectEventIDsInRangeASCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" +
" ORDER BY topological_position ASC LIMIT $4"
const selectEventIDsInRangeDESCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" +
" ORDER BY topological_position DESC LIMIT $4"
const selectPositionInTopologySQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology" +
" WHERE event_id = $1"
const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1"
const selectEventIDsFromPositionSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND topological_position = $2"
type outputRoomEventsTopologyStatements struct {
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt
selectEventIDsFromPositionStmt *sql.Stmt
}
func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(outputRoomEventsTopologySchema)
if err != nil {
return
}
if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil {
return
}
if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil {
return
}
if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil {
return
}
if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
return
}
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return
}
if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil {
return
}
return
}
// insertEventInTopology inserts the given event in the room's topology, based
// on the event's depth.
func (s *outputRoomEventsTopologyStatements) insertEventInTopology(
ctx context.Context, event *gomatrixserverlib.Event,
) (err error) {
_, err = s.insertEventInTopologyStmt.ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(),
)
return
}
// selectEventIDsInRange selects the IDs of events which positions are within a
// given range in a given room's topological order.
// Returns an empty slice if no events match the given range.
func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
ctx context.Context, roomID string, fromPos, toPos types.StreamPosition,
limit int, chronologicalOrder bool,
) (eventIDs []string, err error) {
// Decide on the selection's order according to whether chronological order
// is requested or not.
var stmt *sql.Stmt
if chronologicalOrder {
stmt = s.selectEventIDsInRangeASCStmt
} else {
stmt = s.selectEventIDsInRangeDESCStmt
}
// Query the event IDs.
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
if err == sql.ErrNoRows {
// If no event matched the request, return an empty slice.
return []string{}, nil
} else if err != nil {
return
}
// Return the IDs.
var eventID string
for rows.Next() {
if err = rows.Scan(&eventID); err != nil {
return
}
eventIDs = append(eventIDs, eventID)
}
return
}
// selectPositionInTopology returns the position of a given event in the
// topology of the room it belongs to.
func (s *outputRoomEventsTopologyStatements) selectPositionInTopology(
ctx context.Context, eventID string,
) (pos types.StreamPosition, err error) {
err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos)
return
}
func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology(
ctx context.Context, roomID string,
) (pos types.StreamPosition, err error) {
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos)
return
}
// selectEventIDsFromPosition returns the IDs of all events that have a given
// position in the topology of a given room.
func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition(
ctx context.Context, roomID string, pos types.StreamPosition,
) (eventIDs []string, err error) {
// Query the event IDs.
rows, err := s.selectEventIDsFromPositionStmt.QueryContext(ctx, roomID, pos)
if err == sql.ErrNoRows {
// If no event matched the request, return an empty slice.
return []string{}, nil
} else if err != nil {
return
}
// Return the IDs.
var eventID string
for rows.Next() {
if err = rows.Scan(&eventID); err != nil {
return
}
eventIDs = append(eventIDs, eventID)
}
return
}

View File

@ -20,7 +20,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -43,17 +42,10 @@ type stateDelta struct {
membership string membership string
// The PDU stream position of the latest membership event for this user, if applicable. // The PDU stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta. // Can be 0 if there is no membership event in this delta.
membershipPos int64 membershipPos types.StreamPosition
} }
// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. // SyncServerDatasource represents a sync server datasource which manages
type streamEvent struct {
gomatrixserverlib.Event
streamPosition int64
transactionID *api.TransactionID
}
// SyncServerDatabase represents a sync server datasource which manages
// both the database for PDUs and caches for EDUs. // both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct { type SyncServerDatasource struct {
db *sql.DB db *sql.DB
@ -63,9 +55,11 @@ type SyncServerDatasource struct {
roomstate currentRoomStateStatements roomstate currentRoomStateStatements
invites inviteEventsStatements invites inviteEventsStatements
typingCache *cache.TypingCache typingCache *cache.TypingCache
topology outputRoomEventsTopologyStatements
backwardExtremities backwardExtremitiesStatements
} }
// NewSyncServerDatabase creates a new sync server database // NewSyncServerDatasource creates a new sync server database
func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) { func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) {
var d SyncServerDatasource var d SyncServerDatasource
var err error var err error
@ -87,6 +81,12 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er
if err := d.invites.prepare(d.db); err != nil { if err := d.invites.prepare(d.db); err != nil {
return nil, err return nil, err
} }
if err := d.topology.prepare(d.db); err != nil {
return nil, err
}
if err := d.backwardExtremities.prepare(d.db); err != nil {
return nil, err
}
d.typingCache = cache.NewTypingCache() d.typingCache = cache.NewTypingCache()
return &d, nil return &d, nil
} }
@ -109,7 +109,46 @@ func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([
// We don't include a device here as we only include transaction IDs in // We don't include a device here as we only include transaction IDs in
// incremental syncs. // incremental syncs.
return streamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(nil, streamEvents), nil
}
func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, ev *gomatrixserverlib.Event) error {
// If the event is already known as a backward extremity, don't consider
// it as such anymore now that we have it.
isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, ev.RoomID(), ev.EventID())
if err != nil {
return err
}
if isBackwardExtremity {
if err = d.backwardExtremities.deleteBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil {
return err
}
}
// Check if we have all of the event's previous events. If an event is
// missing, add it to the room's backward extremities.
prevEvents, err := d.events.selectEvents(ctx, nil, ev.PrevEventIDs())
if err != nil {
return err
}
var found bool
for _, eID := range ev.PrevEventIDs() {
found = false
for _, prevEv := range prevEvents {
if eID == prevEv.EventID() {
found = true
}
}
// If the event is missing, consider it a backward extremity.
if !found {
if err = d.backwardExtremities.insertsBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil {
return err
}
}
}
return nil
} }
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
@ -120,16 +159,26 @@ func (d *SyncServerDatasource) WriteEvent(
ev *gomatrixserverlib.Event, ev *gomatrixserverlib.Event,
addStateEvents []gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event,
addStateEventIDs, removeStateEventIDs []string, addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, transactionID *api.TransactionID, excludeFromSync bool,
) (pduPosition int64, returnErr error) { ) (pduPosition types.StreamPosition, returnErr error) {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID) pos, err := d.events.insertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
)
if err != nil { if err != nil {
return err return err
} }
pduPosition = pos pduPosition = pos
if err = d.topology.insertEventInTopology(ctx, ev); err != nil {
return err
}
if err = d.handleBackwardExtremities(ctx, ev); err != nil {
return err
}
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event. // Nothing to do, the event may have just been a message event.
return nil return nil
@ -137,14 +186,15 @@ func (d *SyncServerDatasource) WriteEvent(
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition) return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition)
}) })
return
return pduPosition, returnErr
} }
func (d *SyncServerDatasource) updateRoomState( func (d *SyncServerDatasource) updateRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
removedEventIDs []string, removedEventIDs []string,
addedEvents []gomatrixserverlib.Event, addedEvents []gomatrixserverlib.Event,
pduPosition int64, pduPosition types.StreamPosition,
) error { ) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs { for _, eventID := range removedEventIDs {
@ -196,14 +246,141 @@ func (d *SyncServerDatasource) GetStateEventsForRoom(
return return
} }
// GetEventsInRange retrieves all of the events on a given ordering using the
// given extremities and limit.
func (d *SyncServerDatasource) GetEventsInRange(
ctx context.Context,
from, to *types.PaginationToken,
roomID string, limit int,
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
// If the pagination token's type is types.PaginationTokenTypeTopology, the
// events must be retrieved from the rooms' topology table rather than the
// table contaning the syncapi server's whole stream of events.
if from.Type == types.PaginationTokenTypeTopology {
// Determine the backward and forward limit, i.e. the upper and lower
// limits to the selection in the room's topology, from the direction.
var backwardLimit, forwardLimit types.StreamPosition
if backwardOrdering {
// Backward ordering is antichronological (latest event to oldest
// one).
backwardLimit = to.PDUPosition
forwardLimit = from.PDUPosition
} else {
// Forward ordering is chronological (oldest event to latest one).
backwardLimit = from.PDUPosition
forwardLimit = to.PDUPosition
}
// Select the event IDs from the defined range.
var eIDs []string
eIDs, err = d.topology.selectEventIDsInRange(
ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering,
)
if err != nil {
return
}
// Retrieve the events' contents using their IDs.
events, err = d.events.selectEvents(ctx, nil, eIDs)
return
}
// If the pagination token's type is types.PaginationTokenTypeStream, the
// events must be retrieved from the table contaning the syncapi server's
// whole stream of events.
if backwardOrdering {
// When using backward ordering, we want the most recent events first.
if events, err = d.events.selectRecentEvents(
ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false,
); err != nil {
return
}
} else {
// When using forward ordering, we want the least recent events first.
if events, err = d.events.selectEarlyEvents(
ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit,
); err != nil {
return
}
}
return
}
// SyncPosition returns the latest positions for syncing. // SyncPosition returns the latest positions for syncing.
func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) { func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) {
return d.syncPositionTx(ctx, nil) return d.syncPositionTx(ctx, nil)
} }
// BackwardExtremitiesForRoom returns the event IDs of all of the backward
// extremities we know of for a given room.
func (d *SyncServerDatasource) BackwardExtremitiesForRoom(
ctx context.Context, roomID string,
) (backwardExtremities []string, err error) {
return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID)
}
// MaxTopologicalPosition returns the highest topological position for a given
// room.
func (d *SyncServerDatasource) MaxTopologicalPosition(
ctx context.Context, roomID string,
) (types.StreamPosition, error) {
return d.topology.selectMaxPositionInTopology(ctx, roomID)
}
// EventsAtTopologicalPosition returns all of the events matching a given
// position in the topology of a given room.
func (d *SyncServerDatasource) EventsAtTopologicalPosition(
ctx context.Context, roomID string, pos types.StreamPosition,
) ([]types.StreamEvent, error) {
eIDs, err := d.topology.selectEventIDsFromPosition(ctx, roomID, pos)
if err != nil {
return nil, err
}
return d.events.selectEvents(ctx, nil, eIDs)
}
func (d *SyncServerDatasource) EventPositionInTopology(
ctx context.Context, eventID string,
) (types.StreamPosition, error) {
return d.topology.selectPositionInTopology(ctx, eventID)
}
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {
return d.syncStreamPositionTx(ctx, nil)
}
func (d *SyncServerDatasource) syncStreamPositionTx(
ctx context.Context, txn *sql.Tx,
) (types.StreamPosition, error) {
maxID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil {
return 0, err
}
maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
if err != nil {
return 0, err
}
if maxAccountDataID > maxID {
maxID = maxAccountDataID
}
maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
if err != nil {
return 0, err
}
if maxInviteID > maxID {
maxID = maxInviteID
}
return types.StreamPosition(maxID), nil
}
func (d *SyncServerDatasource) syncPositionTx( func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (sp types.SyncPosition, err error) { ) (sp types.PaginationToken, err error) {
maxEventID, err := d.events.selectMaxEventID(ctx, txn) maxEventID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil { if err != nil {
@ -223,10 +400,8 @@ func (d *SyncServerDatasource) syncPositionTx(
if maxInviteID > maxEventID { if maxInviteID > maxEventID {
maxEventID = maxInviteID maxEventID = maxInviteID
} }
sp.PDUPosition = maxEventID sp.PDUPosition = types.StreamPosition(maxEventID)
sp.EDUTypingPosition = types.StreamPosition(d.typingCache.GetLatestSyncPosition())
sp.TypingPosition = d.typingCache.GetLatestSyncPosition()
return return
} }
@ -235,7 +410,7 @@ func (d *SyncServerDatasource) syncPositionTx(
func (d *SyncServerDatasource) addPDUDeltaToResponse( func (d *SyncServerDatasource) addPDUDeltaToResponse(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos int64, fromPos, toPos types.StreamPosition,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
res *types.Response, res *types.Response,
@ -287,7 +462,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// addTypingDeltaToResponse adds all typing notifications to a sync response // addTypingDeltaToResponse adds all typing notifications to a sync response
// since the specified position. // since the specified position.
func (d *SyncServerDatasource) addTypingDeltaToResponse( func (d *SyncServerDatasource) addTypingDeltaToResponse(
since int64, since types.PaginationToken,
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) error { ) error {
@ -296,7 +471,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
var err error var err error
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter( if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter(
roomID, since, roomID, int64(since.EDUTypingPosition),
); updated { ); updated {
ev := gomatrixserverlib.ClientEvent{ ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping, Type: gomatrixserverlib.MTyping,
@ -321,14 +496,14 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
// the positions of that type are not equal in fromPos and toPos. // the positions of that type are not equal in fromPos and toPos.
func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) addEDUDeltaToResponse(
fromPos, toPos types.SyncPosition, fromPos, toPos types.PaginationToken,
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) (err error) { ) (err error) {
if fromPos.TypingPosition != toPos.TypingPosition { if fromPos.EDUTypingPosition != toPos.EDUTypingPosition {
err = d.addTypingDeltaToResponse( err = d.addTypingDeltaToResponse(
fromPos.TypingPosition, joinedRoomIDs, res, fromPos, joinedRoomIDs, res,
) )
} }
@ -343,7 +518,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse(
func (d *SyncServerDatasource) IncrementalSync( func (d *SyncServerDatasource) IncrementalSync(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.SyncPosition, fromPos, toPos types.PaginationToken,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
) (*types.Response, error) { ) (*types.Response, error) {
@ -383,7 +558,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) ( ) (
res *types.Response, res *types.Response,
toPos types.SyncPosition, toPos types.PaginationToken,
joinedRoomIDs []string, joinedRoomIDs []string,
err error, err error,
) { ) {
@ -423,27 +598,37 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
} }
// TODO: When filters are added, we may need to call this multiple times to get enough events. // TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []streamEvent var recentStreamEvents []types.StreamEvent
recentStreamEvents, err = d.events.selectRecentEvents( recentStreamEvents, err = d.events.selectRecentEvents(
ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition,
numRecentEventsPerRoom, true, true,
//ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom,
) )
if err != nil { if err != nil {
return return
} }
// Retrieve the backward topology position, i.e. the position of the
// oldest event in the room's topology.
var backwardTopologyPos types.StreamPosition
backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID())
if err != nil {
return nil, types.PaginationToken{}, []string{}, err
}
if backwardTopologyPos-1 <= 0 {
backwardTopologyPos = types.StreamPosition(1)
} else {
backwardTopologyPos = backwardTopologyPos - 1
}
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs // transaction IDs for complete syncs
recentEvents := streamEventsToEvents(nil, recentStreamEvents) recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
// Use the short form of batch token for prev_batch types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) ).String()
} else {
// Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = "1"
}
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true jr.Timeline.Limited = true
jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
@ -471,7 +656,7 @@ func (d *SyncServerDatasource) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added. // Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
types.SyncPosition{}, toPos, joinedRoomIDs, res, types.PaginationToken{}, toPos, joinedRoomIDs, res,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -496,7 +681,7 @@ var txReadOnlySnapshot = sql.TxOptions{
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
func (d *SyncServerDatasource) GetAccountDataInRange( func (d *SyncServerDatasource) GetAccountDataInRange(
ctx context.Context, userID string, oldPos, newPos int64, ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrix.FilterPart, accountDataFilterPart *gomatrix.FilterPart,
) (map[string][]string, error) { ) (map[string][]string, error) {
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
@ -510,7 +695,7 @@ func (d *SyncServerDatasource) GetAccountDataInRange(
// Returns an error if there was an issue with the upsert // Returns an error if there was an issue with the upsert
func (d *SyncServerDatasource) UpsertAccountData( func (d *SyncServerDatasource) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string, ctx context.Context, userID, roomID, dataType string,
) (int64, error) { ) (types.StreamPosition, error) {
return d.accountData.insertAccountData(ctx, userID, roomID, dataType) return d.accountData.insertAccountData(ctx, userID, roomID, dataType)
} }
@ -519,7 +704,7 @@ func (d *SyncServerDatasource) UpsertAccountData(
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatasource) AddInviteEvent( func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event, ctx context.Context, inviteEvent gomatrixserverlib.Event,
) (int64, error) { ) (types.StreamPosition, error) {
return d.invites.insertInviteEvent(ctx, inviteEvent) return d.invites.insertInviteEvent(ctx, inviteEvent)
} }
@ -542,26 +727,26 @@ func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallback
// Returns the newly calculated sync position for typing notifications. // Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) AddTypingUser( func (d *SyncServerDatasource) AddTypingUser(
userID, roomID string, expireTime *time.Time, userID, roomID string, expireTime *time.Time,
) int64 { ) types.StreamPosition {
return d.typingCache.AddTypingUser(userID, roomID, expireTime) return types.StreamPosition(d.typingCache.AddTypingUser(userID, roomID, expireTime))
} }
// RemoveTypingUser removes a typing user from the typing cache. // RemoveTypingUser removes a typing user from the typing cache.
// Returns the newly calculated sync position for typing notifications. // Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) RemoveTypingUser( func (d *SyncServerDatasource) RemoveTypingUser(
userID, roomID string, userID, roomID string,
) int64 { ) types.StreamPosition {
return d.typingCache.RemoveUser(userID, roomID) return types.StreamPosition(d.typingCache.RemoveUser(userID, roomID))
} }
func (d *SyncServerDatasource) addInvitesToResponse( func (d *SyncServerDatasource) addInvitesToResponse(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userID string,
fromPos, toPos int64, fromPos, toPos types.StreamPosition,
res *types.Response, res *types.Response,
) error { ) error {
invites, err := d.invites.selectInviteEventsInRange( invites, err := d.invites.selectInviteEventsInRange(
ctx, txn, userID, int64(fromPos), int64(toPos), ctx, txn, userID, fromPos, toPos,
) )
if err != nil { if err != nil {
return err return err
@ -577,12 +762,32 @@ func (d *SyncServerDatasource) addInvitesToResponse(
return nil return nil
} }
// Retrieve the backward topology position, i.e. the position of the
// oldest event in the room's topology.
func (d *SyncServerDatasource) getBackwardTopologyPos(
ctx context.Context,
events []types.StreamEvent,
) (pos types.StreamPosition, err error) {
if len(events) > 0 {
pos, err = d.topology.selectPositionInTopology(ctx, events[0].EventID())
if err != nil {
return
}
}
if pos-1 <= 0 {
pos = types.StreamPosition(1)
} else {
pos = pos - 1
}
return
}
// addRoomDeltaToResponse adds a room state delta to a sync response // addRoomDeltaToResponse adds a room state delta to a sync response
func (d *SyncServerDatasource) addRoomDeltaToResponse( func (d *SyncServerDatasource) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
device *authtypes.Device, device *authtypes.Device,
txn *sql.Tx, txn *sql.Tx,
fromPos, toPos int64, fromPos, toPos types.StreamPosition,
delta stateDelta, delta stateDelta,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
res *types.Response, res *types.Response,
@ -598,38 +803,28 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
endPos = delta.membershipPos endPos = delta.membershipPos
} }
recentStreamEvents, err := d.events.selectRecentEvents( recentStreamEvents, err := d.events.selectRecentEvents(
ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom, ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos),
numRecentEventsPerRoom, true, true,
) )
if err != nil { if err != nil {
return err return err
} }
recentEvents := streamEventsToEvents(device, recentStreamEvents) recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
var prevPDUPos int64 var backwardTopologyPos types.StreamPosition
backwardTopologyPos, err = d.getBackwardTopologyPos(ctx, recentStreamEvents)
if len(recentEvents) == 0 { if err != nil {
if len(delta.stateEvents) == 0 { return err
// Don't bother appending empty room entries
return nil
}
// If full_state=true and since is already up to date, then we'll have
// state events but no recent events.
prevPDUPos = toPos - 1
} else {
prevPDUPos = recentStreamEvents[0].streamPosition - 1
}
if prevPDUPos <= 0 {
prevPDUPos = 1
} }
switch delta.membership { switch delta.membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
// Use the short form of batch token for prev_batch
jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
).String()
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -640,8 +835,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are // TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
// Use the short form of batch token for prev_batch lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
).String()
lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -656,9 +852,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
func (d *SyncServerDatasource) fetchStateEvents( func (d *SyncServerDatasource) fetchStateEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomIDToEventIDSet map[string]map[string]bool, roomIDToEventIDSet map[string]map[string]bool,
eventIDToEvent map[string]streamEvent, eventIDToEvent map[string]types.StreamEvent,
) (map[string][]streamEvent, error) { ) (map[string][]types.StreamEvent, error) {
stateBetween := make(map[string][]streamEvent) stateBetween := make(map[string][]types.StreamEvent)
missingEvents := make(map[string][]string) missingEvents := make(map[string][]string)
for roomID, ids := range roomIDToEventIDSet { for roomID, ids := range roomIDToEventIDSet {
events := stateBetween[roomID] events := stateBetween[roomID]
@ -700,7 +896,7 @@ func (d *SyncServerDatasource) fetchStateEvents(
func (d *SyncServerDatasource) fetchMissingStateEvents( func (d *SyncServerDatasource) fetchMissingStateEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]streamEvent, error) { ) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the // Fetch from the events table first so we pick up the stream ID for the
// event. // event.
events, err := d.events.selectEvents(ctx, txn, eventIDs) events, err := d.events.selectEvents(ctx, txn, eventIDs)
@ -743,7 +939,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents(
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
func (d *SyncServerDatasource) getStateDeltas( func (d *SyncServerDatasource) getStateDeltas(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos int64, userID string, fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrix.FilterPart, stateFilterPart *gomatrix.FilterPart,
) ([]stateDelta, []string, error) { ) ([]stateDelta, []string, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
@ -776,7 +972,7 @@ func (d *SyncServerDatasource) getStateDeltas(
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership == gomatrixserverlib.Join { if membership == gomatrixserverlib.Join {
// send full room state down instead of a delta // send full room state down instead of a delta
var s []streamEvent var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart) s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -787,8 +983,8 @@ func (d *SyncServerDatasource) getStateDeltas(
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: membership, membership: membership,
membershipPos: ev.streamPosition, membershipPos: ev.StreamPosition,
stateEvents: streamEventsToEvents(device, stateStreamEvents), stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
roomID: roomID, roomID: roomID,
}) })
break break
@ -804,7 +1000,7 @@ func (d *SyncServerDatasource) getStateDeltas(
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Join, membership: gomatrixserverlib.Join,
stateEvents: streamEventsToEvents(device, state[joinedRoomID]), stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
roomID: joinedRoomID, roomID: joinedRoomID,
}) })
} }
@ -818,7 +1014,7 @@ func (d *SyncServerDatasource) getStateDeltas(
// updates for other rooms. // updates for other rooms.
func (d *SyncServerDatasource) getStateDeltasForFullStateSync( func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos int64, userID string, fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrix.FilterPart, stateFilterPart *gomatrix.FilterPart,
) ([]stateDelta, []string, error) { ) ([]stateDelta, []string, error) {
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -837,7 +1033,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
} }
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Join, membership: gomatrixserverlib.Join,
stateEvents: streamEventsToEvents(device, s), stateEvents: d.StreamEventsToEvents(device, s),
roomID: joinedRoomID, roomID: joinedRoomID,
}) })
} }
@ -858,8 +1054,8 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: membership, membership: membership,
membershipPos: ev.streamPosition, membershipPos: ev.StreamPosition,
stateEvents: streamEventsToEvents(device, stateStreamEvents), stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
roomID: roomID, roomID: roomID,
}) })
} }
@ -875,29 +1071,29 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
func (d *SyncServerDatasource) currentStateStreamEventsForRoom( func (d *SyncServerDatasource) currentStateStreamEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrix.FilterPart, stateFilterPart *gomatrix.FilterPart,
) ([]streamEvent, error) { ) ([]types.StreamEvent, error) {
allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s := make([]streamEvent, len(allState)) s := make([]types.StreamEvent, len(allState))
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
s[i] = streamEvent{Event: allState[i], streamPosition: 0} s[i] = types.StreamEvent{Event: allState[i], StreamPosition: 0}
} }
return s, nil return s, nil
} }
// streamEventsToEvents converts streamEvent to Event. If device is non-nil and // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.
func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event { func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event {
out := make([]gomatrixserverlib.Event, len(in)) out := make([]gomatrixserverlib.Event, len(in))
for i := 0; i < len(in); i++ { for i := 0; i < len(in); i++ {
out[i] = in[i].Event out[i] = in[i].Event
if device != nil && in[i].transactionID != nil { if device != nil && in[i].TransactionID != nil {
if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID { if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField( err := out[i].SetUnsignedField(
"transaction_id", in[i].transactionID.TransactionID, "transaction_id", in[i].TransactionID.TransactionID,
) )
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{

View File

@ -33,19 +33,26 @@ type Database interface {
common.PartitionStorer common.PartitionStorer
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error)
WriteEvent(ctx context.Context, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID) (pduPosition int64, returnErr error) WriteEvent(context.Context, *gomatrixserverlib.Event, []gomatrixserverlib.Event, []string, []string, *api.TransactionID, bool) (types.StreamPosition, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error)
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error) GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error)
SyncPosition(ctx context.Context) (types.SyncPosition, error) SyncPosition(ctx context.Context) (types.PaginationToken, error)
IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.SyncPosition, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error)
GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos int64, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error) GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error)
UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (int64, error) UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error)
AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (int64, error) AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (types.StreamPosition, error)
RetireInviteEvent(ctx context.Context, inviteEventID string) error RetireInviteEvent(ctx context.Context, inviteEventID string) error
SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn)
AddTypingUser(userID, roomID string, expireTime *time.Time) int64 AddTypingUser(userID, roomID string, expireTime *time.Time) types.StreamPosition
RemoveTypingUser(userID, roomID string) int64 RemoveTypingUser(userID, roomID string) types.StreamPosition
GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error)
EventsAtTopologicalPosition(ctx context.Context, roomID string, pos types.StreamPosition) ([]types.StreamEvent, error)
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error)
MaxTopologicalPosition(ctx context.Context, roomID string) (types.StreamPosition, error)
StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event
SyncStreamPosition(ctx context.Context) (types.StreamPosition, error)
} }
// NewPublicRoomsServerDatabase opens a database connection. // NewPublicRoomsServerDatabase opens a database connection.

View File

@ -36,7 +36,7 @@ type Notifier struct {
// Protects currPos and userStreams. // Protects currPos and userStreams.
streamLock *sync.Mutex streamLock *sync.Mutex
// The latest sync position // The latest sync position
currPos types.SyncPosition currPos types.PaginationToken
// A map of user_id => UserStream which can be used to wake a given user's /sync request. // A map of user_id => UserStream which can be used to wake a given user's /sync request.
userStreams map[string]*UserStream userStreams map[string]*UserStream
// The last time we cleaned out stale entries from the userStreams map // The last time we cleaned out stale entries from the userStreams map
@ -46,7 +46,7 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position. // NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and // In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier(pos types.SyncPosition) *Notifier { func NewNotifier(pos types.PaginationToken) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: pos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
@ -68,7 +68,7 @@ func NewNotifier(pos types.SyncPosition) *Notifier {
// event type it handles, leaving other fields as 0. // event type it handles, leaving other fields as 0.
func (n *Notifier) OnNewEvent( func (n *Notifier) OnNewEvent(
ev *gomatrixserverlib.Event, roomID string, userIDs []string, ev *gomatrixserverlib.Event, roomID string, userIDs []string,
posUpdate types.SyncPosition, posUpdate types.PaginationToken,
) { ) {
// update the current position then notify relevant /sync streams. // update the current position then notify relevant /sync streams.
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
@ -151,7 +151,7 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
} }
// CurrentPosition returns the current sync position // CurrentPosition returns the current sync position
func (n *Notifier) CurrentPosition() types.SyncPosition { func (n *Notifier) CurrentPosition() types.PaginationToken {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
@ -173,7 +173,7 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
} }
} }
func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) { func (n *Notifier) wakeupUsers(userIDs []string, newPos types.PaginationToken) {
for _, userID := range userIDs { for _, userID := range userIDs {
stream := n.fetchUserStream(userID, false) stream := n.fetchUserStream(userID, false)
if stream != nil { if stream != nil {

View File

@ -32,11 +32,11 @@ var (
randomMessageEvent gomatrixserverlib.Event randomMessageEvent gomatrixserverlib.Event
aliceInviteBobEvent gomatrixserverlib.Event aliceInviteBobEvent gomatrixserverlib.Event
bobLeaveEvent gomatrixserverlib.Event bobLeaveEvent gomatrixserverlib.Event
syncPositionVeryOld types.SyncPosition syncPositionVeryOld types.PaginationToken
syncPositionBefore types.SyncPosition syncPositionBefore types.PaginationToken
syncPositionAfter types.SyncPosition syncPositionAfter types.PaginationToken
syncPositionNewEDU types.SyncPosition syncPositionNewEDU types.PaginationToken
syncPositionAfter2 types.SyncPosition syncPositionAfter2 types.PaginationToken
) )
var ( var (
@ -46,9 +46,9 @@ var (
) )
func init() { func init() {
baseSyncPos := types.SyncPosition{ baseSyncPos := types.PaginationToken{
PDUPosition: 0, PDUPosition: 0,
TypingPosition: 0, EDUTypingPosition: 0,
} }
syncPositionVeryOld = baseSyncPos syncPositionVeryOld = baseSyncPos
@ -61,7 +61,7 @@ func init() {
syncPositionAfter.PDUPosition = 12 syncPositionAfter.PDUPosition = 12
syncPositionNewEDU = syncPositionAfter syncPositionNewEDU = syncPositionAfter
syncPositionNewEDU.TypingPosition = 1 syncPositionNewEDU.EDUTypingPosition = 1
syncPositionAfter2 = baseSyncPos syncPositionAfter2 = baseSyncPos
syncPositionAfter2.PDUPosition = 13 syncPositionAfter2.PDUPosition = 13
@ -119,7 +119,7 @@ func TestImmediateNotification(t *testing.T) {
t.Fatalf("TestImmediateNotification error: %s", err) t.Fatalf("TestImmediateNotification error: %s", err)
} }
if pos != syncPositionBefore { if pos != syncPositionBefore {
t.Fatalf("TestImmediateNotification want %d, got %d", syncPositionBefore, pos) t.Fatalf("TestImmediateNotification want %v, got %v", syncPositionBefore, pos)
} }
} }
@ -138,7 +138,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndJoinedToRoom error: %s", err)
} }
if pos != syncPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", syncPositionAfter, pos) t.Errorf("TestNewEventAndJoinedToRoom want %v, got %v", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
}() }()
@ -166,7 +166,7 @@ func TestNewInviteEventForUser(t *testing.T) {
t.Errorf("TestNewInviteEventForUser error: %s", err) t.Errorf("TestNewInviteEventForUser error: %s", err)
} }
if pos != syncPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionAfter, pos) t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
}() }()
@ -194,7 +194,7 @@ func TestEDUWakeup(t *testing.T) {
t.Errorf("TestNewInviteEventForUser error: %s", err) t.Errorf("TestNewInviteEventForUser error: %s", err)
} }
if pos != syncPositionNewEDU { if pos != syncPositionNewEDU {
t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionNewEDU, pos) t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionNewEDU, pos)
} }
wg.Done() wg.Done()
}() }()
@ -222,7 +222,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
t.Errorf("TestMultipleRequestWakeup error: %s", err) t.Errorf("TestMultipleRequestWakeup error: %s", err)
} }
if pos != syncPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestMultipleRequestWakeup want %d, got %d", syncPositionAfter, pos) t.Errorf("TestMultipleRequestWakeup want %v, got %v", syncPositionAfter, pos)
} }
wg.Done() wg.Done()
} }
@ -262,7 +262,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
} }
if pos != syncPositionAfter { if pos != syncPositionAfter {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter, pos) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter, pos)
} }
leaveWG.Done() leaveWG.Done()
}() }()
@ -281,7 +281,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
} }
if pos != syncPositionAfter2 { if pos != syncPositionAfter2 {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter2, pos) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter2, pos)
} }
aliceWG.Done() aliceWG.Done()
}() }()
@ -305,14 +305,14 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
func waitForEvents(n *Notifier, req syncRequest) (types.SyncPosition, error) { func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) {
listener := n.GetListener(req) listener := n.GetListener(req)
defer listener.Close() defer listener.Close()
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
return types.SyncPosition{}, fmt.Errorf( return types.PaginationToken{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since, "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(*req.since):
p := listener.GetSyncPosition() p := listener.GetSyncPosition()
@ -337,7 +337,7 @@ func lockedFetchUserStream(n *Notifier, userID string) *UserStream {
return n.fetchUserStream(userID, true) return n.fetchUserStream(userID, true)
} }
func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest { func newTestSyncRequest(userID string, since types.PaginationToken) syncRequest {
return syncRequest{ return syncRequest{
device: authtypes.Device{UserID: userID}, device: authtypes.Device{UserID: userID},
timeout: 1 * time.Minute, timeout: 1 * time.Minute,

View File

@ -16,10 +16,8 @@ package sync
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -38,7 +36,7 @@ type syncRequest struct {
device authtypes.Device device authtypes.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.SyncPosition // nil means that no since token was supplied since *types.PaginationToken // nil means that no since token was supplied
wantFullState bool wantFullState bool
log *log.Entry log *log.Entry
} }
@ -47,7 +45,7 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e
timeout := getTimeout(req.URL.Query().Get("timeout")) timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state") fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false" wantFullState := fullState != "" && fullState != "false"
since, err := getSyncStreamPosition(req.URL.Query().Get("since")) since, err := getPaginationToken(req.URL.Query().Get("since"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,41 +73,14 @@ func getTimeout(timeoutMS string) time.Duration {
} }
// getSyncStreamPosition tries to parse a 'since' token taken from the API to a // getSyncStreamPosition tries to parse a 'since' token taken from the API to a
// types.SyncPosition. If the string is empty then (nil, nil) is returned. // types.PaginationToken. If the string is empty then (nil, nil) is returned.
// There are two forms of tokens: The full length form containing all PDU and EDU // There are two forms of tokens: The full length form containing all PDU and EDU
// positions separated by "_", and the short form containing only the PDU // positions separated by "_", and the short form containing only the PDU
// position. Short form can be used for, e.g., `prev_batch` tokens. // position. Short form can be used for, e.g., `prev_batch` tokens.
func getSyncStreamPosition(since string) (*types.SyncPosition, error) { func getPaginationToken(since string) (*types.PaginationToken, error) {
if since == "" { if since == "" {
return nil, nil return nil, nil
} }
posStrings := strings.Split(since, "_") return types.NewPaginationTokenFromString(since)
if len(posStrings) != 2 && len(posStrings) != 1 {
// A token can either be full length or short (PDU-only).
return nil, errors.New("malformed batch token")
}
positions := make([]int64, len(posStrings))
for i, posString := range posStrings {
pos, err := strconv.ParseInt(posString, 10, 64)
if err != nil {
return nil, err
}
positions[i] = pos
}
if len(positions) == 2 {
// Full length token; construct SyncPosition with every entry in
// `positions`. These entries must have the same order with the fields
// in struct SyncPosition, so we disable the govet check below.
return &types.SyncPosition{ //nolint:govet
positions[0], positions[1],
}, nil
} else {
// Token with PDU position only
return &types.SyncPosition{
PDUPosition: positions[0],
}, nil
}
} }

View File

@ -130,7 +130,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
} }
} }
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncPosition) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.PaginationToken) (res *types.Response, err error) {
// TODO: handle ignored users // TODO: handle ignored users
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
@ -143,7 +143,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncP
} }
accountDataFilter := gomatrix.DefaultFilterPart() // TODO: use filter provided in req instead accountDataFilter := gomatrix.DefaultFilterPart() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) res, err = rp.appendAccountData(res, req.device.UserID, req, int64(latestPos.PDUPosition), &accountDataFilter)
return return
} }
@ -183,7 +183,11 @@ func (rp *RequestPool) appendAccountData(
} }
// Sync is not initial, get all account data since the latest sync // Sync is not initial, get all account data since the latest sync
dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since.PDUPosition, currentPos, accountDataFilter) dataTypes, err := rp.db.GetAccountDataInRange(
req.ctx, userID,
types.StreamPosition(req.since.PDUPosition), types.StreamPosition(currentPos),
accountDataFilter,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -35,7 +35,7 @@ type UserStream struct {
// Closed when there is an update. // Closed when there is an update.
signalChannel chan struct{} signalChannel chan struct{}
// The last sync position that there may have been an update for the user // The last sync position that there may have been an update for the user
pos types.SyncPosition pos types.PaginationToken
// The last time when we had some listeners waiting // The last time when we had some listeners waiting
timeOfLastChannel time.Time timeOfLastChannel time.Time
// The number of listeners waiting // The number of listeners waiting
@ -51,7 +51,7 @@ type UserStreamListener struct {
} }
// NewUserStream creates a new user stream // NewUserStream creates a new user stream
func NewUserStream(userID string, currPos types.SyncPosition) *UserStream { func NewUserStream(userID string, currPos types.PaginationToken) *UserStream {
return &UserStream{ return &UserStream{
UserID: userID, UserID: userID,
timeOfLastChannel: time.Now(), timeOfLastChannel: time.Now(),
@ -85,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
} }
// Broadcast a new sync position for this user. // Broadcast a new sync position for this user.
func (s *UserStream) Broadcast(pos types.SyncPosition) { func (s *UserStream) Broadcast(pos types.PaginationToken) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -120,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time {
// GetStreamPosition returns last sync position which the UserStream was // GetStreamPosition returns last sync position which the UserStream was
// notified about // notified about
func (s *UserStreamListener) GetSyncPosition() types.SyncPosition { func (s *UserStreamListener) GetSyncPosition() types.PaginationToken {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
@ -132,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.SyncPosition {
// sincePos specifies from which point we want to be notified about. If there // sincePos specifies from which point we want to be notified about. If there
// has already been an update after sincePos we'll return a closed channel // has already been an update after sincePos we'll return a closed channel
// immediately. // immediately.
func (s *UserStreamListener) GetNotifyChannel(sincePos types.SyncPosition) <-chan struct{} { func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <-chan struct{} {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()

View File

@ -21,7 +21,9 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/common/basecomponent" "github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/syncapi/consumers" "github.com/matrix-org/dendrite/syncapi/consumers"
@ -37,6 +39,8 @@ func SetupSyncAPIComponent(
deviceDB *devices.Database, deviceDB *devices.Database,
accountsDB *accounts.Database, accountsDB *accounts.Database,
queryAPI api.RoomserverQueryAPI, queryAPI api.RoomserverQueryAPI,
federation *gomatrixserverlib.FederationClient,
cfg *config.Dendrite,
) { ) {
syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI)) syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI))
if err != nil { if err != nil {
@ -77,5 +81,5 @@ func SetupSyncAPIComponent(
logrus.WithError(err).Panicf("failed to start typing server consumer") logrus.WithError(err).Panicf("failed to start typing server consumer")
} }
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB) routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, federation, queryAPI, cfg)
} }

View File

@ -16,45 +16,144 @@ package types
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt"
"strconv" "strconv"
"strings"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// SyncPosition contains the PDU and EDU stream sync positions for a client. var (
type SyncPosition struct { // ErrInvalidPaginationTokenType is returned when an attempt at creating a
// PDUPosition is the stream position for PDUs the client is at. // new instance of PaginationToken with an invalid type (i.e. neither "s"
PDUPosition int64 // nor "t").
// TypingPosition is the client's position for typing notifications. ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)")
TypingPosition int64 // ErrInvalidPaginationTokenLen is returned when the pagination token is an
// invalid length
ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length")
)
// StreamPosition represents the offset in the sync stream a client is at.
type StreamPosition int64
// Same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamEvent struct {
gomatrixserverlib.Event
StreamPosition StreamPosition
TransactionID *api.TransactionID
ExcludeFromSync bool
} }
// String implements the Stringer interface. // PaginationTokenType represents the type of a pagination token.
func (sp SyncPosition) String() string { // It can be either "s" (representing a position in the whole stream of events)
return strconv.FormatInt(sp.PDUPosition, 10) + "_" + // or "t" (representing a position in a room's topology/depth).
strconv.FormatInt(sp.TypingPosition, 10) type PaginationTokenType string
const (
// PaginationTokenTypeStream represents a position in the server's whole
// stream of events
PaginationTokenTypeStream PaginationTokenType = "s"
// PaginationTokenTypeTopology represents a position in a room's topology.
PaginationTokenTypeTopology PaginationTokenType = "t"
)
// PaginationToken represents a pagination token, used for interactions with
// /sync or /messages, for example.
type PaginationToken struct {
//Position StreamPosition
Type PaginationTokenType
PDUPosition StreamPosition
EDUTypingPosition StreamPosition
} }
// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition. // NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x"
func (sp SyncPosition) IsAfter(other SyncPosition) bool { // represents the type of a pagination token and "yyyy..." the token itself, and
return sp.PDUPosition > other.PDUPosition || // parses it in order to create a new instance of PaginationToken. Returns an
sp.TypingPosition > other.TypingPosition // error if the token couldn't be parsed into an int64, or if the token type
// isn't a known type (returns ErrInvalidPaginationTokenType in the latter
// case).
func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) {
if len(s) == 0 {
return nil, ErrInvalidPaginationTokenLen
}
token = new(PaginationToken)
var positions []string
switch t := PaginationTokenType(s[:1]); t {
case PaginationTokenTypeStream, PaginationTokenTypeTopology:
token.Type = t
positions = strings.Split(s[1:], "_")
default:
token.Type = PaginationTokenTypeStream
positions = strings.Split(s, "_")
}
// Try to get the PDU position.
if len(positions) >= 1 {
if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
return nil, err
} else if pduPos < 0 {
return nil, errors.New("negative PDU position not allowed")
} else {
token.PDUPosition = StreamPosition(pduPos)
}
}
// Try to get the typing position.
if len(positions) >= 2 {
if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
return nil, err
} else if typPos < 0 {
return nil, errors.New("negative EDU typing position not allowed")
} else {
token.EDUTypingPosition = StreamPosition(typPos)
}
}
return
} }
// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition. // NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a
// If the latter SyncPosition contains a field that is not 0, it is considered an update, // StreamPosition and returns an instance of PaginationToken.
// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called. func NewPaginationTokenFromTypeAndPosition(
func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition { t PaginationTokenType, pdupos StreamPosition, typpos StreamPosition,
ret := sp ) (p *PaginationToken) {
return &PaginationToken{
Type: t,
PDUPosition: pdupos,
EDUTypingPosition: typpos,
}
}
// String translates a PaginationToken to a string of the "xyyyy..." (see
// NewPaginationToken to know what it represents).
func (p *PaginationToken) String() string {
return fmt.Sprintf("%s%d_%d", p.Type, p.PDUPosition, p.EDUTypingPosition)
}
// WithUpdates returns a copy of the PaginationToken with updates applied from another PaginationToken.
// If the latter PaginationToken contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the PaginationToken on which WithUpdates is called.
func (pt *PaginationToken) WithUpdates(other PaginationToken) PaginationToken {
ret := *pt
if other.PDUPosition != 0 { if other.PDUPosition != 0 {
ret.PDUPosition = other.PDUPosition ret.PDUPosition = other.PDUPosition
} }
if other.TypingPosition != 0 { if other.EDUTypingPosition != 0 {
ret.TypingPosition = other.TypingPosition ret.EDUTypingPosition = other.EDUTypingPosition
} }
return ret return ret
} }
// IsAfter returns whether one PaginationToken refers to states newer than another PaginationToken.
func (sp *PaginationToken) IsAfter(other PaginationToken) bool {
return sp.PDUPosition > other.PDUPosition ||
sp.EDUTypingPosition > other.EDUTypingPosition
}
// PrevEventRef represents a reference to a previous event in a state event upgrade // PrevEventRef represents a reference to a previous event in a state event upgrade
type PrevEventRef struct { type PrevEventRef struct {
PrevContent json.RawMessage `json:"prev_content"` PrevContent json.RawMessage `json:"prev_content"`
@ -79,9 +178,9 @@ type Response struct {
} }
// NewResponse creates an empty response with initialised maps. // NewResponse creates an empty response with initialised maps.
func NewResponse(pos SyncPosition) *Response { func NewResponse(token PaginationToken) *Response {
res := Response{ res := Response{
NextBatch: pos.String(), NextBatch: token.String(),
} }
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
@ -96,6 +195,14 @@ func NewResponse(pos SyncPosition) *Response {
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
// Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume
// we'll always return a stream token.
res.NextBatch = NewPaginationTokenFromTypeAndPosition(
PaginationTokenTypeStream,
StreamPosition(token.PDUPosition),
StreamPosition(token.EDUTypingPosition),
).String()
return &res return &res
} }

View File

@ -0,0 +1,52 @@
package types
import "testing"
func TestNewPaginationTokenFromString(t *testing.T) {
shouldPass := map[string]PaginationToken{
"2": PaginationToken{
Type: PaginationTokenTypeStream,
PDUPosition: 2,
},
"s4": PaginationToken{
Type: PaginationTokenTypeStream,
PDUPosition: 4,
},
"s3_1": PaginationToken{
Type: PaginationTokenTypeStream,
PDUPosition: 3,
EDUTypingPosition: 1,
},
"t3_1_4": PaginationToken{
Type: PaginationTokenTypeTopology,
PDUPosition: 3,
EDUTypingPosition: 1,
},
}
shouldFail := []string{
"",
"s_1",
"s_",
"a3_4",
"b",
"b-1",
"-4",
}
for test, expected := range shouldPass {
result, err := NewPaginationTokenFromString(test)
if err != nil {
t.Error(err)
}
if *result != expected {
t.Errorf("expected %v but got %v", expected.String(), result.String())
}
}
for _, test := range shouldFail {
if _, err := NewPaginationTokenFromString(test); err == nil {
t.Errorf("input '%v' should have errored but didn't", test)
}
}
}

View File

@ -12,3 +12,7 @@ Room members can override their displayname on a room-specific basis
# Blacklisted due to flakiness # Blacklisted due to flakiness
Alias creators can delete alias with no ops Alias creators can delete alias with no ops
# Blacklisted because matrix-org/dendrite#847 might have broken it but we're not
# really sure and we need it pretty badly anyway
Real non-joined users can get individual state for world_readable rooms after leaving

View File

@ -74,7 +74,7 @@ Real non-joined user cannot call /events on joined room
Real non-joined user cannot call /events on default room Real non-joined user cannot call /events on default room
Real non-joined users can get state for world_readable rooms Real non-joined users can get state for world_readable rooms
Real non-joined users can get individual state for world_readable rooms Real non-joined users can get individual state for world_readable rooms
Real non-joined users can get individual state for world_readable rooms after leaving #Real non-joined users can get individual state for world_readable rooms after leaving
Real non-joined users cannot send messages to guest_access rooms if not joined Real non-joined users cannot send messages to guest_access rooms if not joined
Real users can sync from world_readable guest_access rooms if joined Real users can sync from world_readable guest_access rooms if joined
Real users can sync from default guest_access rooms if joined Real users can sync from default guest_access rooms if joined
@ -206,3 +206,7 @@ remote user can join room with version 5
Inbound federation can query room alias directory Inbound federation can query room alias directory
Outbound federation can query v2 /send_join Outbound federation can query v2 /send_join
Inbound federation can receive v2 /send_join Inbound federation can receive v2 /send_join
Message history can be paginated
Getting messages going forward is limited for a departed room (SPEC-216)
m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users
Backfill works correctly with history visibility set to joined