bugfix: E2EE device keys could sometimes not be sent to remote servers (#2466)

* Fix flakey sytest 'Local device key changes get to remote servers'

* Debug logs

* Remove internal/test and use /test only

Remove a lot of ancient code too.

* Use FederationRoomserverAPI in more places

* Use more interfaces in federationapi; begin adding regression test

* Linting

* Add regression test

* Unbreak tests

* ALL THE LOGS

* Fix a race condition which could cause events to not be sent to servers

If a new room event which rewrites state arrives, we remove all joined hosts
then re-calculate them. This wasn't done in a transaction so for a brief period
we would have no joined hosts. During this interim, key change events which arrive
would not be sent to destination servers. This would sporadically fail on sytest.

* Unbreak new tests

* Linting
This commit is contained in:
kegsay 2022-05-17 13:23:35 +01:00 committed by GitHub
parent cd82460513
commit 6de29c1cd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 566 additions and 618 deletions

View File

@ -20,7 +20,7 @@ import (
"log" "log"
"os" "os"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/test"
) )
const usage = `Usage: %s const usage = `Usage: %s

View File

@ -12,12 +12,16 @@ import (
// FederationInternalAPI is used to query information from the federation sender. // FederationInternalAPI is used to query information from the federation sender.
type FederationInternalAPI interface { type FederationInternalAPI interface {
FederationClient gomatrixserverlib.FederatedStateClient
KeyserverFederationAPI
gomatrixserverlib.KeyDatabase gomatrixserverlib.KeyDatabase
ClientFederationAPI ClientFederationAPI
RoomserverFederationAPI RoomserverFederationAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
// Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos.
PerformBroadcastEDU( PerformBroadcastEDU(
@ -60,17 +64,43 @@ type RoomserverFederationAPI interface {
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError // this interface are of type FederationClientError
type FederationClient interface { type KeyserverFederationAPI interface {
gomatrixserverlib.FederatedStateClient
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
}
// an interface for gmsl.FederationClient - contains functions called by federationapi only.
type FederationClient interface {
gomatrixserverlib.KeyClient
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
// Perform operations
LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error)
MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error)
SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error)
MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error)
SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error)
SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error)
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error)
Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error)
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error)
LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
} }
// FederationClientError is returned from FederationClient methods in the event of a problem. // FederationClientError is returned from FederationClient methods in the event of a problem.

View File

@ -39,7 +39,7 @@ type KeyChangeConsumer struct {
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
rsAPI roomserverAPI.RoomserverInternalAPI rsAPI roomserverAPI.FederationRoomserverAPI
topic string topic string
} }
@ -50,7 +50,7 @@ func NewKeyChangeConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,
store storage.Database, store storage.Database,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
) *KeyChangeConsumer { ) *KeyChangeConsumer {
return &KeyChangeConsumer{ return &KeyChangeConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -120,6 +120,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
logger.WithError(err).Error("failed to calculate joined rooms for user") logger.WithError(err).Error("failed to calculate joined rooms for user")
return true return true
} }
logrus.Infof("DEBUG: %v joined rooms for user %v", queryRes.RoomIDs, m.UserID)
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
if err != nil { if err != nil {
@ -128,6 +129,9 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
} }
if len(destinations) == 0 { if len(destinations) == 0 {
logger.WithField("num_rooms", len(queryRes.RoomIDs)).Debug("user is in no federated rooms")
destinations, err = t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, false)
logrus.Infof("GetJoinedHostsForRooms exclude self=false -> %v %v", destinations, err)
return true return true
} }
// Pack the EDU and marshal it // Pack the EDU and marshal it

View File

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/queue"
@ -36,7 +37,7 @@ import (
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
cfg *config.FederationAPI cfg *config.FederationAPI
rsAPI api.RoomserverInternalAPI rsAPI api.FederationRoomserverAPI
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string durable string
db storage.Database db storage.Database
@ -51,7 +52,7 @@ func NewOutputRoomEventConsumer(
js nats.JetStreamContext, js nats.JetStreamContext,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,
store storage.Database, store storage.Database,
rsAPI api.RoomserverInternalAPI, rsAPI api.FederationRoomserverAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -89,15 +90,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg)
switch output.Type { switch output.Type {
case api.OutputTypeNewRoomEvent: case api.OutputTypeNewRoomEvent:
ev := output.NewRoomEvent.Event ev := output.NewRoomEvent.Event
if err := s.processMessage(*output.NewRoomEvent, output.NewRoomEvent.RewritesState); err != nil {
if output.NewRoomEvent.RewritesState {
if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil {
log.WithError(err).Errorf("roomserver output log: purge room state failure")
return false
}
}
if err := s.processMessage(*output.NewRoomEvent); err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": ev.EventID(), "event_id": ev.EventID(),
@ -145,7 +138,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
// processMessage updates the list of currently joined hosts in the room // processMessage updates the list of currently joined hosts in the room
// and then sends the event to the hosts that were joined before the event. // and then sends the event to the hosts that were joined before the event.
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error {
addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() addsStateEvents, missingEventIDs := ore.NeededStateEventIDs()
// Ask the roomserver and add in the rest of the results into the set. // Ask the roomserver and add in the rest of the results into the set.
@ -164,7 +157,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
addsStateEvents = append(addsStateEvents, eventsRes.Events...) addsStateEvents = append(addsStateEvents, eventsRes.Events...)
} }
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents))
if err != nil { if err != nil {
return err return err
} }
@ -173,13 +166,13 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
// expressed as a delta against the current state. // expressed as a delta against the current state.
// TODO(#290): handle EventIDMismatchError and recover the current state by // TODO(#290): handle EventIDMismatchError and recover the current state by
// talking to the roomserver // talking to the roomserver
logrus.Infof("room %s adds joined hosts: %v removes %v", ore.Event.RoomID(), addsJoinedHosts, ore.RemovesStateEventIDs)
oldJoinedHosts, err := s.db.UpdateRoom( oldJoinedHosts, err := s.db.UpdateRoom(
s.ctx, s.ctx,
ore.Event.RoomID(), ore.Event.RoomID(),
ore.LastSentEventID,
ore.Event.EventID(),
addsJoinedHosts, addsJoinedHosts,
ore.RemovesStateEventIDs, ore.RemovesStateEventIDs,
rewritesState, // if we're re-writing state, nuke all joined hosts before adding
) )
if err != nil { if err != nil {
return err return err
@ -238,7 +231,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
return nil, err return nil, err
} }
combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents) combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -284,10 +277,10 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
return result, nil return result, nil
} }
// joinedHostsFromEvents turns a list of state events into a list of joined hosts. // JoinedHostsFromEvents turns a list of state events into a list of joined hosts.
// This errors if one of the events was invalid. // This errors if one of the events was invalid.
// It should be impossible for an invalid event to get this far in the pipeline. // It should be impossible for an invalid event to get this far in the pipeline.
func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) { func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
var joinedHosts []types.JoinedHost var joinedHosts []types.JoinedHost
for _, ev := range evs { for _, ev := range evs {
if ev.Type() != "m.room.member" || ev.StateKey() == nil { if ev.Type() != "m.room.member" || ev.StateKey() == nil {

View File

@ -93,8 +93,8 @@ func AddPublicRoutes(
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
base *base.BaseDendrite, base *base.BaseDendrite,
federation *gomatrixserverlib.FederationClient, federation api.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
caches *caching.Caches, caches *caching.Caches,
keyRing *gomatrixserverlib.KeyRing, keyRing *gomatrixserverlib.KeyRing,
resetBlacklist bool, resetBlacklist bool,

View File

@ -3,18 +3,250 @@ package federationapi_test
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"encoding/json"
"fmt"
"strings" "strings"
"testing" "testing"
"time"
"github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/internal" "github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/internal/test" keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
) )
type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
}
// PerformJoin will call this function
func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if f.inputRoomEvents == nil {
return
}
f.inputRoomEvents(ctx, req, res)
}
// keychange consumer calls this
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if f.queryRoomsForUser == nil {
return nil
}
return f.queryRoomsForUser(ctx, req, res)
}
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
type fedClient struct {
api.FederationClient
allowJoins []*test.Room
keys map[gomatrixserverlib.ServerName]struct {
key ed25519.PrivateKey
keyID gomatrixserverlib.KeyID
}
t *testing.T
sentTxn bool
}
func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) {
fmt.Println("GetServerKeys:", matrixServer)
var keys gomatrixserverlib.ServerKeys
var keyID gomatrixserverlib.KeyID
var pkey ed25519.PrivateKey
for srv, data := range f.keys {
if srv == matrixServer {
pkey = data.key
keyID = data.keyID
break
}
}
if pkey == nil {
return keys, nil
}
keys.ServerName = matrixServer
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour))
publicKey := pkey.Public().(ed25519.PublicKey)
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
keyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
toSign, err := json.Marshal(keys.ServerKeyFields)
if err != nil {
return keys, err
}
keys.Raw, err = gomatrixserverlib.SignJSON(
string(matrixServer), keyID, pkey, toSign,
)
if err != nil {
return keys, err
}
return keys, nil
}
func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) {
for _, r := range f.allowJoins {
if r.ID == roomID {
res.RoomVersion = r.Version
res.JoinEvent = gomatrixserverlib.EventBuilder{
Sender: userID,
RoomID: roomID,
Type: "m.room.member",
StateKey: &userID,
Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)),
PrevEvents: r.ForwardExtremities(),
}
var needed gomatrixserverlib.StateNeeded
needed, err = gomatrixserverlib.StateNeededForEventBuilder(&res.JoinEvent)
if err != nil {
f.t.Errorf("StateNeededForEventBuilder: %v", err)
return
}
res.JoinEvent.AuthEvents = r.MustGetAuthEventRefsForEvent(f.t, needed)
return
}
}
return
}
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
for _, r := range f.allowJoins {
if r.ID == event.RoomID() {
r.InsertEvent(f.t, event.Headered(r.Version))
f.t.Logf("Join event: %v", event.EventID())
res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState())
res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events())
}
}
return
}
func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
for _, edu := range t.EDUs {
if edu.Type == gomatrixserverlib.MDeviceListUpdate {
f.sentTxn = true
}
}
f.t.Logf("got /send")
return
}
// Regression test to make sure that /send_join is updating the destination hosts synchronously and
// isn't relying on the roomserver.
func TestFederationAPIJoinThenKeyUpdate(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testFederationAPIJoinThenKeyUpdate(t, dbType)
})
}
func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
base.Cfg.FederationAPI.PreferDirectFetch = true
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
serverA := gomatrixserverlib.ServerName("server.a")
serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera")
serverAPrivKey := test.PrivateKeyA
creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey))
myServer := base.Cfg.Global.ServerName
myServerKeyID := base.Cfg.Global.KeyID
myServerPrivKey := base.Cfg.Global.PrivateKey
joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey))
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator)
rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if req.Asynchronous {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
}
},
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if req.UserID == joiningUser.ID && req.WantMembership == "join" {
res.RoomIDs = []string{room.ID}
return nil
}
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req)
},
}
fc := &fedClient{
allowJoins: []*test.Room{room},
t: t,
keys: map[gomatrixserverlib.ServerName]struct {
key ed25519.PrivateKey
keyID gomatrixserverlib.KeyID
}{
serverA: {
key: serverAPrivKey,
keyID: serverAKeyID,
},
myServer: {
key: myServerPrivKey,
keyID: myServerKeyID,
},
},
}
fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false)
var resp api.PerformJoinResponse
fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{
RoomID: room.ID,
UserID: joiningUser.ID,
ServerNames: []gomatrixserverlib.ServerName{serverA},
}, &resp)
if resp.JoinedVia != serverA {
t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA)
}
if resp.LastError != nil {
t.Fatalf("PerformJoin: returned error: %+v", *resp.LastError)
}
// Inject a keyserver key change event and ensure we try to send it out. If we don't, then the
// federationapi is incorrectly waiting for an output room event to arrive to update the joined
// hosts table.
key := keyapi.DeviceMessage{
Type: keyapi.TypeDeviceKeyUpdate,
DeviceKeys: &keyapi.DeviceKeys{
UserID: joiningUser.ID,
DeviceID: "MY_DEVICE",
DisplayName: "BLARGLE",
KeyJSON: []byte(`{}`),
},
}
b, err := json.Marshal(key)
if err != nil {
t.Fatalf("Failed to marshal device message: %s", err)
}
msg := &nats.Msg{
Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
Header: nats.Header{},
Data: b,
}
msg.Header.Set(jetstream.UserID, key.UserID)
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(500 * time.Millisecond)
if !fc.sentTxn {
t.Fatalf("did not send device list update")
}
}
// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404. // Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404.
// Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated. // Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated.
func TestRoomsV3URLEscapeDoNot404(t *testing.T) { func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
@ -86,7 +318,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
} }
gerr, ok := err.(gomatrix.HTTPError) gerr, ok := err.(gomatrix.HTTPError)
if !ok { if !ok {
t.Errorf("failed to cast response error as gomatrix.HTTPError") t.Errorf("failed to cast response error as gomatrix.HTTPError: %s", err)
continue continue
} }
t.Logf("Error: %+v", gerr) t.Logf("Error: %+v", gerr)

View File

@ -25,8 +25,8 @@ type FederationInternalAPI struct {
db storage.Database db storage.Database
cfg *config.FederationAPI cfg *config.FederationAPI
statistics *statistics.Statistics statistics *statistics.Statistics
rsAPI roomserverAPI.RoomserverInternalAPI rsAPI roomserverAPI.FederationRoomserverAPI
federation *gomatrixserverlib.FederationClient federation api.FederationClient
keyRing *gomatrixserverlib.KeyRing keyRing *gomatrixserverlib.KeyRing
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
joins sync.Map // joins currently in progress joins sync.Map // joins currently in progress
@ -34,8 +34,8 @@ type FederationInternalAPI struct {
func NewFederationInternalAPI( func NewFederationInternalAPI(
db storage.Database, cfg *config.FederationAPI, db storage.Database, cfg *config.FederationAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
federation *gomatrixserverlib.FederationClient, federation api.FederationClient,
statistics *statistics.Statistics, statistics *statistics.Statistics,
caches *caching.Caches, caches *caching.Caches,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
@ -235,6 +236,21 @@ func (r *FederationInternalAPI) performJoinUsingServer(
return fmt.Errorf("respSendJoin.Check: %w", err) return fmt.Errorf("respSendJoin.Check: %w", err)
} }
// We need to immediately update our list of joined hosts for this room now as we are technically
// joined. We must do this synchronously: we cannot rely on the roomserver output events as they
// will happen asyncly. If we don't update this table, you can end up with bad failure modes like
// joining a room, waiting for 200 OK then changing device keys and have those keys not be sent
// to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers")
// The events are trusted now as we performed auth checks above.
joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false))
if err != nil {
return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err)
}
logrus.WithField("hosts", joinedHosts).WithField("room", roomID).Info("Joined federated room with hosts")
if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil {
return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err)
}
// If we successfully performed a send_join above then the other // If we successfully performed a send_join above then the other
// server now thinks we're a part of the room. Send the newly // server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view. // returned state to the roomserver to update our local view.
@ -650,7 +666,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided // FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider( func federatedAuthProvider(
ctx context.Context, federation *gomatrixserverlib.FederationClient, ctx context.Context, federation api.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider { ) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in // A list of events that we have retried, if they were not included in

View File

@ -21,6 +21,7 @@ import (
"sync" "sync"
"time" "time"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/storage/shared"
@ -49,21 +50,21 @@ type destinationQueue struct {
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
signing *SigningInfo signing *SigningInfo
rsAPI api.RoomserverInternalAPI rsAPI api.FederationRoomserverAPI
client *gomatrixserverlib.FederationClient // federation client client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests origin gomatrixserverlib.ServerName // origin of requests
destination gomatrixserverlib.ServerName // destination of requests destination gomatrixserverlib.ServerName // destination of requests
running atomic.Bool // is the queue worker running? running atomic.Bool // is the queue worker running?
backingOff atomic.Bool // true if we're backing off backingOff atomic.Bool // true if we're backing off
overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more
statistics *statistics.ServerStatistics // statistics about this remote server statistics *statistics.ServerStatistics // statistics about this remote server
transactionIDMutex sync.Mutex // protects transactionID transactionIDMutex sync.Mutex // protects transactionID
transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful
notify chan struct{} // interrupts idle wait pending PDUs/EDUs notify chan struct{} // interrupts idle wait pending PDUs/EDUs
pendingPDUs []*queuedPDU // PDUs waiting to be sent pendingPDUs []*queuedPDU // PDUs waiting to be sent
pendingEDUs []*queuedEDU // EDUs waiting to be sent pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff interruptBackoff chan bool // interrupts backoff
} }
// Send event adds the event to the pending queue for the destination. // Send event adds the event to the pending queue for the destination.

View File

@ -26,6 +26,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/storage/shared"
@ -39,9 +40,9 @@ type OutgoingQueues struct {
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
disabled bool disabled bool
rsAPI api.RoomserverInternalAPI rsAPI api.FederationRoomserverAPI
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
client *gomatrixserverlib.FederationClient client fedapi.FederationClient
statistics *statistics.Statistics statistics *statistics.Statistics
signing *SigningInfo signing *SigningInfo
queuesMutex sync.Mutex // protects the below queuesMutex sync.Mutex // protects the below
@ -85,8 +86,8 @@ func NewOutgoingQueues(
process *process.ProcessContext, process *process.ProcessContext,
disabled bool, disabled bool,
origin gomatrixserverlib.ServerName, origin gomatrixserverlib.ServerName,
client *gomatrixserverlib.FederationClient, client fedapi.FederationClient,
rsAPI api.RoomserverInternalAPI, rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics, statistics *statistics.Statistics,
signing *SigningInfo, signing *SigningInfo,
) *OutgoingQueues { ) *OutgoingQueues {

View File

@ -30,7 +30,7 @@ import (
// RoomAliasToID converts the queried alias into a room ID and returns it // RoomAliasToID converts the queried alias into a room ID and returns it
func RoomAliasToID( func RoomAliasToID(
httpReq *http.Request, httpReq *http.Request,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
cfg *config.FederationAPI, cfg *config.FederationAPI,
rsAPI roomserverAPI.FederationRoomserverAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
senderAPI federationAPI.FederationInternalAPI, senderAPI federationAPI.FederationInternalAPI,

View File

@ -54,7 +54,7 @@ func Setup(
rsAPI roomserverAPI.FederationRoomserverAPI, rsAPI roomserverAPI.FederationRoomserverAPI,
fsAPI *fedInternal.FederationInternalAPI, fsAPI *fedInternal.FederationInternalAPI,
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI, userAPI userapi.FederationUserAPI,
keyAPI keyserverAPI.FederationKeyAPI, keyAPI keyserverAPI.FederationKeyAPI,
mscCfg *config.MSCs, mscCfg *config.MSCs,

View File

@ -85,7 +85,7 @@ func Send(
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
keyAPI keyapi.FederationKeyAPI, keyAPI keyapi.FederationKeyAPI,
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
mu *internal.MutexByRoom, mu *internal.MutexByRoom,
servers federationAPI.ServersInRoomProvider, servers federationAPI.ServersInRoomProvider,
producer *producers.SyncAPIProducer, producer *producers.SyncAPIProducer,

View File

@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )

View File

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -57,7 +58,7 @@ var (
func CreateInvitesFrom3PIDInvites( func CreateInvitesFrom3PIDInvites(
req *http.Request, rsAPI api.FederationRoomserverAPI, req *http.Request, rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI, cfg *config.FederationAPI,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI, userAPI userapi.FederationUserAPI,
) util.JSONResponse { ) util.JSONResponse {
var body invites var body invites
@ -107,7 +108,7 @@ func ExchangeThirdPartyInvite(
roomID string, roomID string,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI, cfg *config.FederationAPI,
federation *gomatrixserverlib.FederationClient, federation federationAPI.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
var builder gomatrixserverlib.EventBuilder var builder gomatrixserverlib.EventBuilder
if err := json.Unmarshal(request.Content(), &builder); err != nil { if err := json.Unmarshal(request.Content(), &builder); err != nil {
@ -165,7 +166,12 @@ func ExchangeThirdPartyInvite(
// Ask the requesting server to sign the newly created event so we know it // Ask the requesting server to sign the newly created event so we know it
// acknowledged it // acknowledged it
signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), event) inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request")
return jsonerror.InternalServerError()
}
signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq)
if err != nil { if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -205,7 +211,7 @@ func ExchangeThirdPartyInvite(
func createInviteFrom3PIDInvite( func createInviteFrom3PIDInvite(
ctx context.Context, rsAPI api.FederationRoomserverAPI, ctx context.Context, rsAPI api.FederationRoomserverAPI,
cfg *config.FederationAPI, cfg *config.FederationAPI,
inv invite, federation *gomatrixserverlib.FederationClient, inv invite, federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI, userAPI userapi.FederationUserAPI,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID}
@ -335,7 +341,7 @@ func buildMembershipEvent(
// them responded with an error. // them responded with an error.
func sendToRemoteServer( func sendToRemoteServer(
ctx context.Context, inv invite, ctx context.Context, inv invite,
federation *gomatrixserverlib.FederationClient, _ *config.FederationAPI, federation federationAPI.FederationClient, _ *config.FederationAPI,
builder gomatrixserverlib.EventBuilder, builder gomatrixserverlib.EventBuilder,
) (err error) { ) (err error) {
remoteServers := make([]gomatrixserverlib.ServerName, 2) remoteServers := make([]gomatrixserverlib.ServerName, 2)

View File

@ -25,13 +25,12 @@ import (
type Database interface { type Database interface {
gomatrixserverlib.KeyDatabase gomatrixserverlib.KeyDatabase
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error)
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
PurgeRoomState(ctx context.Context, roomID string) error
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View File

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
) )
const joinedHostsSchema = ` const joinedHostsSchema = `
@ -111,6 +112,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
roomID, eventID string, roomID, eventID string,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) error { ) error {
logrus.Debugf("FederationJoinedHosts: INSERT %v %v %v", roomID, eventID, serverName)
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err return err
@ -119,6 +121,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
func (s *joinedHostsStatements) DeleteJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) error { ) error {
logrus.Debugf("FederationJoinedHosts: DELETE WITH EVENTS %v", eventIDs)
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
return err return err
@ -127,6 +130,7 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) error { ) error {
logrus.Debugf("FederationJoinedHosts: DELETE ALL IN ROOM %v", roomID)
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID) _, err := stmt.ExecContext(ctx, roomID)
return err return err
@ -207,6 +211,7 @@ func joinedHostsFromStmt(
ServerName: gomatrixserverlib.ServerName(serverName), ServerName: gomatrixserverlib.ServerName(serverName),
}) })
} }
logrus.Debugf("FederationJoinedHosts: SELECT %v => %+v", roomID, result)
return result, rows.Err() return result, rows.Err()
} }

View File

@ -63,11 +63,21 @@ func (r *Receipt) String() string {
// this isn't a duplicate message. // this isn't a duplicate message.
func (d *Database) UpdateRoom( func (d *Database) UpdateRoom(
ctx context.Context, ctx context.Context,
roomID, oldEventID, newEventID string, roomID string,
addHosts []types.JoinedHost, addHosts []types.JoinedHost,
removeHosts []string, removeHosts []string,
purgeRoomFirst bool,
) (joinedHosts []types.JoinedHost, err error) { ) (joinedHosts []types.JoinedHost, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if purgeRoomFirst {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err = d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
}
}
joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID) joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
if err != nil { if err != nil {
return err return err
@ -138,20 +148,6 @@ func (d *Database) StoreJSON(
}, nil }, nil
} }
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
}
return nil
})
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)

View File

@ -20,7 +20,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/test"
) )
func TestEDUCache(t *testing.T) { func TestEDUCache(t *testing.T) {

View File

@ -1,158 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net/http"
"sync"
"time"
"github.com/matrix-org/gomatrixserverlib"
)
// Request contains the information necessary to issue a request and test its result
type Request struct {
Req *http.Request
WantedBody string
WantedStatusCode int
LastErr *LastRequestErr
}
// LastRequestErr is a synchronised error wrapper
// Useful for obtaining the last error from a set of requests
type LastRequestErr struct {
sync.Mutex
Err error
}
// Set sets the error
func (r *LastRequestErr) Set(err error) {
r.Lock()
defer r.Unlock()
r.Err = err
}
// Get gets the error
func (r *LastRequestErr) Get() error {
r.Lock()
defer r.Unlock()
return r.Err
}
// CanonicalJSONInput canonicalises a slice of JSON strings
// Useful for test input
func CanonicalJSONInput(jsonData []string) []string {
for i := range jsonData {
jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i]))
if err != nil && err != io.EOF {
panic(err)
}
jsonData[i] = string(jsonBytes)
}
return jsonData
}
// Do issues a request and checks the status code and body of the response
func (r *Request) Do() (err error) {
client := &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
res, err := client.Do(r.Req)
if err != nil {
return err
}
defer (func() { err = res.Body.Close() })()
if res.StatusCode != r.WantedStatusCode {
return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode)
}
if r.WantedBody != "" {
resBytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return err
}
jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes)
if err != nil {
return err
}
if string(jsonBytes) != r.WantedBody {
return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes))
}
}
return nil
}
// DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body.
// It then closes the given channel and returns.
func (r *Request) DoUntilSuccess(done chan error) {
r.LastErr = &LastRequestErr{}
for {
if err := r.Do(); err != nil {
r.LastErr.Set(err)
time.Sleep(1 * time.Second) // don't tightloop
continue
}
close(done)
return
}
}
// Run repeatedly issues a request until success, error or a timeout is reached
func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) {
fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout)
done := make(chan error, 1)
// We need to wait for the server to:
// - have connected to the database
// - have created the tables
// - be listening on the given port
go r.DoUntilSuccess(done)
// wait for one of:
// - the test to pass (done channel is closed)
// - the server to exit with an error (error sent on serverCmdChan)
// - our test timeout to expire
// We don't need to clean up since the main() function handles that in the event we panic
select {
case <-time.After(timeout):
fmt.Printf("==TESTING== %v TIMEOUT\n", label)
if reqErr := r.LastErr.Get(); reqErr != nil {
fmt.Println("Last /sync request error:")
fmt.Println(reqErr)
}
panic(fmt.Sprintf("%v server timed out", label))
case err := <-serverCmdChan:
if err != nil {
fmt.Println("=============================================================================================")
fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label)
fmt.Println(" export PGHOST=/var/run/postgresql")
fmt.Println("=============================================================================================")
panic(err)
}
case <-done:
fmt.Printf("==TESTING== %v PASSED\n", label)
}
}

View File

@ -1,76 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"io"
"os/exec"
"path/filepath"
"strings"
)
// KafkaExecutor executes kafka scripts.
type KafkaExecutor struct {
// The location of Zookeeper. Typically this is `localhost:2181`.
ZookeeperURI string
// The directory where Kafka is installed to. Used to locate kafka scripts.
KafkaDirectory string
// The location of the Kafka logs. Typically this is `localhost:9092`.
KafkaURI string
// Where stdout and stderr should be written to. Typically this is `os.Stderr`.
OutputWriter io.Writer
}
// CreateTopic creates a new kafka topic. This is created with a single partition.
func (e *KafkaExecutor) CreateTopic(topic string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
"--create",
"--zookeeper", e.ZookeeperURI,
"--replication-factor", "1",
"--partitions", "1",
"--topic", topic,
)
cmd.Stdout = e.OutputWriter
cmd.Stderr = e.OutputWriter
return cmd.Run()
}
// WriteToTopic writes data to a kafka topic.
func (e *KafkaExecutor) WriteToTopic(topic string, data []string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-console-producer.sh"),
"--broker-list", e.KafkaURI,
"--topic", topic,
)
cmd.Stdout = e.OutputWriter
cmd.Stderr = e.OutputWriter
cmd.Stdin = strings.NewReader(strings.Join(data, "\n"))
return cmd.Run()
}
// DeleteTopic deletes a given kafka topic if it exists.
func (e *KafkaExecutor) DeleteTopic(topic string) error {
cmd := exec.Command(
filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"),
"--delete",
"--if-exists",
"--zookeeper", e.ZookeeperURI,
"--topic", topic,
)
cmd.Stderr = e.OutputWriter
cmd.Stdout = e.OutputWriter
return cmd.Run()
}

View File

@ -1,152 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package test
import (
"context"
"fmt"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/matrix-org/dendrite/setup/config"
)
// Defaulting allows assignment of string variables with a fallback default value
// Useful for use with os.Getenv() for example
func Defaulting(value, defaultValue string) string {
if value == "" {
value = defaultValue
}
return value
}
// CreateDatabase creates a new database, dropping it first if it exists
func CreateDatabase(command string, args []string, database string) error {
cmd := exec.Command(command, args...)
cmd.Stdin = strings.NewReader(
fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database),
)
// Send stdout and stderr to our stderr so that we see error messages from
// the psql process
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
return cmd.Run()
}
// CreateBackgroundCommand creates an executable command
// The Cmd being executed is returned. A channel is also returned,
// which will have any termination errors sent down it, followed immediately by the channel being closed.
func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan error) {
cmd := exec.Command(command, args...)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stderr
if err := cmd.Start(); err != nil {
panic("failed to start server: " + err.Error())
}
cmdChan := make(chan error, 1)
go func() {
cmdChan <- cmd.Wait()
close(cmdChan)
}()
return cmd, cmdChan
}
// InitDatabase creates the database and config file needed for the server to run
func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) {
if len(databases) > 0 {
var dbCmd string
var dbArgs []string
if postgresContainerName == "" {
dbCmd = "psql"
dbArgs = []string{postgresDatabase}
} else {
dbCmd = "docker"
dbArgs = []string{
"exec", "-i", postgresContainerName, "psql", "-U", "postgres", postgresDatabase,
}
}
for _, database := range databases {
if err := CreateDatabase(dbCmd, dbArgs, database); err != nil {
panic(err)
}
}
}
}
// StartProxy creates a reverse proxy
func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) {
proxyArgs := []string{
"--bind-address", bindAddr,
"--sync-api-server-url", "http://" + string(cfg.SyncAPI.InternalAPI.Connect),
"--client-api-server-url", "http://" + string(cfg.ClientAPI.InternalAPI.Connect),
"--media-api-server-url", "http://" + string(cfg.MediaAPI.InternalAPI.Connect),
"--tls-cert", "server.crt",
"--tls-key", "server.key",
}
return CreateBackgroundCommand(
filepath.Join(filepath.Dir(os.Args[0]), "client-api-proxy"),
proxyArgs,
)
}
// ListenAndServe will listen on a random high-numbered port and attach the given router.
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen: %s", err)
}
port := listener.Addr().(*net.TCPAddr).Port
srv := http.Server{}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Handler = router
var err error
if useTLS {
certFile := filepath.Join(os.TempDir(), "dendrite.cert")
keyFile := filepath.Join(os.TempDir(), "dendrite.key")
err = NewTLSKey(keyFile, certFile)
if err != nil {
t.Logf("failed to generate tls key/cert: %s", err)
return
}
err = srv.ServeTLS(listener, certFile, keyFile)
} else {
err = srv.Serve(listener)
}
if err != nil && err != http.ErrServerClosed {
t.Logf("Listen failed: %s", err)
}
}()
secure := ""
if useTLS {
secure = "s"
}
return fmt.Sprintf("http%s://localhost:%d", secure, port), func() {
_ = srv.Shutdown(context.Background())
wg.Wait()
}
}

View File

@ -84,7 +84,7 @@ type DeviceListUpdater struct {
db DeviceListUpdaterDatabase db DeviceListUpdaterDatabase
api DeviceListUpdaterAPI api DeviceListUpdaterAPI
producer KeyChangeProducer producer KeyChangeProducer
fedClient fedsenderapi.FederationClient fedClient fedsenderapi.KeyserverFederationAPI
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
@ -127,7 +127,7 @@ type KeyChangeProducer interface {
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater( func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.FederationClient, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
userIDToMutex: make(map[string]*sync.Mutex), userIDToMutex: make(map[string]*sync.Mutex),

View File

@ -37,7 +37,7 @@ import (
type KeyInternalAPI struct { type KeyInternalAPI struct {
DB storage.Database DB storage.Database
ThisServer gomatrixserverlib.ServerName ThisServer gomatrixserverlib.ServerName
FedClient fedsenderapi.FederationClient FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange Producer *producers.KeyChange
Updater *DeviceListUpdater Updater *DeviceListUpdater

View File

@ -37,7 +37,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
) api.KeyInternalAPI { ) api.KeyInternalAPI {
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)

View File

@ -183,6 +183,7 @@ type FederationRoomserverAPI interface {
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.

View File

@ -12,7 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
) )
@ -22,7 +22,7 @@ var jc *nats.Conn
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
var b *base.BaseDendrite var b *base.BaseDendrite
b, js, jc = test.Base(nil) b, js, jc = testrig.Base(nil)
code := m.Run() code := m.Run()
b.ShutdownDendrite() b.ShutdownDendrite()
b.WaitForComponentsToFinish() b.WaitForComponentsToFinish()

View File

@ -19,8 +19,8 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )

View File

@ -39,7 +39,7 @@ func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, fun
} }
func Test_EventsTable(t *testing.T) { func Test_EventsTable(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {

View File

@ -38,7 +38,7 @@ func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables
func TestPreviousEventsTable(t *testing.T) { func TestPreviousEventsTable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreatePreviousEventsTable(t, dbType) tab, close := mustCreatePreviousEventsTable(t, dbType)

View File

@ -38,7 +38,7 @@ func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Publ
func TestPublishedTable(t *testing.T) { func TestPublishedTable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
alice := test.NewUser() alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreatePublishedTable(t, dbType) tab, close := mustCreatePublishedTable(t, dbType)

View File

@ -36,7 +36,7 @@ func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.Ro
} }
func TestRoomAliasesTable(t *testing.T) { func TestRoomAliasesTable(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice)
ctx := context.Background() ctx := context.Background()

View File

@ -38,7 +38,7 @@ func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, c
} }
func TestRoomsTable(t *testing.T) { func TestRoomsTable(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {

View File

@ -47,7 +47,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
func TestWriteEvents(t *testing.T) { func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser() alice := test.NewUser(t)
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
defer close() defer close()
@ -60,7 +60,7 @@ func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser(t)
// dummy room to make sure SQL queries are filtering on room ID // dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
@ -163,7 +163,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser(t)
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)}) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})

View File

@ -45,7 +45,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events,
func TestOutputRoomEventsTable(t *testing.T) { func TestOutputRoomEventsTable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newOutputRoomEventsTable(t, dbType) tab, db, close := newOutputRoomEventsTable(t, dbType)

View File

@ -40,7 +40,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D
func TestTopologyTable(t *testing.T) { func TestTopologyTable(t *testing.T) {
ctx := context.Background() ctx := context.Background()
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newTopologyTable(t, dbType) tab, db, close := newTopologyTable(t, dbType)

View File

@ -15,6 +15,7 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
@ -86,7 +87,7 @@ func TestSyncAPIAccessTokens(t *testing.T) {
} }
func testSyncAccessTokens(t *testing.T, dbType test.DBType) { func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
user := test.NewUser() user := test.NewUser(t)
room := test.NewRoom(t, user) room := test.NewRoom(t, user)
alice := userapi.Device{ alice := userapi.Device{
ID: "ALICEID", ID: "ALICEID",
@ -96,14 +97,14 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
} }
base, close := test.CreateBaseDendrite(t, dbType) base, close := testrig.CreateBaseDendrite(t, dbType)
defer close() defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events()) msgs := toNATSMsgs(t, base, room.Events())
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
test.MustPublishMsgs(t, jsctx, msgs...) testrig.MustPublishMsgs(t, jsctx, msgs...)
testCases := []struct { testCases := []struct {
name string name string
@ -173,7 +174,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) {
} }
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
user := test.NewUser() user := test.NewUser(t)
room := test.NewRoom(t, user) room := test.NewRoom(t, user)
alice := userapi.Device{ alice := userapi.Device{
ID: "ALICEID", ID: "ALICEID",
@ -183,7 +184,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
} }
base, close := test.CreateBaseDendrite(t, dbType) base, close := testrig.CreateBaseDendrite(t, dbType)
defer close() defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
@ -198,7 +199,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
sinceTokens := make([]string, len(msgs)) sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
for i, msg := range msgs { for i, msg := range msgs {
test.MustPublishMsgs(t, jsctx, msg) testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
w := httptest.NewRecorder() w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
@ -262,7 +263,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
if ev.StateKey() != nil { if ev.StateKey() != nil {
addsStateIDs = append(addsStateIDs, ev.EventID()) addsStateIDs = append(addsStateIDs, ev.EventID())
} }
result[i] = test.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{ result[i] = testrig.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{
Type: rsapi.OutputTypeNewRoomEvent, Type: rsapi.OutputTypeNewRoomEvent,
NewRoomEvent: &rsapi.OutputNewRoomEvent{ NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev, Event: ev,

View File

@ -52,6 +52,24 @@ func WithUnsigned(unsigned interface{}) eventModifier {
} }
} }
func WithKeyID(keyID gomatrixserverlib.KeyID) eventModifier {
return func(e *eventMods) {
e.keyID = keyID
}
}
func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier {
return func(e *eventMods) {
e.privKey = pkey
}
}
func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier {
return func(e *eventMods) {
e.origin = origin
}
}
// Reverse a list of events // Reverse a list of events
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) out := make([]*gomatrixserverlib.HeaderedEvent, len(in))

View File

@ -2,10 +2,15 @@ package test
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"path/filepath"
"sync"
"testing" "testing"
) )
@ -43,3 +48,45 @@ func NewRequest(t *testing.T, method, path string, opts ...HTTPRequestOpt) *http
} }
return req return req
} }
// ListenAndServe will listen on a random high-numbered port and attach the given router.
// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed.
func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL string, cancel func()) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen: %s", err)
}
port := listener.Addr().(*net.TCPAddr).Port
srv := http.Server{}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Handler = router
var err error
if withTLS {
certFile := filepath.Join(t.TempDir(), "dendrite.cert")
keyFile := filepath.Join(t.TempDir(), "dendrite.key")
err = NewTLSKey(keyFile, certFile)
if err != nil {
t.Errorf("failed to make TLS key: %s", err)
return
}
err = srv.ServeTLS(listener, certFile, keyFile)
} else {
err = srv.Serve(listener)
}
if err != nil && err != http.ErrServerClosed {
t.Logf("Listen failed: %s", err)
}
}()
s := ""
if withTLS {
s = "s"
}
return fmt.Sprintf("http%s://localhost:%d", s, port), func() {
_ = srv.Shutdown(context.Background())
wg.Wait()
}
}

View File

@ -25,103 +25,19 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"os" "os"
"path/filepath"
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"gopkg.in/yaml.v2"
) )
const ( const (
// ConfigFile is the name of the config file for a server.
ConfigFile = "dendrite.yaml"
// ServerKeyFile is the name of the file holding the matrix server private key. // ServerKeyFile is the name of the file holding the matrix server private key.
ServerKeyFile = "server_key.pem" ServerKeyFile = "server_key.pem"
// TLSCertFile is the name of the file holding the TLS certificate used for federation. // TLSCertFile is the name of the file holding the TLS certificate used for federation.
TLSCertFile = "tls_cert.pem" TLSCertFile = "tls_cert.pem"
// TLSKeyFile is the name of the file holding the TLS key used for federation. // TLSKeyFile is the name of the file holding the TLS key used for federation.
TLSKeyFile = "tls_key.pem" TLSKeyFile = "tls_key.pem"
// MediaDir is the name of the directory used to store media.
MediaDir = "media"
) )
// MakeConfig makes a config suitable for running integration tests.
// Generates new matrix and TLS keys for the server.
func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*config.Dendrite, int, error) {
var cfg config.Dendrite
cfg.Defaults(true)
port := startPort
assignAddress := func() config.HTTPAddress {
result := config.HTTPAddress(fmt.Sprintf("http://%s:%d", host, port))
port++
return result
}
serverKeyPath := filepath.Join(configDir, ServerKeyFile)
tlsCertPath := filepath.Join(configDir, TLSKeyFile)
tlsKeyPath := filepath.Join(configDir, TLSCertFile)
mediaBasePath := filepath.Join(configDir, MediaDir)
if err := NewMatrixKey(serverKeyPath); err != nil {
return nil, 0, err
}
if err := NewTLSKey(tlsKeyPath, tlsCertPath); err != nil {
return nil, 0, err
}
cfg.Version = config.Version
cfg.Global.ServerName = gomatrixserverlib.ServerName(assignAddress())
cfg.Global.PrivateKeyPath = config.Path(serverKeyPath)
cfg.MediaAPI.BasePath = config.Path(mediaBasePath)
cfg.Global.JetStream.Addresses = []string{kafkaURI}
// TODO: Use different databases for the different schemas.
// Using the same database for every schema currently works because
// the table names are globally unique. But we might not want to
// rely on that in the future.
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(database)
cfg.FederationAPI.Database.ConnectionString = config.DataSource(database)
cfg.KeyServer.Database.ConnectionString = config.DataSource(database)
cfg.MediaAPI.Database.ConnectionString = config.DataSource(database)
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
cfg.FederationAPI.InternalAPI.Listen = assignAddress()
cfg.KeyServer.InternalAPI.Listen = assignAddress()
cfg.MediaAPI.InternalAPI.Listen = assignAddress()
cfg.RoomServer.InternalAPI.Listen = assignAddress()
cfg.SyncAPI.InternalAPI.Listen = assignAddress()
cfg.UserAPI.InternalAPI.Listen = assignAddress()
cfg.AppServiceAPI.InternalAPI.Connect = cfg.AppServiceAPI.InternalAPI.Listen
cfg.FederationAPI.InternalAPI.Connect = cfg.FederationAPI.InternalAPI.Listen
cfg.KeyServer.InternalAPI.Connect = cfg.KeyServer.InternalAPI.Listen
cfg.MediaAPI.InternalAPI.Connect = cfg.MediaAPI.InternalAPI.Listen
cfg.RoomServer.InternalAPI.Connect = cfg.RoomServer.InternalAPI.Listen
cfg.SyncAPI.InternalAPI.Connect = cfg.SyncAPI.InternalAPI.Listen
cfg.UserAPI.InternalAPI.Connect = cfg.UserAPI.InternalAPI.Listen
return &cfg, port, nil
}
// WriteConfig writes the config file to the directory.
func WriteConfig(cfg *config.Dendrite, configDir string) error {
data, err := yaml.Marshal(cfg)
if err != nil {
return err
}
return ioutil.WriteFile(filepath.Join(configDir, ConfigFile), data, 0666)
}
// NewMatrixKey generates a new ed25519 matrix server key and writes it to a file. // NewMatrixKey generates a new ed25519 matrix server key and writes it to a file.
func NewMatrixKey(matrixKeyPath string) (err error) { func NewMatrixKey(matrixKeyPath string) (err error) {
var data [35]byte var data [35]byte

View File

@ -15,7 +15,6 @@
package test package test
import ( import (
"crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
@ -35,12 +34,6 @@ var (
PresetTrustedPrivateChat Preset = 3 PresetTrustedPrivateChat Preset = 3
roomIDCounter = int64(0) roomIDCounter = int64(0)
testKeyID = gomatrixserverlib.KeyID("ed25519:test")
testPrivateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
) )
type Room struct { type Room struct {
@ -49,22 +42,25 @@ type Room struct {
preset Preset preset Preset
creator *User creator *User
authEvents gomatrixserverlib.AuthEvents authEvents gomatrixserverlib.AuthEvents
events []*gomatrixserverlib.HeaderedEvent currentState map[string]*gomatrixserverlib.HeaderedEvent
events []*gomatrixserverlib.HeaderedEvent
} }
// Create a new test room. Automatically creates the initial create events. // Create a new test room. Automatically creates the initial create events.
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
t.Helper() t.Helper()
counter := atomic.AddInt64(&roomIDCounter, 1) counter := atomic.AddInt64(&roomIDCounter, 1)
if creator.srvName == "" {
// set defaults then let roomModifiers override t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator)
}
r := &Room{ r := &Room{
ID: fmt.Sprintf("!%d:localhost", counter), ID: fmt.Sprintf("!%d:%s", counter, creator.srvName),
creator: creator, creator: creator,
authEvents: gomatrixserverlib.NewAuthEvents(nil), authEvents: gomatrixserverlib.NewAuthEvents(nil),
preset: PresetPublicChat, preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9, Version: gomatrixserverlib.RoomVersionV9,
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
} }
for _, m := range modifiers { for _, m := range modifiers {
m(t, r) m(t, r)
@ -73,6 +69,24 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
return r return r
} }
func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference {
t.Helper()
a, err := needed.AuthEventReferences(&r.authEvents)
if err != nil {
t.Fatalf("MustGetAuthEvents: %v", err)
}
return a
}
func (r *Room) ForwardExtremities() []string {
if len(r.events) == 0 {
return nil
}
return []string{
r.events[len(r.events)-1].EventID(),
}
}
func (r *Room) insertCreateEvents(t *testing.T) { func (r *Room) insertCreateEvents(t *testing.T) {
t.Helper() t.Helper()
var joinRule gomatrixserverlib.JoinRuleContent var joinRule gomatrixserverlib.JoinRuleContent
@ -88,6 +102,7 @@ func (r *Room) insertCreateEvents(t *testing.T) {
joinRule.JoinRule = "public" joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared" hisVis.HistoryVisibility = "shared"
} }
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
"creator": r.creator.ID, "creator": r.creator.ID,
"room_version": r.Version, "room_version": r.Version,
@ -112,16 +127,16 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
} }
if mod.privKey == nil { if mod.privKey == nil {
mod.privKey = testPrivateKey mod.privKey = creator.privKey
} }
if mod.keyID == "" { if mod.keyID == "" {
mod.keyID = testKeyID mod.keyID = creator.keyID
} }
if mod.originServerTS.IsZero() { if mod.originServerTS.IsZero() {
mod.originServerTS = time.Now() mod.originServerTS = time.Now()
} }
if mod.origin == "" { if mod.origin == "" {
mod.origin = gomatrixserverlib.ServerName("localhost") mod.origin = creator.srvName
} }
var unsigned gomatrixserverlib.RawJSON var unsigned gomatrixserverlib.RawJSON
@ -174,13 +189,14 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
// Add a new event to this room DAG. Not thread-safe. // Add a new event to this room DAG. Not thread-safe.
func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) { func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
t.Helper() t.Helper()
// Add the event to the list of auth events // Add the event to the list of auth/state events
r.events = append(r.events, he) r.events = append(r.events, he)
if he.StateKey() != nil { if he.StateKey() != nil {
err := r.authEvents.AddEvent(he.Unwrap()) err := r.authEvents.AddEvent(he.Unwrap())
if err != nil { if err != nil {
t.Fatalf("InsertEvent: failed to add event to auth events: %s", err) t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
} }
r.currentState[he.Type()+" "+*he.StateKey()] = he
} }
} }
@ -188,6 +204,16 @@ func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
return r.events return r.events
} }
func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent {
events := make([]*gomatrixserverlib.HeaderedEvent, len(r.currentState))
i := 0
for _, e := range r.currentState {
events[i] = e
i++
}
return events
}
func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent { func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
t.Helper() t.Helper()
he := r.CreateEvent(t, creator, eventType, content, mods...) he := r.CreateEvent(t, creator, eventType, content, mods...)

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package test package testrig
import ( import (
"errors" "errors"
@ -24,22 +24,23 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
) )
func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()) { func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) {
var cfg config.Dendrite var cfg config.Dendrite
cfg.Defaults(false) cfg.Defaults(false)
cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.InMemory = true
switch dbType { switch dbType {
case DBTypePostgres: case test.DBTypePostgres:
cfg.Global.Defaults(true) // autogen a signing key cfg.Global.Defaults(true) // autogen a signing key
cfg.MediaAPI.Defaults(true) // autogen a media path cfg.MediaAPI.Defaults(true) // autogen a media path
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :( // the file system event with InMemory=true :(
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
connStr, close := PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
cfg.Global.DatabaseOptions = config.DatabaseOptions{ cfg.Global.DatabaseOptions = config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
MaxOpenConnections: 10, MaxOpenConnections: 10,
@ -47,7 +48,7 @@ func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()
ConnMaxLifetimeSeconds: 60, ConnMaxLifetimeSeconds: 60,
} }
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close
case DBTypeSQLite: case test.DBTypeSQLite:
cfg.Defaults(true) // sets a sqlite db per component cfg.Defaults(true) // sets a sqlite db per component
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :( // the file system event with InMemory=true :(

View File

@ -1,4 +1,4 @@
package test package testrig
import ( import (
"encoding/json" "encoding/json"

View File

@ -15,22 +15,64 @@
package test package test
import ( import (
"crypto/ed25519"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
"testing"
"github.com/matrix-org/gomatrixserverlib"
) )
var ( var (
userIDCounter = int64(0) userIDCounter = int64(0)
serverName = gomatrixserverlib.ServerName("test")
keyID = gomatrixserverlib.KeyID("ed25519:test")
privateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
// private keys that tests can use
PrivateKeyA = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 77,
})
PrivateKeyB = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 66,
})
) )
type User struct { type User struct {
ID string ID string
// key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
srvName gomatrixserverlib.ServerName
} }
func NewUser() *User { type UserOpt func(*User)
counter := atomic.AddInt64(&userIDCounter, 1)
u := &User{ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt {
ID: fmt.Sprintf("@%d:localhost", counter), return func(u *User) {
u.keyID = keyID
u.privKey = privKey
u.srvName = srvName
} }
return u }
func NewUser(t *testing.T, opts ...UserOpt) *User {
counter := atomic.AddInt64(&userIDCounter, 1)
var u User
for _, opt := range opts {
opt(&u)
}
if u.keyID == "" || u.srvName == "" || u.privKey == nil {
t.Logf("NewUser: missing signing server credentials; using default.")
WithSigningServer(serverName, keyID, privateKey)(&u)
}
u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName)
t.Logf("NewUser: created user %s", u.ID)
return &u
} }

View File

@ -43,7 +43,7 @@ func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -74,7 +74,7 @@ func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -128,7 +128,7 @@ func Test_Accounts(t *testing.T) {
} }
func Test_Devices(t *testing.T) { func Test_Devices(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
deviceID := util.RandomString(8) deviceID := util.RandomString(8)
@ -212,7 +212,7 @@ func Test_Devices(t *testing.T) {
} }
func Test_KeyBackup(t *testing.T) { func Test_KeyBackup(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -291,7 +291,7 @@ func Test_KeyBackup(t *testing.T) {
} }
func Test_LoginToken(t *testing.T) { func Test_LoginToken(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
@ -321,7 +321,7 @@ func Test_LoginToken(t *testing.T) {
} }
func Test_OpenID(t *testing.T) { func Test_OpenID(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
token := util.RandomString(24) token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -341,7 +341,7 @@ func Test_OpenID(t *testing.T) {
} }
func Test_Profile(t *testing.T) { func Test_Profile(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -379,7 +379,7 @@ func Test_Profile(t *testing.T) {
} }
func Test_Pusher(t *testing.T) { func Test_Pusher(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -430,7 +430,7 @@ func Test_Pusher(t *testing.T) {
} }
func Test_ThreePID(t *testing.T) { func Test_ThreePID(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
@ -467,7 +467,7 @@ func Test_ThreePID(t *testing.T) {
} }
func Test_Notification(t *testing.T) { func Test_Notification(t *testing.T) {
alice := test.NewUser() alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)

View File

@ -24,7 +24,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
internalTest "github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
@ -135,7 +134,7 @@ func TestQueryProfile(t *testing.T) {
t.Run("HTTP API", func(t *testing.T) { t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
userapi.AddInternalRoutes(router, userAPI) userapi.AddInternalRoutes(router, userAPI)
apiURL, cancel := internalTest.ListenAndServe(t, router, false) apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel() defer cancel()
httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{})
if err != nil { if err != nil {