dendrite/roomserver/internal/helpers/helpers.go

327 lines
10 KiB
Go
Raw Normal View History

package helpers
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// TODO: temporary package which has helper functions used by both internal/perform packages.
// Move these to a more sensible place.
func UpdateToInviteMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
roomVersion gomatrixserverlib.RoomVersion,
) ([]api.OutputEvent, error) {
// We may have already sent the invite to the user, either because we are
// reprocessing this event, or because the we received this invite from a
// remote server via the federation invite API. In those cases we don't need
// to send the event.
needsSending, err := mu.SetToInvite(*add)
if err != nil {
return nil, err
}
if needsSending {
// We notify the consumers using a special event even though we will
// notify them about the change in current state as part of the normal
// room event stream. This ensures that the consumers only have to
// consider a single stream of events when determining whether a user
// is invited, rather than having to combine multiple streams themselves.
onie := api.OutputNewInviteEvent{
Event: add.Headered(roomVersion),
RoomVersion: roomVersion,
}
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeNewInviteEvent,
NewInviteEvent: &onie,
})
}
return updates, nil
}
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
info, err := db.RoomInfo(ctx, roomID)
if err != nil {
return false, err
}
if info == nil {
return false, fmt.Errorf("unknown room %s", roomID)
}
eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
if err != nil {
return false, err
}
events, err := db.Events(ctx, eventNIDs)
if err != nil {
return false, err
}
gmslEvents := make([]gomatrixserverlib.Event, len(events))
for i := range events {
gmslEvents[i] = events[i].Event
}
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
}
func IsInvitePending(
ctx context.Context, db storage.Database,
roomID, userID string,
) (bool, string, string, error) {
// Look up the room NID for the supplied room ID.
info, err := db.RoomInfo(ctx, roomID)
if err != nil {
return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
}
if info == nil {
return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
}
// Look up the state key NID for the supplied user ID.
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
if err != nil {
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
}
targetUserNID, targetUserFound := targetUserNIDs[userID]
if !targetUserFound {
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
}
// Let's see if we have an event active for the user in the room. If
// we do then it will contain a server name that we can direct the
// send_leave to.
senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
if err != nil {
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
}
if len(senderUserNIDs) == 0 {
return false, "", "", nil
}
userNIDToEventID := make(map[types.EventStateKeyNID]string)
for i, nid := range senderUserNIDs {
userNIDToEventID[nid] = eventIDs[i]
}
// Look up the user ID from the NID.
senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
if err != nil {
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
}
if len(senderUsers) == 0 {
return false, "", "", fmt.Errorf("no senderUsers")
}
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
if !senderUserFound {
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
}
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
}
// GetMembershipsAtState filters the state events to
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func GetMembershipsAtState(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) {
var eventNIDs []types.EventNID
for _, entry := range stateEntries {
// Filter the events to retrieve to only keep the membership events
if entry.EventTypeNID == types.MRoomMemberNID {
eventNIDs = append(eventNIDs, entry.EventNID)
}
}
// Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
if !joinedOnly {
return stateEvents, nil
}
// Filter the events to only keep the "join" membership events
var events []types.Event
for _, event := range stateEvents {
membership, err := event.Membership()
if err != nil {
return nil, err
}
if membership == gomatrixserverlib.Join {
events = append(events, event)
}
}
return events, nil
}
func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
return nil, err
}
eventIDs := []string{eIDs[eventNID]}
prevState, err := db.StateAtEventIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
// Fetch the state as it was when this event was fired
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.Event, error) {
stateEvents, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
result := make([]gomatrixserverlib.Event, len(stateEvents))
for i := range stateEvents {
result[i] = stateEvents[i].Event
}
return result, nil
}
func LoadStateEvents(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
) ([]gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID
}
return LoadEvents(ctx, db, eventNIDs)
}
func CheckServerAllowedToSeeEvent(
ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) {
roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
return false, err
}
// TODO: We probably want to make it so that we don't have to pull
// out all the state if possible.
stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries)
if err != nil {
return false, err
}
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
}
// TODO: Remove this when we have tests to assert correctness of this function
// nolint:gocyclo
func ScanEventTree(
ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName,
) ([]types.EventNID, error) {
var resultNIDs []types.EventNID
var err error
var allowed bool
var events []types.Event
var next []string
var pre string
// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
// Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
// duplicate events being sent in response to /backfill requests.
initialIgnoreList := make(map[string]bool, len(visited))
for k, v := range visited {
initialIgnoreList[k] = v
}
resultNIDs = make([]types.EventNID, 0, limit)
var checkedServerInRoom bool
var isServerInRoom bool
// Loop through the event IDs to retrieve the requested events and go
// through the whole tree (up to the provided limit) using the events'
// "prev_event" key.
BFSLoop:
for len(front) > 0 {
// Prevent unnecessary allocations: reset the slice only when not empty.
if len(next) > 0 {
next = make([]string, 0)
}
// Retrieve the events to process from the database.
events, err = db.EventsFromIDs(ctx, front)
if err != nil {
return resultNIDs, err
}
if !checkedServerInRoom && len(events) > 0 {
// It's nasty that we have to extract the room ID from an event, but many federation requests
// only talk in event IDs, no room IDs at all (!!!)
ev := events[0]
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
}
checkedServerInRoom = true
}
for _, ev := range events {
// Break out of the loop if the provided limit is reached.
if len(resultNIDs) == limit {
break BFSLoop
}
if !initialIgnoreList[ev.EventID()] {
// Update the list of events to retrieve.
resultNIDs = append(resultNIDs, ev.EventNID)
}
// Loop through the event's parents.
for _, pre = range ev.PrevEventIDs() {
// Only add an event to the list of next events to process if it
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",
)
return resultNIDs, err
}
// If the event hasn't been seen before and the HS
// requesting to retrieve it is allowed to do so, add it to
// the list of events to retrieve.
if allowed {
next = append(next, pre)
} else {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
}
}
}
}
// Repeat the same process with the parent events we just processed.
front = next
}
return resultNIDs, err
}