mirror of
https://github.com/1f349/dendrite.git
synced 2025-01-21 23:06:32 +00:00
Move currentstateserver API to roomserver (#1387)
* Move currentstateserver API to roomserver Stub out DB functions for now, nothing uses the roomserver version yet. * Allow it to startup * Implement some current-state-server storage interface functions * Add missing package
This commit is contained in:
parent
6150de6cb3
commit
b20386123e
@ -23,17 +23,25 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/dendrite/currentstateserver/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ServerACLDatabase interface {
|
||||
// GetKnownRooms returns a list of all rooms we know about.
|
||||
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||
// If no event could be found, returns nil
|
||||
// If there was an issue during the retrieval, returns an error
|
||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||
}
|
||||
|
||||
type ServerACLs struct {
|
||||
acls map[string]*serverACL // room ID -> ACL
|
||||
aclsMutex sync.RWMutex // protects the above
|
||||
}
|
||||
|
||||
func NewServerACLs(db storage.Database) *ServerACLs {
|
||||
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||
ctx := context.TODO()
|
||||
acls := &ServerACLs{
|
||||
acls: make(map[string]*serverACL),
|
||||
|
@ -296,6 +296,30 @@ func (t *testRoomserverAPI) RemoveRoomAlias(
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (t *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (t *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (t *testRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (t *testRoomserverAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (t *testRoomserverAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testStateAPI struct {
|
||||
}
|
||||
|
||||
|
164
roomserver/acls/acls.go
Normal file
164
roomserver/acls/acls.go
Normal file
@ -0,0 +1,164 @@
|
||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package acls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type ServerACLDatabase interface {
|
||||
// GetKnownRooms returns a list of all rooms we know about.
|
||||
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||
// If no event could be found, returns nil
|
||||
// If there was an issue during the retrieval, returns an error
|
||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||
}
|
||||
|
||||
type ServerACLs struct {
|
||||
acls map[string]*serverACL // room ID -> ACL
|
||||
aclsMutex sync.RWMutex // protects the above
|
||||
}
|
||||
|
||||
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||
ctx := context.TODO()
|
||||
acls := &ServerACLs{
|
||||
acls: make(map[string]*serverACL),
|
||||
}
|
||||
// Look up all of the rooms that the current state server knows about.
|
||||
rooms, err := db.GetKnownRooms(ctx)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatalf("Failed to get known rooms")
|
||||
}
|
||||
// For each room, let's see if we have a server ACL state event. If we
|
||||
// do then we'll process it into memory so that we have the regexes to
|
||||
// hand.
|
||||
for _, room := range rooms {
|
||||
state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "")
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room)
|
||||
continue
|
||||
}
|
||||
if state != nil {
|
||||
acls.OnServerACLUpdate(&state.Event)
|
||||
}
|
||||
}
|
||||
return acls
|
||||
}
|
||||
|
||||
type ServerACL struct {
|
||||
Allowed []string `json:"allow"`
|
||||
Denied []string `json:"deny"`
|
||||
AllowIPLiterals bool `json:"allow_ip_literals"`
|
||||
}
|
||||
|
||||
type serverACL struct {
|
||||
ServerACL
|
||||
allowedRegexes []*regexp.Regexp
|
||||
deniedRegexes []*regexp.Regexp
|
||||
}
|
||||
|
||||
func compileACLRegex(orig string) (*regexp.Regexp, error) {
|
||||
escaped := regexp.QuoteMeta(orig)
|
||||
escaped = strings.Replace(escaped, "\\?", ".", -1)
|
||||
escaped = strings.Replace(escaped, "\\*", ".*", -1)
|
||||
return regexp.Compile(escaped)
|
||||
}
|
||||
|
||||
func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) {
|
||||
acls := &serverACL{}
|
||||
if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs")
|
||||
return
|
||||
}
|
||||
// The spec calls only for * (zero or more chars) and ? (exactly one char)
|
||||
// to be supported as wildcard components, so we will escape all of the regex
|
||||
// special characters and then replace * and ? with their regex counterparts.
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
|
||||
for _, orig := range acls.Allowed {
|
||||
if expr, err := compileACLRegex(orig); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to compile allowed regex")
|
||||
} else {
|
||||
acls.allowedRegexes = append(acls.allowedRegexes, expr)
|
||||
}
|
||||
}
|
||||
for _, orig := range acls.Denied {
|
||||
if expr, err := compileACLRegex(orig); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to compile denied regex")
|
||||
} else {
|
||||
acls.deniedRegexes = append(acls.deniedRegexes, expr)
|
||||
}
|
||||
}
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"allow_ip_literals": acls.AllowIPLiterals,
|
||||
"num_allowed": len(acls.allowedRegexes),
|
||||
"num_denied": len(acls.deniedRegexes),
|
||||
}).Debugf("Updating server ACLs for %q", state.RoomID())
|
||||
s.aclsMutex.Lock()
|
||||
defer s.aclsMutex.Unlock()
|
||||
s.acls[state.RoomID()] = acls
|
||||
}
|
||||
|
||||
func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool {
|
||||
s.aclsMutex.RLock()
|
||||
// First of all check if we have an ACL for this room. If we don't then
|
||||
// no servers are banned from the room.
|
||||
acls, ok := s.acls[roomID]
|
||||
if !ok {
|
||||
s.aclsMutex.RUnlock()
|
||||
return false
|
||||
}
|
||||
s.aclsMutex.RUnlock()
|
||||
// Split the host and port apart. This is because the spec calls on us to
|
||||
// validate the hostname only in cases where the port is also present.
|
||||
if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil {
|
||||
serverName = gomatrixserverlib.ServerName(serverNameOnly)
|
||||
}
|
||||
// Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding
|
||||
// a /0 prefix length just to trick ParseCIDR into working. If we find that
|
||||
// the server is an IP literal and we don't allow those then stop straight
|
||||
// away.
|
||||
if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil {
|
||||
if !acls.AllowIPLiterals {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check if the hostname matches one of the denied regexes. If it does then
|
||||
// the server is banned from the room.
|
||||
for _, expr := range acls.deniedRegexes {
|
||||
if expr.MatchString(string(serverName)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check if the hostname matches one of the allowed regexes. If it does then
|
||||
// the server is NOT banned from the room.
|
||||
for _, expr := range acls.allowedRegexes {
|
||||
if expr.MatchString(string(serverName)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// If we've got to this point then we haven't matched any regexes or an IP
|
||||
// hostname if disallowed. The spec calls for default-deny here.
|
||||
return true
|
||||
}
|
105
roomserver/acls/acls_test.go
Normal file
105
roomserver/acls/acls_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package acls
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenACLsWithBlacklist(t *testing.T) {
|
||||
roomID := "!test:test.com"
|
||||
allowRegex, err := compileACLRegex("*")
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
denyRegex, err := compileACLRegex("foo.com")
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
acls := ServerACLs{
|
||||
acls: make(map[string]*serverACL),
|
||||
}
|
||||
|
||||
acls.acls[roomID] = &serverACL{
|
||||
ServerACL: ServerACL{
|
||||
AllowIPLiterals: true,
|
||||
},
|
||||
allowedRegexes: []*regexp.Regexp{allowRegex},
|
||||
deniedRegexes: []*regexp.Regexp{denyRegex},
|
||||
}
|
||||
|
||||
if acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
|
||||
t.Fatal("Expected 1.2.3.4 to be allowed but wasn't")
|
||||
}
|
||||
if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
|
||||
t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't")
|
||||
}
|
||||
if !acls.IsServerBannedFromRoom("foo.com", roomID) {
|
||||
t.Fatal("Expected foo.com to be banned but wasn't")
|
||||
}
|
||||
if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
|
||||
t.Fatal("Expected foo.com:3456 to be banned but wasn't")
|
||||
}
|
||||
if acls.IsServerBannedFromRoom("bar.com", roomID) {
|
||||
t.Fatal("Expected bar.com to be allowed but wasn't")
|
||||
}
|
||||
if acls.IsServerBannedFromRoom("bar.com:4567", roomID) {
|
||||
t.Fatal("Expected bar.com:4567 to be allowed but wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultACLsWithWhitelist(t *testing.T) {
|
||||
roomID := "!test:test.com"
|
||||
allowRegex, err := compileACLRegex("foo.com")
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
acls := ServerACLs{
|
||||
acls: make(map[string]*serverACL),
|
||||
}
|
||||
|
||||
acls.acls[roomID] = &serverACL{
|
||||
ServerACL: ServerACL{
|
||||
AllowIPLiterals: false,
|
||||
},
|
||||
allowedRegexes: []*regexp.Regexp{allowRegex},
|
||||
deniedRegexes: []*regexp.Regexp{},
|
||||
}
|
||||
|
||||
if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
|
||||
t.Fatal("Expected 1.2.3.4 to be banned but wasn't")
|
||||
}
|
||||
if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
|
||||
t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't")
|
||||
}
|
||||
if acls.IsServerBannedFromRoom("foo.com", roomID) {
|
||||
t.Fatal("Expected foo.com to be allowed but wasn't")
|
||||
}
|
||||
if acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
|
||||
t.Fatal("Expected foo.com:3456 to be allowed but wasn't")
|
||||
}
|
||||
if !acls.IsServerBannedFromRoom("bar.com", roomID) {
|
||||
t.Fatal("Expected bar.com to be allowed but wasn't")
|
||||
}
|
||||
if !acls.IsServerBannedFromRoom("baz.com", roomID) {
|
||||
t.Fatal("Expected baz.com to be allowed but wasn't")
|
||||
}
|
||||
if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) {
|
||||
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
|
||||
}
|
||||
}
|
@ -106,6 +106,20 @@ type RoomserverInternalAPI interface {
|
||||
response *QueryStateAndAuthChainResponse,
|
||||
) error
|
||||
|
||||
// QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from
|
||||
// the response.
|
||||
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
|
||||
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
|
||||
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
||||
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
|
||||
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
|
||||
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
|
||||
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
||||
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
|
||||
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
||||
|
||||
// Query a given amount (or less) of events prior to a given set of events.
|
||||
PerformBackfill(
|
||||
ctx context.Context,
|
||||
|
@ -236,6 +236,47 @@ func (t *RoomserverInternalAPITrace) RemoveRoomAlias(
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error {
|
||||
err := t.Impl.QueryCurrentState(ctx, req, res)
|
||||
util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
|
||||
func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error {
|
||||
err := t.Impl.QueryRoomsForUser(ctx, req, res)
|
||||
util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
|
||||
func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error {
|
||||
err := t.Impl.QueryBulkStateContent(ctx, req, res)
|
||||
util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||
func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error {
|
||||
err := t.Impl.QuerySharedUsers(ctx, req, res)
|
||||
util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
||||
func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error {
|
||||
err := t.Impl.QueryKnownUsers(ctx, req, res)
|
||||
util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||
func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error {
|
||||
err := t.Impl.QueryServerBannedFromRoom(ctx, req, res)
|
||||
util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res))
|
||||
return err
|
||||
}
|
||||
|
||||
func js(thing interface{}) string {
|
||||
b, err := json.Marshal(thing)
|
||||
if err != nil {
|
||||
|
@ -17,6 +17,11 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
@ -225,3 +230,102 @@ type QueryPublishedRoomsResponse struct {
|
||||
// The list of published rooms.
|
||||
RoomIDs []string
|
||||
}
|
||||
|
||||
type QuerySharedUsersRequest struct {
|
||||
UserID string
|
||||
ExcludeRoomIDs []string
|
||||
IncludeRoomIDs []string
|
||||
}
|
||||
|
||||
type QuerySharedUsersResponse struct {
|
||||
UserIDsToCount map[string]int
|
||||
}
|
||||
|
||||
type QueryRoomsForUserRequest struct {
|
||||
UserID string
|
||||
// The desired membership of the user. If this is the empty string then no rooms are returned.
|
||||
WantMembership string
|
||||
}
|
||||
|
||||
type QueryRoomsForUserResponse struct {
|
||||
RoomIDs []string
|
||||
}
|
||||
|
||||
type QueryBulkStateContentRequest struct {
|
||||
// Returns state events in these rooms
|
||||
RoomIDs []string
|
||||
// If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*'
|
||||
AllowWildcards bool
|
||||
// The state events to return. Only a small subset of tuples are allowed in this request as only certain events
|
||||
// have their content fields extracted. Specifically, the tuple Type must be one of:
|
||||
// m.room.avatar
|
||||
// m.room.create
|
||||
// m.room.canonical_alias
|
||||
// m.room.guest_access
|
||||
// m.room.history_visibility
|
||||
// m.room.join_rules
|
||||
// m.room.member
|
||||
// m.room.name
|
||||
// m.room.topic
|
||||
// Any other tuple type will result in the query failing.
|
||||
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||
}
|
||||
type QueryBulkStateContentResponse struct {
|
||||
// map of room ID -> tuple -> content_value
|
||||
Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string
|
||||
}
|
||||
|
||||
type QueryCurrentStateRequest struct {
|
||||
RoomID string
|
||||
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||
}
|
||||
|
||||
type QueryCurrentStateResponse struct {
|
||||
StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent
|
||||
}
|
||||
|
||||
type QueryKnownUsersRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
SearchString string `json:"search_string"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
type QueryKnownUsersResponse struct {
|
||||
Users []authtypes.FullyQualifiedProfile `json:"profiles"`
|
||||
}
|
||||
|
||||
type QueryServerBannedFromRoomRequest struct {
|
||||
ServerName gomatrixserverlib.ServerName `json:"server_name"`
|
||||
RoomID string `json:"room_id"`
|
||||
}
|
||||
|
||||
type QueryServerBannedFromRoomResponse struct {
|
||||
Banned bool `json:"banned"`
|
||||
}
|
||||
|
||||
// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
|
||||
func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
|
||||
se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents))
|
||||
for k, v := range r.StateEvents {
|
||||
// use 0x1F (unit separator) as the delimiter between type/state key,
|
||||
se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v
|
||||
}
|
||||
return json.Marshal(se)
|
||||
}
|
||||
|
||||
func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
|
||||
res := make(map[string]*gomatrixserverlib.HeaderedEvent)
|
||||
err := json.Unmarshal(data, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res))
|
||||
for k, v := range res {
|
||||
fields := strings.Split(k, "\x1F")
|
||||
r.StateEvents[gomatrixserverlib.StateKeyTuple{
|
||||
EventType: fields[0],
|
||||
StateKey: fields[1],
|
||||
}] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -133,3 +133,102 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string)
|
||||
}
|
||||
return &res.Events[0]
|
||||
}
|
||||
|
||||
// GetStateEvent returns the current state event in the room or nil.
|
||||
func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent {
|
||||
var res QueryCurrentStateResponse
|
||||
err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{
|
||||
RoomID: roomID,
|
||||
StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
|
||||
}, &res)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed to QueryCurrentState")
|
||||
return nil
|
||||
}
|
||||
ev, ok := res.StateEvents[tuple]
|
||||
if ok {
|
||||
return ev
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs.
|
||||
func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool {
|
||||
req := &QueryServerBannedFromRoomRequest{
|
||||
ServerName: serverName,
|
||||
RoomID: roomID,
|
||||
}
|
||||
res := &QueryServerBannedFromRoomResponse{}
|
||||
if err := rsAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom")
|
||||
return true
|
||||
}
|
||||
return res.Banned
|
||||
}
|
||||
|
||||
// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the
|
||||
// published room directory.
|
||||
// due to lots of switches
|
||||
// nolint:gocyclo
|
||||
func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) {
|
||||
avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""}
|
||||
nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""}
|
||||
canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""}
|
||||
topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""}
|
||||
guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""}
|
||||
visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""}
|
||||
joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""}
|
||||
|
||||
var stateRes QueryBulkStateContentResponse
|
||||
err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{
|
||||
RoomIDs: roomIDs,
|
||||
AllowWildcards: true,
|
||||
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
||||
nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple,
|
||||
{EventType: gomatrixserverlib.MRoomMember, StateKey: "*"},
|
||||
},
|
||||
}, &stateRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed")
|
||||
return nil, err
|
||||
}
|
||||
chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs))
|
||||
i := 0
|
||||
for roomID, data := range stateRes.Rooms {
|
||||
pub := gomatrixserverlib.PublicRoom{
|
||||
RoomID: roomID,
|
||||
}
|
||||
joinCount := 0
|
||||
var joinRule, guestAccess string
|
||||
for tuple, contentVal := range data {
|
||||
if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" {
|
||||
joinCount++
|
||||
continue
|
||||
}
|
||||
switch tuple {
|
||||
case avatarTuple:
|
||||
pub.AvatarURL = contentVal
|
||||
case nameTuple:
|
||||
pub.Name = contentVal
|
||||
case topicTuple:
|
||||
pub.Topic = contentVal
|
||||
case canonicalTuple:
|
||||
pub.CanonicalAlias = contentVal
|
||||
case visibilityTuple:
|
||||
pub.WorldReadable = contentVal == "world_readable"
|
||||
// need both of these to determine whether guests can join
|
||||
case joinRuleTuple:
|
||||
joinRule = contentVal
|
||||
case guestTuple:
|
||||
guestAccess = contentVal
|
||||
}
|
||||
}
|
||||
if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" {
|
||||
pub.GuestCanJoin = true
|
||||
}
|
||||
pub.JoinedMembersCount = joinCount
|
||||
chunk[i] = pub
|
||||
i++
|
||||
}
|
||||
return chunk, nil
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/config"
|
||||
"github.com/matrix-org/dendrite/roomserver/acls"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||
"github.com/matrix-org/dendrite/roomserver/internal/perform"
|
||||
@ -46,8 +47,9 @@ func NewRoomserverAPI(
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
KeyRing: keyRing,
|
||||
Queryer: &query.Queryer{
|
||||
DB: roomserverDB,
|
||||
Cache: caches,
|
||||
DB: roomserverDB,
|
||||
Cache: caches,
|
||||
ServerACLs: acls.NewServerACLs(roomserverDB),
|
||||
},
|
||||
Inputer: &input.Inputer{
|
||||
DB: roomserverDB,
|
||||
|
@ -16,9 +16,12 @@ package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/roomserver/acls"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
@ -31,8 +34,9 @@ import (
|
||||
)
|
||||
|
||||
type Queryer struct {
|
||||
DB storage.Database
|
||||
Cache caching.RoomServerCaches
|
||||
DB storage.Database
|
||||
Cache caching.RoomServerCaches
|
||||
ServerACLs *acls.ServerACLs
|
||||
}
|
||||
|
||||
// QueryLatestEventsAndState implements api.RoomserverInternalAPI
|
||||
@ -502,3 +506,97 @@ func (r *Queryer) QueryPublishedRooms(
|
||||
res.RoomIDs = rooms
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
||||
for _, tuple := range req.StateTuples {
|
||||
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ev != nil {
|
||||
res.StateEvents[tuple] = ev
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.RoomIDs = roomIDs
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
|
||||
users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, user := range users {
|
||||
res.Users = append(res.Users, authtypes.FullyQualifiedProfile{
|
||||
UserID: user,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
|
||||
events, err := r.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
|
||||
for _, ev := range events {
|
||||
if res.Rooms[ev.RoomID] == nil {
|
||||
res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string)
|
||||
}
|
||||
room := res.Rooms[ev.RoomID]
|
||||
room[gomatrixserverlib.StateKeyTuple{
|
||||
EventType: ev.EventType,
|
||||
StateKey: ev.StateKey,
|
||||
}] = ev.ContentValue
|
||||
res.Rooms[ev.RoomID] = room
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomIDs = append(roomIDs, req.IncludeRoomIDs...)
|
||||
excludeMap := make(map[string]bool)
|
||||
for _, roomID := range req.ExcludeRoomIDs {
|
||||
excludeMap[roomID] = true
|
||||
}
|
||||
// filter out excluded rooms
|
||||
j := 0
|
||||
for i := range roomIDs {
|
||||
// move elements to include to the beginning of the slice
|
||||
// then trim elements on the right
|
||||
if !excludeMap[roomIDs[i]] {
|
||||
roomIDs[j] = roomIDs[i]
|
||||
j++
|
||||
}
|
||||
}
|
||||
roomIDs = roomIDs[:j]
|
||||
|
||||
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.UserIDsToCount = users
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
|
||||
if r.ServerACLs == nil {
|
||||
return errors.New("no server ACL tracking")
|
||||
}
|
||||
res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID)
|
||||
return nil
|
||||
}
|
||||
|
@ -43,6 +43,12 @@ const (
|
||||
RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities"
|
||||
RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom"
|
||||
RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms"
|
||||
RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState"
|
||||
RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser"
|
||||
RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent"
|
||||
RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers"
|
||||
RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers"
|
||||
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
|
||||
)
|
||||
|
||||
type httpRoomserverInternalAPI struct {
|
||||
@ -371,3 +377,69 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom(
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *httpRoomserverInternalAPI) QueryCurrentState(
|
||||
ctx context.Context,
|
||||
request *api.QueryCurrentStateRequest,
|
||||
response *api.QueryCurrentStateResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
|
||||
ctx context.Context,
|
||||
request *api.QueryRoomsForUserRequest,
|
||||
response *api.QueryRoomsForUserResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
|
||||
ctx context.Context,
|
||||
request *api.QueryBulkStateContentRequest,
|
||||
response *api.QueryBulkStateContentResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpRoomserverInternalAPI) QuerySharedUsers(
|
||||
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpRoomserverInternalAPI) QueryKnownUsers(
|
||||
ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
|
||||
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
@ -312,4 +312,82 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(RoomserverQueryCurrentStatePath,
|
||||
httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryCurrentStateRequest{}
|
||||
response := api.QueryCurrentStateResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(RoomserverQueryRoomsForUserPath,
|
||||
httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryRoomsForUserRequest{}
|
||||
response := api.QueryRoomsForUserResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(RoomserverQueryBulkStateContentPath,
|
||||
httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryBulkStateContentRequest{}
|
||||
response := api.QueryBulkStateContentResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
|
||||
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QuerySharedUsersRequest{}
|
||||
response := api.QuerySharedUsersResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
|
||||
httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryKnownUsersRequest{}
|
||||
response := api.QueryKnownUsersResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath,
|
||||
httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryServerBannedFromRoomRequest{}
|
||||
response := api.QueryServerBannedFromRoomResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ package storage
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/currentstateserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
@ -138,4 +139,22 @@ type Database interface {
|
||||
PublishRoom(ctx context.Context, roomID string, publish bool) error
|
||||
// Returns a list of room IDs for rooms which are published.
|
||||
GetPublishedRooms(ctx context.Context) ([]string, error)
|
||||
|
||||
// TODO: factor out - from currentstateserver
|
||||
|
||||
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||
// If no event could be found, returns nil
|
||||
// If there was an issue during the retrieval, returns an error
|
||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
|
||||
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
|
||||
// GetKnownUsers searches all users that userID knows about.
|
||||
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
||||
// GetKnownRooms returns a list of all rooms we know about.
|
||||
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
@ -99,6 +99,9 @@ const updateMembershipSQL = "" +
|
||||
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
|
||||
" WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
const selectRoomsWithMembershipSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
||||
|
||||
type membershipStatements struct {
|
||||
insertMembershipStmt *sql.Stmt
|
||||
selectMembershipForUpdateStmt *sql.Stmt
|
||||
@ -108,6 +111,7 @@ type membershipStatements struct {
|
||||
selectMembershipsFromRoomStmt *sql.Stmt
|
||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||
updateMembershipStmt *sql.Stmt
|
||||
selectRoomsWithMembershipStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||
@ -126,6 +130,7 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
@ -222,3 +227,22 @@ func (s *membershipStatements) UpdateMembership(
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
|
||||
var roomNIDs []types.RoomNID
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
if err := rows.Scan(&roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomNIDs = append(roomNIDs, roomNID)
|
||||
}
|
||||
return roomNIDs, nil
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"errors"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
@ -74,6 +75,12 @@ const selectRoomVersionForRoomNIDSQL = "" +
|
||||
const selectRoomInfoSQL = "" +
|
||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
const selectRoomIDsSQL = "" +
|
||||
"SELECT room_id FROM roomserver_rooms"
|
||||
|
||||
const bulkSelectRoomIDsSQL = "" +
|
||||
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
|
||||
type roomStatements struct {
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
selectRoomNIDStmt *sql.Stmt
|
||||
@ -82,6 +89,8 @@ type roomStatements struct {
|
||||
updateLatestEventNIDsStmt *sql.Stmt
|
||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||
selectRoomInfoStmt *sql.Stmt
|
||||
selectRoomIDsStmt *sql.Stmt
|
||||
bulkSelectRoomIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
@ -98,9 +107,27 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
func (s *roomStatements) InsertRoomNID(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
@ -197,3 +224,24 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
|
||||
}
|
||||
return roomVersion, err
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
var array pq.Int64Array
|
||||
for _, nid := range roomNIDs {
|
||||
array = append(array, int64(nid))
|
||||
}
|
||||
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
@ -711,3 +712,82 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
|
||||
}
|
||||
return &evs[0]
|
||||
}
|
||||
|
||||
// GetStateEvent returns the current state event of a given type for a given room with a given state key
|
||||
// If no event could be found, returns nil
|
||||
// If there was an issue during the retrieval, returns an error
|
||||
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||
/*
|
||||
roomInfo, err := d.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
*/
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
|
||||
var membershipState tables.MembershipState
|
||||
switch membership {
|
||||
case "join":
|
||||
membershipState = tables.MembershipStateJoin
|
||||
case "invite":
|
||||
membershipState = tables.MembershipStateInvite
|
||||
case "leave":
|
||||
membershipState = tables.MembershipStateLeaveOrBan
|
||||
case "ban":
|
||||
membershipState = tables.MembershipStateLeaveOrBan
|
||||
default:
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
||||
}
|
||||
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
||||
}
|
||||
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(roomIDs) != len(roomNIDs) {
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs))
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
||||
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||
func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]csstables.StrippedEvent, error) {
|
||||
return nil, fmt.Errorf("not implemented yet")
|
||||
}
|
||||
|
||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||
return nil, fmt.Errorf("not implemented yet")
|
||||
}
|
||||
|
||||
// GetKnownUsers searches all users that userID knows about.
|
||||
func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
|
||||
return nil, fmt.Errorf("not implemented yet")
|
||||
}
|
||||
|
||||
// GetKnownRooms returns a list of all rooms we know about.
|
||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||
return d.RoomsTable.SelectRoomIDs(ctx)
|
||||
}
|
||||
|
@ -75,6 +75,9 @@ const updateMembershipSQL = "" +
|
||||
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" +
|
||||
" WHERE room_nid = $4 AND target_nid = $5"
|
||||
|
||||
const selectRoomsWithMembershipSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
||||
|
||||
type membershipStatements struct {
|
||||
db *sql.DB
|
||||
insertMembershipStmt *sql.Stmt
|
||||
@ -84,6 +87,7 @@ type membershipStatements struct {
|
||||
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
||||
selectMembershipsFromRoomStmt *sql.Stmt
|
||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||
selectRoomsWithMembershipStmt *sql.Stmt
|
||||
updateMembershipStmt *sql.Stmt
|
||||
}
|
||||
|
||||
@ -105,6 +109,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
@ -203,3 +208,22 @@ func (s *membershipStatements) UpdateMembership(
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
|
||||
var roomNIDs []types.RoomNID
|
||||
for rows.Next() {
|
||||
var roomNID types.RoomNID
|
||||
if err := rows.Scan(&roomNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomNIDs = append(roomNIDs, roomNID)
|
||||
}
|
||||
return roomNIDs, nil
|
||||
}
|
||||
|
@ -21,7 +21,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
@ -64,6 +66,12 @@ const selectRoomVersionForRoomNIDSQL = "" +
|
||||
const selectRoomInfoSQL = "" +
|
||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
const selectRoomIDsSQL = "" +
|
||||
"SELECT room_id FROM roomserver_rooms"
|
||||
|
||||
const bulkSelectRoomIDsSQL = "" +
|
||||
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||
|
||||
type roomStatements struct {
|
||||
db *sql.DB
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
@ -73,6 +81,7 @@ type roomStatements struct {
|
||||
updateLatestEventNIDsStmt *sql.Stmt
|
||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||
selectRoomInfoStmt *sql.Stmt
|
||||
selectRoomIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
@ -91,9 +100,27 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
var info types.RoomInfo
|
||||
var latestNIDsJSON string
|
||||
@ -203,3 +230,25 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
|
||||
}
|
||||
return roomVersion, err
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||
var roomIDs []string
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
if err = rows.Scan(&roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
@ -65,6 +65,8 @@ type Rooms interface {
|
||||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
||||
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||
SelectRoomIDs(ctx context.Context) ([]string, error)
|
||||
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
|
||||
}
|
||||
|
||||
type Transactions interface {
|
||||
@ -120,6 +122,7 @@ type Membership interface {
|
||||
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
|
||||
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||
}
|
||||
|
||||
type Published interface {
|
||||
|
Loading…
Reference in New Issue
Block a user