mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-13 23:31:34 +00:00
Add context.Context to the federation client (#225)
* Add context.Context to the federation client * gb vendor update github.com/matrix-org/gomatrixserverlib
This commit is contained in:
parent
086683459f
commit
029e71828a
@ -66,7 +66,7 @@ func DirectoryRoom(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resp, err = federation.LookupRoomAlias(domain, roomAlias)
|
||||
resp, err = federation.LookupRoomAlias(req.Context(), domain, roomAlias)
|
||||
if err != nil {
|
||||
switch x := err.(type) {
|
||||
case gomatrix.HTTPError:
|
||||
|
@ -136,7 +136,7 @@ func (r joinRoomReq) joinRoomByAlias(roomAlias string) util.JSONResponse {
|
||||
func (r joinRoomReq) joinRoomByRemoteAlias(
|
||||
domain gomatrixserverlib.ServerName, roomAlias string,
|
||||
) util.JSONResponse {
|
||||
resp, err := r.federation.LookupRoomAlias(domain, roomAlias)
|
||||
resp, err := r.federation.LookupRoomAlias(r.req.Context(), domain, roomAlias)
|
||||
if err != nil {
|
||||
switch x := err.(type) {
|
||||
case gomatrix.HTTPError:
|
||||
@ -226,7 +226,7 @@ func (r joinRoomReq) joinRoomUsingServers(
|
||||
// server was invalid this returns an error.
|
||||
// Otherwise this returns a JSONResponse.
|
||||
func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib.ServerName) (*util.JSONResponse, error) {
|
||||
respMakeJoin, err := r.federation.MakeJoin(server, roomID, r.userID)
|
||||
respMakeJoin, err := r.federation.MakeJoin(r.req.Context(), server, roomID, r.userID)
|
||||
if err != nil {
|
||||
// TODO: Check if the user was not allowed to join the room.
|
||||
return nil, err
|
||||
@ -246,12 +246,12 @@ func (r joinRoomReq) joinRoomUsingServer(roomID string, server gomatrixserverlib
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
respSendJoin, err := r.federation.SendJoin(server, event)
|
||||
respSendJoin, err := r.federation.SendJoin(r.req.Context(), server, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = respSendJoin.Check(r.keyRing, event); err != nil {
|
||||
if err = respSendJoin.Check(r.req.Context(), r.keyRing, event); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,9 @@
|
||||
package keydb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
@ -41,6 +43,7 @@ func NewDatabase(dataSourceName string) (*Database, error) {
|
||||
|
||||
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||
func (d *Database) FetchKeys(
|
||||
ctx context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
||||
return d.statements.bulkSelectServerKeys(requests)
|
||||
@ -48,6 +51,7 @@ func (d *Database) FetchKeys(
|
||||
|
||||
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||
func (d *Database) StoreKeys(
|
||||
ctx context.Context,
|
||||
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys,
|
||||
) error {
|
||||
// TODO: Inserting all the keys within a single transaction may
|
||||
|
@ -76,7 +76,7 @@ func Invite(
|
||||
Message: event.Redact().JSON(),
|
||||
AtTS: event.OriginServerTS(),
|
||||
}}
|
||||
verifyResults, err := keys.VerifyJSONs(verifyRequests)
|
||||
verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
|
||||
if err != nil {
|
||||
return httputil.LogThenError(httpReq, err)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package writers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -41,6 +42,7 @@ func Send(
|
||||
) util.JSONResponse {
|
||||
|
||||
t := txnReq{
|
||||
context: httpReq.Context(),
|
||||
query: query,
|
||||
producer: producer,
|
||||
keys: keys,
|
||||
@ -70,6 +72,7 @@ func Send(
|
||||
|
||||
type txnReq struct {
|
||||
gomatrixserverlib.Transaction
|
||||
context context.Context
|
||||
query api.RoomserverQueryAPI
|
||||
producer *producers.RoomserverProducer
|
||||
keys gomatrixserverlib.KeyRing
|
||||
@ -78,7 +81,7 @@ type txnReq struct {
|
||||
|
||||
func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
|
||||
// Check the event signatures
|
||||
if err := gomatrixserverlib.VerifyEventSignatures(t.PDUs, t.keys); err != nil {
|
||||
if err := gomatrixserverlib.VerifyEventSignatures(t.context, t.PDUs, t.keys); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -110,7 +113,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
|
||||
// our server so we should bail processing the transaction entirely.
|
||||
return nil, err
|
||||
}
|
||||
results[e.EventID()] = gomatrixserverlib.PDUResult{err.Error()}
|
||||
results[e.EventID()] = gomatrixserverlib.PDUResult{
|
||||
Error: err.Error(),
|
||||
}
|
||||
} else {
|
||||
results[e.EventID()] = gomatrixserverlib.PDUResult{}
|
||||
}
|
||||
@ -197,12 +202,12 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event) error {
|
||||
// need to fallback to /state.
|
||||
// TODO: Attempt to fill in the gap using /get_missing_events
|
||||
// TODO: Attempt to fetch the state using /state_ids and /events
|
||||
state, err := t.federation.LookupState(t.Origin, e.RoomID(), e.EventID())
|
||||
state, err := t.federation.LookupState(t.context, t.Origin, e.RoomID(), e.EventID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check that the returned state is valid.
|
||||
if err := state.Check(t.keys); err != nil {
|
||||
if err := state.Check(t.context, t.keys); err != nil {
|
||||
return err
|
||||
}
|
||||
// Check that the event is allowed by the state.
|
||||
|
@ -15,6 +15,7 @@
|
||||
package writers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -63,7 +64,9 @@ func CreateInvitesFrom3PIDInvites(
|
||||
|
||||
evs := []gomatrixserverlib.Event{}
|
||||
for _, inv := range body.Invites {
|
||||
event, err := createInviteFrom3PIDInvite(queryAPI, cfg, inv, federation)
|
||||
event, err := createInviteFrom3PIDInvite(
|
||||
req.Context(), queryAPI, cfg, inv, federation,
|
||||
)
|
||||
if err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
@ -139,7 +142,7 @@ func ExchangeThirdPartyInvite(
|
||||
|
||||
// Ask the requesting server to sign the newly created event so we know it
|
||||
// acknowledged it
|
||||
signedEvent, err := federation.SendInvite(request.Origin(), *event)
|
||||
signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), *event)
|
||||
if err != nil {
|
||||
return httputil.LogThenError(httpReq, err)
|
||||
}
|
||||
@ -160,8 +163,8 @@ func ExchangeThirdPartyInvite(
|
||||
// Returns an error if there was a problem building the event or fetching the
|
||||
// necessary data to do so.
|
||||
func createInviteFrom3PIDInvite(
|
||||
queryAPI api.RoomserverQueryAPI, cfg config.Dendrite, inv invite,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
ctx context.Context, queryAPI api.RoomserverQueryAPI, cfg config.Dendrite,
|
||||
inv invite, federation *gomatrixserverlib.FederationClient,
|
||||
) (*gomatrixserverlib.Event, error) {
|
||||
// Build the event
|
||||
builder := &gomatrixserverlib.EventBuilder{
|
||||
@ -185,7 +188,10 @@ func createInviteFrom3PIDInvite(
|
||||
|
||||
event, err := buildMembershipEvent(builder, queryAPI, cfg)
|
||||
if err == errNotInRoom {
|
||||
return nil, sendToRemoteServer(inv, federation, cfg, *builder)
|
||||
return nil, sendToRemoteServer(ctx, inv, federation, cfg, *builder)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return event, nil
|
||||
@ -253,7 +259,8 @@ func buildMembershipEvent(
|
||||
// Returns an error if it couldn't get the server names to reach or if all of
|
||||
// them responded with an error.
|
||||
func sendToRemoteServer(
|
||||
inv invite, federation *gomatrixserverlib.FederationClient, cfg config.Dendrite,
|
||||
ctx context.Context, inv invite,
|
||||
federation *gomatrixserverlib.FederationClient, cfg config.Dendrite,
|
||||
builder gomatrixserverlib.EventBuilder,
|
||||
) (err error) {
|
||||
remoteServers := make([]gomatrixserverlib.ServerName, 2)
|
||||
@ -269,7 +276,7 @@ func sendToRemoteServer(
|
||||
}
|
||||
|
||||
for _, server := range remoteServers {
|
||||
err = federation.ExchangeThirdPartyInvite(server, builder)
|
||||
err = federation.ExchangeThirdPartyInvite(ctx, server, builder)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
@ -65,7 +66,7 @@ func (oq *destinationQueue) backgroundSend() {
|
||||
// TODO: handle retries.
|
||||
// TODO: blacklist uncooperative servers.
|
||||
|
||||
_, err := oq.client.SendTransaction(*t)
|
||||
_, err := oq.client.SendTransaction(context.TODO(), *t)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"destination": oq.destination,
|
||||
|
2
vendor/manifest
vendored
2
vendor/manifest
vendored
@ -116,7 +116,7 @@
|
||||
{
|
||||
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
||||
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
||||
"revision": "790f02e8f465552dab4317ffe7ca047ccb594cbf",
|
||||
"revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e",
|
||||
"branch": "master"
|
||||
},
|
||||
{
|
||||
|
@ -17,6 +17,7 @@ package gomatrixserverlib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -103,7 +104,9 @@ func (f *federationTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
|
||||
// LookupUserInfo gets information about a user from a given matrix homeserver
|
||||
// using a bearer access token.
|
||||
func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserInfo, err error) {
|
||||
func (fc *Client) LookupUserInfo(
|
||||
ctx context.Context, matrixServer ServerName, token string,
|
||||
) (u UserInfo, err error) {
|
||||
url := url.URL{
|
||||
Scheme: "matrix",
|
||||
Host: string(matrixServer),
|
||||
@ -111,8 +114,13 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
|
||||
RawQuery: url.Values{"access_token": []string{token}}.Encode(),
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", url.String(), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var response *http.Response
|
||||
response, err = fc.client.Get(url.String())
|
||||
response, err = fc.client.Do(req.WithContext(ctx))
|
||||
if response != nil {
|
||||
defer response.Body.Close() // nolint: errcheck
|
||||
}
|
||||
@ -153,7 +161,7 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
|
||||
// copy of the keys.
|
||||
// Returns the keys or an error if there was a problem talking to the server.
|
||||
func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
||||
matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
|
||||
ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
|
||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
url := url.URL{
|
||||
Scheme: "matrix",
|
||||
@ -183,7 +191,13 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := fc.client.Post(url.String(), "application/json", bytes.NewBuffer(requestBytes))
|
||||
req, err := http.NewRequest("POST", url.String(), bytes.NewBuffer(requestBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
response, err := fc.client.Do(req.WithContext(ctx))
|
||||
if response != nil {
|
||||
defer response.Body.Close() // nolint: errcheck
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ package gomatrixserverlib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -188,7 +189,7 @@ func verifyEventSignature(signingName string, keyID KeyID, publicKey ed25519.Pub
|
||||
|
||||
// VerifyEventSignatures checks that each event in a list of events has valid
|
||||
// signatures from the server that sent it.
|
||||
func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: gocyclo
|
||||
func VerifyEventSignatures(ctx context.Context, events []Event, keyRing KeyRing) error { // nolint: gocyclo
|
||||
var toVerify []VerifyJSONRequest
|
||||
for _, event := range events {
|
||||
redactedJSON, err := redactEvent(event.eventJSON)
|
||||
@ -222,7 +223,7 @@ func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: g
|
||||
}
|
||||
}
|
||||
|
||||
results, err := keyRing.VerifyJSONs(toVerify)
|
||||
results, err := keyRing.VerifyJSONs(ctx, toVerify)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package gomatrixserverlib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@ -31,7 +32,7 @@ func NewFederationClient(
|
||||
}
|
||||
}
|
||||
|
||||
func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{}) error {
|
||||
func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest, resBody interface{}) error {
|
||||
if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -41,7 +42,7 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := ac.client.Do(req)
|
||||
res, err := ac.client.Do(req.WithContext(ctx))
|
||||
if res != nil {
|
||||
defer res.Body.Close() // nolint: errcheck
|
||||
}
|
||||
@ -87,13 +88,15 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
|
||||
var federationPathPrefix = "/_matrix/federation/v1"
|
||||
|
||||
// SendTransaction sends a transaction
|
||||
func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err error) {
|
||||
func (ac *FederationClient) SendTransaction(
|
||||
ctx context.Context, t Transaction,
|
||||
) (res RespSend, err error) {
|
||||
path := federationPathPrefix + "/send/" + string(t.TransactionID) + "/"
|
||||
req := NewFederationRequest("PUT", t.Destination, path)
|
||||
if err = req.SetContent(t); err != nil {
|
||||
return
|
||||
}
|
||||
err = ac.doRequest(req, &res)
|
||||
err = ac.doRequest(ctx, req, &res)
|
||||
return
|
||||
}
|
||||
|
||||
@ -106,12 +109,14 @@ func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err er
|
||||
// If this successfully returns an acceptable event we will sign it with our
|
||||
// server's key and pass it to SendJoin.
|
||||
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
|
||||
func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res RespMakeJoin, err error) {
|
||||
func (ac *FederationClient) MakeJoin(
|
||||
ctx context.Context, s ServerName, roomID, userID string,
|
||||
) (res RespMakeJoin, err error) {
|
||||
path := federationPathPrefix + "/make_join/" +
|
||||
url.PathEscape(roomID) + "/" +
|
||||
url.PathEscape(userID)
|
||||
req := NewFederationRequest("GET", s, path)
|
||||
err = ac.doRequest(req, &res)
|
||||
err = ac.doRequest(ctx, req, &res)
|
||||
return
|
||||
}
|
||||
|
||||
@ -119,7 +124,9 @@ func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res R
|
||||
// remote matrix server.
|
||||
// This is used to join a room the local server isn't a member of.
|
||||
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
|
||||
func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoin, err error) {
|
||||
func (ac *FederationClient) SendJoin(
|
||||
ctx context.Context, s ServerName, event Event,
|
||||
) (res RespSendJoin, err error) {
|
||||
path := federationPathPrefix + "/send_join/" +
|
||||
url.PathEscape(event.RoomID()) + "/" +
|
||||
url.PathEscape(event.EventID())
|
||||
@ -127,13 +134,15 @@ func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoi
|
||||
if err = req.SetContent(event); err != nil {
|
||||
return
|
||||
}
|
||||
err = ac.doRequest(req, &res)
|
||||
err = ac.doRequest(ctx, req, &res)
|
||||
return
|
||||
}
|
||||
|
||||
// SendInvite sends an invite m.room.member event to an invited server to be
|
||||
// signed by it. This is used to invite a user that is not on the local server.
|
||||
func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvite, err error) {
|
||||
func (ac *FederationClient) SendInvite(
|
||||
ctx context.Context, s ServerName, event Event,
|
||||
) (res RespInvite, err error) {
|
||||
path := federationPathPrefix + "/invite/" +
|
||||
url.PathEscape(event.RoomID()) + "/" +
|
||||
url.PathEscape(event.EventID())
|
||||
@ -141,7 +150,7 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
|
||||
if err = req.SetContent(event); err != nil {
|
||||
return
|
||||
}
|
||||
err = ac.doRequest(req, &res)
|
||||
err = ac.doRequest(ctx, req, &res)
|
||||
return
|
||||
}
|
||||
|
||||
@ -150,38 +159,44 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
|
||||
// server.
|
||||
// This is used to exchange a m.room.third_party_invite event for a m.room.member
|
||||
// one in a room the local server isn't a member of.
|
||||
func (ac *FederationClient) ExchangeThirdPartyInvite(s ServerName, builder EventBuilder) (err error) {
|
||||
func (ac *FederationClient) ExchangeThirdPartyInvite(
|
||||
ctx context.Context, s ServerName, builder EventBuilder,
|
||||
) (err error) {
|
||||
path := federationPathPrefix + "/exchange_third_party_invite/" +
|
||||
url.PathEscape(builder.RoomID)
|
||||
req := NewFederationRequest("PUT", s, path)
|
||||
if err = req.SetContent(builder); err != nil {
|
||||
return
|
||||
}
|
||||
err = ac.doRequest(req, nil)
|
||||
err = ac.doRequest(ctx, req, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// LookupState retrieves the room state for a room at an event from a
|
||||
// remote matrix server as full matrix events.
|
||||
func (ac *FederationClient) LookupState(s ServerName, roomID, eventID string) (res RespState, err error) {
|
||||
func (ac *FederationClient) LookupState(
|
||||
ctx context.Context, s ServerName, roomID, eventID string,
|
||||
) (res RespState, err error) {
|
||||
path := federationPathPrefix + "/state/" +
|
||||
url.PathEscape(roomID) +
|
||||
"/?event_id=" +
|
||||
url.QueryEscape(eventID)
|
||||
req := NewFederationRequest("GET", s, path)
|
||||
err = ac.doRequest(req, &res)
|
||||
err = ac.doRequest(ctx, req, &res)
|
||||
return
|
||||
}
|
||||
|
||||
// LookupStateIDs retrieves the room state for a room at an event from a
|
||||
// remote matrix server as lists of matrix event IDs.
|
||||
func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string) (res RespStateIDs, err error) {
|
||||
func (ac *FederationClient) LookupStateIDs(
|
||||
ctx context.Context, s ServerName, roomID, eventID string,
|
||||
) (res RespStateIDs, err error) {
|
||||
path := federationPathPrefix + "/state_ids/" +
|
||||
url.PathEscape(roomID) +
|
||||
"/?event_id=" +
|
||||
url.QueryEscape(eventID)
|
||||
req := NewFederationRequest("GET", s, path)
|
||||
err = ac.doRequest(req, &res)
|
||||
err = ac.doRequest(ctx, req, &res)
|
||||
return
|
||||
}
|
||||
|
||||
@ -190,10 +205,12 @@ func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string)
|
||||
// being looked up on.
|
||||
// If the room alias doesn't exist on the remote server then a 404 gomatrix.HTTPError
|
||||
// is returned.
|
||||
func (ac *FederationClient) LookupRoomAlias(s ServerName, roomAlias string) (res RespDirectory, err error) {
|
||||
func (ac *FederationClient) LookupRoomAlias(
|
||||
ctx context.Context, s ServerName, roomAlias string,
|
||||
) (res RespDirectory, err error) {
|
||||
path := federationPathPrefix + "/query/directory?room_alias=" +
|
||||
url.QueryEscape(roomAlias)
|
||||
req := NewFederationRequest("GET", s, path)
|
||||
err = ac.doRequest(req, &res)
|
||||
err = ac.doRequest(ctx, req, &res)
|
||||
return
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package gomatrixserverlib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
@ -107,7 +108,7 @@ func (r RespState) Events() ([]Event, error) {
|
||||
}
|
||||
|
||||
// Check that a response to /state is valid.
|
||||
func (r RespState) Check(keyRing KeyRing) error {
|
||||
func (r RespState) Check(ctx context.Context, keyRing KeyRing) error {
|
||||
var allEvents []Event
|
||||
for _, event := range r.AuthEvents {
|
||||
if event.StateKey() == nil {
|
||||
@ -133,7 +134,7 @@ func (r RespState) Check(keyRing KeyRing) error {
|
||||
}
|
||||
|
||||
// Check if the events pass signature checks.
|
||||
if err := VerifyEventSignatures(allEvents, keyRing); err != nil {
|
||||
if err := VerifyEventSignatures(ctx, allEvents, keyRing); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -213,11 +214,11 @@ type respSendJoinFields struct {
|
||||
// Check that a response to /send_join is valid.
|
||||
// This checks that it would be valid as a response to /state
|
||||
// This also checks that the join event is allowed by the state.
|
||||
func (r RespSendJoin) Check(keyRing KeyRing, joinEvent Event) error {
|
||||
func (r RespSendJoin) Check(ctx context.Context, keyRing KeyRing, joinEvent Event) error {
|
||||
// First check that the state is valid.
|
||||
// The response to /send_join has the same data as a response to /state
|
||||
// and the checks for a response to /state also apply.
|
||||
if err := RespState(r).Check(keyRing); err != nil {
|
||||
if err := RespState(r).Check(ctx, keyRing); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -6,13 +6,13 @@ echo "Installing lint search engine..."
|
||||
go get github.com/alecthomas/gometalinter/
|
||||
gometalinter --config=linter.json --install --update
|
||||
|
||||
echo "Testing..."
|
||||
go test
|
||||
|
||||
echo "Looking for lint..."
|
||||
gometalinter --config=linter.json
|
||||
|
||||
echo "Double checking spelling..."
|
||||
misspell -error src *.md
|
||||
|
||||
echo "Testing..."
|
||||
go test
|
||||
|
||||
echo "Done!"
|
||||
|
@ -1,6 +1,7 @@
|
||||
package gomatrixserverlib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@ -26,7 +27,7 @@ type KeyFetcher interface {
|
||||
// The result may have fewer (server name, key ID) pairs than were in the request.
|
||||
// The result may have more (server name, key ID) pairs than were in the request.
|
||||
// Returns an error if there was a problem fetching the keys.
|
||||
FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
|
||||
FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
|
||||
}
|
||||
|
||||
// A KeyDatabase is a store for caching public keys.
|
||||
@ -39,7 +40,7 @@ type KeyDatabase interface {
|
||||
// to a concurrent FetchKeys(). This is acceptable since the database is
|
||||
// only used as a cache for the keys, so if a FetchKeys() races with a
|
||||
// StoreKeys() and some of the keys are missing they will be just be refetched.
|
||||
StoreKeys(map[PublicKeyRequest]ServerKeys) error
|
||||
StoreKeys(ctx context.Context, results map[PublicKeyRequest]ServerKeys) error
|
||||
}
|
||||
|
||||
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
|
||||
@ -73,7 +74,7 @@ type VerifyJSONResult struct {
|
||||
// The caller should check the Result field for each entry to see if it was valid.
|
||||
// Returns an error if there was a problem talking to the database or one of the other methods
|
||||
// of fetching the public keys.
|
||||
func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
|
||||
func (k *KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
|
||||
results := make([]VerifyJSONResult, len(requests))
|
||||
keyIDs := make([][]KeyID, len(requests))
|
||||
|
||||
@ -109,7 +110,7 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
|
||||
// This will happen if all the objects are missing supported signatures.
|
||||
return results, nil
|
||||
}
|
||||
keysFromDatabase, err := k.KeyDatabase.FetchKeys(keyRequests)
|
||||
keysFromDatabase, err := k.KeyDatabase.FetchKeys(ctx, keyRequests)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -124,14 +125,14 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
|
||||
}
|
||||
// TODO: Coalesce in-flight requests for the same keys.
|
||||
// Otherwise we risk spamming the servers we query the keys from.
|
||||
keysFetched, err := k.KeyFetchers[i].FetchKeys(keyRequests)
|
||||
keysFetched, err := k.KeyFetchers[i].FetchKeys(ctx, keyRequests)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k.checkUsingKeys(requests, results, keyIDs, keysFetched)
|
||||
|
||||
// Add the keys to the database so that we won't need to fetch them again.
|
||||
if err := k.KeyDatabase.StoreKeys(keysFetched); err != nil {
|
||||
if err := k.KeyDatabase.StoreKeys(ctx, keysFetched); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -143,7 +144,9 @@ func (k *KeyRing) isAlgorithmSupported(keyID KeyID) bool {
|
||||
return strings.HasPrefix(string(keyID), "ed25519:")
|
||||
}
|
||||
|
||||
func (k *KeyRing) publicKeyRequests(requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID) map[PublicKeyRequest]Timestamp {
|
||||
func (k *KeyRing) publicKeyRequests(
|
||||
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
|
||||
) map[PublicKeyRequest]Timestamp {
|
||||
keyRequests := map[PublicKeyRequest]Timestamp{}
|
||||
for i := range requests {
|
||||
if results[i].Error == nil {
|
||||
@ -218,8 +221,10 @@ type PerspectiveKeyFetcher struct {
|
||||
}
|
||||
|
||||
// FetchKeys implements KeyFetcher
|
||||
func (p *PerspectiveKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
results, err := p.Client.LookupServerKeys(p.PerspectiveServerName, requests)
|
||||
func (p *PerspectiveKeyFetcher) FetchKeys(
|
||||
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
results, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -269,7 +274,9 @@ type DirectKeyFetcher struct {
|
||||
}
|
||||
|
||||
// FetchKeys implements KeyFetcher
|
||||
func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
func (d *DirectKeyFetcher) FetchKeys(
|
||||
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{}
|
||||
for req, ts := range requests {
|
||||
server := byServer[req.ServerName]
|
||||
@ -283,7 +290,7 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
|
||||
results := map[PublicKeyRequest]ServerKeys{}
|
||||
for server, reqs := range byServer {
|
||||
// TODO: make these requests in parallel
|
||||
serverResults, err := d.fetchKeysForServer(server, reqs)
|
||||
serverResults, err := d.fetchKeysForServer(ctx, server, reqs)
|
||||
if err != nil {
|
||||
// TODO: Should we actually be erroring here? or should we just drop those keys from the result map?
|
||||
return nil, err
|
||||
@ -296,9 +303,9 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
|
||||
}
|
||||
|
||||
func (d *DirectKeyFetcher) fetchKeysForServer(
|
||||
serverName ServerName, requests map[PublicKeyRequest]Timestamp,
|
||||
ctx context.Context, serverName ServerName, requests map[PublicKeyRequest]Timestamp,
|
||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
results, err := d.Client.LookupServerKeys(serverName, requests)
|
||||
results, err := d.Client.LookupServerKeys(ctx, serverName, requests)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package gomatrixserverlib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
@ -36,7 +37,9 @@ var testKeys = `{
|
||||
|
||||
type testKeyDatabase struct{}
|
||||
|
||||
func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
func (db *testKeyDatabase) FetchKeys(
|
||||
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
results := map[PublicKeyRequest]ServerKeys{}
|
||||
var keys ServerKeys
|
||||
if err := json.Unmarshal([]byte(testKeys), &keys); err != nil {
|
||||
@ -54,14 +57,16 @@ func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (db *testKeyDatabase) StoreKeys(requests map[PublicKeyRequest]ServerKeys) error {
|
||||
func (db *testKeyDatabase) StoreKeys(
|
||||
ctx context.Context, requests map[PublicKeyRequest]ServerKeys,
|
||||
) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestVerifyJSONsSuccess(t *testing.T) {
|
||||
// Check that trying to verify the server key JSON works.
|
||||
k := KeyRing{nil, &testKeyDatabase{}}
|
||||
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
|
||||
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
|
||||
ServerName: "localhost:8800",
|
||||
Message: []byte(testKeys),
|
||||
AtTS: 1493142432964,
|
||||
@ -77,7 +82,7 @@ func TestVerifyJSONsSuccess(t *testing.T) {
|
||||
func TestVerifyJSONsUnknownServerFails(t *testing.T) {
|
||||
// Check that trying to verify JSON for an unknown server fails.
|
||||
k := KeyRing{nil, &testKeyDatabase{}}
|
||||
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
|
||||
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
|
||||
ServerName: "unknown:8800",
|
||||
Message: []byte(testKeys),
|
||||
AtTS: 1493142432964,
|
||||
@ -94,7 +99,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
|
||||
// Check that trying to verify JSON from the distant future fails.
|
||||
distantFuture := Timestamp(2000000000000)
|
||||
k := KeyRing{nil, &testKeyDatabase{}}
|
||||
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
|
||||
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
|
||||
ServerName: "unknown:8800",
|
||||
Message: []byte(testKeys),
|
||||
AtTS: distantFuture,
|
||||
@ -110,7 +115,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
|
||||
func TestVerifyJSONsFetcherError(t *testing.T) {
|
||||
// Check that if the database errors then the attempt to verify JSON fails.
|
||||
k := KeyRing{nil, &erroringKeyDatabase{}}
|
||||
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
|
||||
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
|
||||
ServerName: "localhost:8800",
|
||||
Message: []byte(testKeys),
|
||||
AtTS: 1493142432964,
|
||||
@ -129,10 +134,14 @@ func (e *erroringKeyDatabaseError) Error() string { return "An error with the ke
|
||||
var testErrorFetch = erroringKeyDatabaseError(1)
|
||||
var testErrorStore = erroringKeyDatabaseError(2)
|
||||
|
||||
func (e *erroringKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
func (e *erroringKeyDatabase) FetchKeys(
|
||||
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
||||
return nil, &testErrorFetch
|
||||
}
|
||||
|
||||
func (e *erroringKeyDatabase) StoreKeys(keys map[PublicKeyRequest]ServerKeys) error {
|
||||
func (e *erroringKeyDatabase) StoreKeys(
|
||||
ctx context.Context, keys map[PublicKeyRequest]ServerKeys,
|
||||
) error {
|
||||
return &testErrorStore
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
{
|
||||
"Deadline": "5m",
|
||||
"Enable": [
|
||||
"vet",
|
||||
"vetshadow",
|
||||
|
@ -215,7 +215,7 @@ func VerifyHTTPRequest(
|
||||
return nil, util.MessageResponse(401, message)
|
||||
}
|
||||
|
||||
results, err := keys.VerifyJSONs([]VerifyJSONRequest{{
|
||||
results, err := keys.VerifyJSONs(req.Context(), []VerifyJSONRequest{{
|
||||
ServerName: request.Origin(),
|
||||
AtTS: AsTimestamp(now),
|
||||
Message: toVerify,
|
||||
|
Loading…
Reference in New Issue
Block a user