Add API for querying events by ID. (#127)

* Add API for querying events by ID.

* Fix tense

* Start implementing federation ingress

* More stuff

* Hook up federation event receiving

* Fix comments

* Comment on the order of the arrays
This commit is contained in:
Mark Haines 2017-06-02 11:19:34 +01:00 committed by GitHub
parent 3b498c8074
commit ce82158abb
8 changed files with 501 additions and 6 deletions

View File

@ -18,11 +18,14 @@ import (
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"os" "os"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationapi/config" "github.com/matrix-org/dendrite/federationapi/config"
"github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
@ -40,7 +43,10 @@ var (
// openssl x509 -noout -fingerprint -sha256 -inform pem -in server.crt |\ // openssl x509 -noout -fingerprint -sha256 -inform pem -in server.crt |\
// python -c 'print raw_input()[19:].replace(":","").decode("hex").encode("base64").rstrip("=\n")' // python -c 'print raw_input()[19:].replace(":","").decode("hex").encode("base64").rstrip("=\n")'
// //
tlsFingerprint = os.Getenv("TLS_FINGERPRINT") tlsFingerprint = os.Getenv("TLS_FINGERPRINT")
kafkaURIs = strings.Split(os.Getenv("KAFKA_URIS"), ",")
roomserverURL = os.Getenv("ROOMSERVER_URL")
roomserverInputTopic = os.Getenv("TOPIC_INPUT_ROOM_EVENT")
) )
func main() { func main() {
@ -57,6 +63,18 @@ func main() {
log.Panic("No TLS_FINGERPRINT environment variable found.") log.Panic("No TLS_FINGERPRINT environment variable found.")
} }
if len(kafkaURIs) == 0 {
// the kafka default is :9092
kafkaURIs = []string{"localhost:9092"}
}
if roomserverURL == "" {
log.Panic("No ROOMSERVER_URL environment variable found.")
}
if roomserverInputTopic == "" {
log.Panic("No TOPIC_INPUT_ROOM_EVENT environment variable found. This should match the roomserver input topic.")
}
cfg := config.FederationAPI{ cfg := config.FederationAPI{
ServerName: serverName, ServerName: serverName,
// TODO: make the validity period configurable. // TODO: make the validity period configurable.
@ -75,6 +93,37 @@ func main() {
} }
cfg.TLSFingerPrints = []gomatrixserverlib.TLSFingerprint{{fingerprintSHA256}} cfg.TLSFingerPrints = []gomatrixserverlib.TLSFingerprint{{fingerprintSHA256}}
routing.Setup(http.DefaultServeMux, cfg) federation := gomatrixserverlib.NewFederationClient(cfg.ServerName, cfg.KeyID, cfg.PrivateKey)
keyRing := gomatrixserverlib.KeyRing{
KeyFetchers: []gomatrixserverlib.KeyFetcher{
// TODO: Use perspective key fetchers for production.
&gomatrixserverlib.DirectKeyFetcher{federation.Client},
},
KeyDatabase: &dummyKeyDatabase{},
}
queryAPI := api.NewRoomserverQueryAPIHTTP(roomserverURL, nil)
roomserverProducer, err := producers.NewRoomserverProducer(kafkaURIs, roomserverInputTopic)
if err != nil {
log.Panicf("Failed to setup kafka producers(%s): %s", kafkaURIs, err)
}
routing.Setup(http.DefaultServeMux, cfg, queryAPI, roomserverProducer, keyRing)
log.Fatal(http.ListenAndServe(bindAddr, nil)) log.Fatal(http.ListenAndServe(bindAddr, nil))
} }
// TODO: Implement a proper key database.
type dummyKeyDatabase struct{}
func (d *dummyKeyDatabase) FetchKeys(
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
return nil, nil
}
func (d *dummyKeyDatabase) StoreKeys(
map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys,
) error {
return nil
}

View File

@ -0,0 +1,124 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"flag"
"fmt"
log "github.com/Sirupsen/logrus"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"time"
)
const usage = `Usage: %s
Create a single endpoint URL which remote matrix servers can be pointed at.
The server-server API in Dendrite is split across multiple processes
which listen on multiple ports. You cannot point a Matrix server at
any of those ports, as there will be unimplemented functionality.
In addition, all server-server API processes start with the additional
path prefix '/api', which Matrix servers will be unaware of.
This tool will proxy requests for all server-server URLs and forward
them to their respective process. It will also add the '/api' path
prefix to incoming requests.
THIS TOOL IS FOR TESTING AND NOT INTENDED FOR PRODUCTION USE.
Arguments:
`
var (
federationAPIURL = flag.String("federation-api-url", "", "The base URL of the listening 'dendrite-federation-api-server' process. E.g. 'http://localhost:4200'")
bindAddress = flag.String("bind-address", ":8448", "The listening port for the proxy.")
certFile = flag.String("tls-cert", "server.crt", "The X509 certificate to use for TLS")
keyFile = flag.String("tls-key", "server.key", "The PEM private key to use for TLS")
)
func makeProxy(targetURL string) (*httputil.ReverseProxy, error) {
if !strings.HasSuffix(targetURL, "/") {
targetURL += "/"
}
// Check that we can parse the URL.
_, err := url.Parse(targetURL)
if err != nil {
return nil, err
}
return &httputil.ReverseProxy{
Director: func(req *http.Request) {
// URL.Path() removes the % escaping from the path.
// The % encoding will be added back when the url is encoded
// when the request is forwarded.
// This means that we will lose any unessecary escaping from the URL.
// Pratically this means that any distinction between '%2F' and '/'
// in the URL will be lost by the time it reaches the target.
path := req.URL.Path
path = "api" + path
log.WithFields(log.Fields{
"path": path,
"url": targetURL,
"method": req.Method,
}).Print("proxying request")
newURL, err := url.Parse(targetURL + path)
if err != nil {
// We already checked that we can parse the URL
// So this shouldn't ever get hit.
panic(err)
}
// Copy the query parameters from the request.
newURL.RawQuery = req.URL.RawQuery
req.URL = newURL
},
}, nil
}
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, usage, os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
if *federationAPIURL == "" {
flag.Usage()
fmt.Fprintln(os.Stderr, "no --federation-api-url specified.")
os.Exit(1)
}
federationProxy, err := makeProxy(*federationAPIURL)
if err != nil {
panic(err)
}
http.Handle("/", federationProxy)
srv := &http.Server{
Addr: *bindAddress,
ReadTimeout: 1 * time.Minute, // how long we wait for the client to send the entire request (after connection accept)
WriteTimeout: 5 * time.Minute, // how long the proxy has to write the full response
}
fmt.Println("Proxying requests to:")
fmt.Println(" /* => ", *federationAPIURL+"/api/*")
fmt.Println("Listening on ", *bindAddress)
panic(srv.ListenAndServeTLS(*certFile, *keyFile))
}

View File

@ -16,21 +16,34 @@ package routing
import ( import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/federationapi/config" "github.com/matrix-org/dendrite/federationapi/config"
"github.com/matrix-org/dendrite/federationapi/readers" "github.com/matrix-org/dendrite/federationapi/readers"
"github.com/matrix-org/dendrite/federationapi/writers"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"net/http" "net/http"
"time"
) )
const ( const (
pathPrefixV2Keys = "/_matrix/key/v2" pathPrefixV2Keys = "/_matrix/key/v2"
pathPrefixV1Federation = "/_matrix/federation/v1"
) )
// Setup registers HTTP handlers with the given ServeMux. // Setup registers HTTP handlers with the given ServeMux.
func Setup(servMux *http.ServeMux, cfg config.FederationAPI) { func Setup(
servMux *http.ServeMux,
cfg config.FederationAPI,
query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing,
) {
apiMux := mux.NewRouter() apiMux := mux.NewRouter()
v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter() v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter()
v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter()
localKeys := makeAPI("localkeys", func(req *http.Request) util.JSONResponse { localKeys := makeAPI("localkeys", func(req *http.Request) util.JSONResponse {
return readers.LocalKeys(req, cfg) return readers.LocalKeys(req, cfg)
@ -43,6 +56,17 @@ func Setup(servMux *http.ServeMux, cfg config.FederationAPI) {
v2keysmux.Handle("/server/{keyID}", localKeys) v2keysmux.Handle("/server/{keyID}", localKeys)
v2keysmux.Handle("/server/", localKeys) v2keysmux.Handle("/server/", localKeys)
v1fedmux.Handle("/send/{txnID}/", makeAPI("send",
func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return writers.Send(
req, gomatrixserverlib.TransactionID(vars["txnID"]),
time.Now(),
cfg, query, producer, keys,
)
},
))
servMux.Handle("/metrics", prometheus.Handler()) servMux.Handle("/metrics", prometheus.Handler())
servMux.Handle("/api/", http.StripPrefix("/api", apiMux)) servMux.Handle("/api/", http.StripPrefix("/api", apiMux))
} }

View File

@ -0,0 +1,182 @@
package writers
import (
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/federationapi/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"net/http"
"time"
)
// Send implements /_matrix/federation/v1/send/{txnID}
func Send(
req *http.Request,
txnID gomatrixserverlib.TransactionID,
now time.Time,
cfg config.FederationAPI,
query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing,
) util.JSONResponse {
request, errResp := gomatrixserverlib.VerifyHTTPRequest(req, now, cfg.ServerName, keys)
if request == nil {
return errResp
}
var content gomatrixserverlib.Transaction
if err := json.Unmarshal(request.Content(), &content); err != nil {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
}
}
content.Origin = request.Origin()
content.TransactionID = txnID
content.Destination = cfg.ServerName
resp, err := processTransaction(content, query, producer, keys)
if err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{
Code: 200,
JSON: resp,
}
}
func processTransaction(
t gomatrixserverlib.Transaction,
query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
keys gomatrixserverlib.KeyRing,
) (*gomatrixserverlib.RespSend, error) {
// Check the event signatures
if err := gomatrixserverlib.VerifyEventSignatures(t.PDUs, keys); err != nil {
return nil, err
}
// Process the events.
results := map[string]gomatrixserverlib.PDUResult{}
for _, e := range t.PDUs {
err := processEvent(e, query, producer)
if err != nil {
// If the error is due to the event itself being bad then we skip
// it and move onto the next event. We report an error so that the
// sender knows that we have skipped processing it.
//
// However if the event is due to a temporary failure in our server
// such as a database being unavailable then we should bail, and
// hope that the sender will retry when we are feeling better.
//
// It is uncertain what we should do if an event fails because
// we failed to fetch more information from the sending server.
// For example if a request to /state fails.
// If we skip the event then we risk missing the event until we
// receive another event referencing it.
// If we bail and stop processing then we risk wedging incoming
// transactions from that server forever.
switch err.(type) {
case unknownRoomError:
case *gomatrixserverlib.NotAllowed:
default:
// Any other error should be the result of a temporary error in
// our server so we should bail processing the transaction entirely.
return nil, err
}
results[e.EventID()] = gomatrixserverlib.PDUResult{err.Error()}
} else {
results[e.EventID()] = gomatrixserverlib.PDUResult{}
}
}
// TODO: Process the EDUs.
return &gomatrixserverlib.RespSend{PDUs: results}, nil
}
type unknownRoomError string
func (e unknownRoomError) Error() string { return fmt.Sprintf("unknown room %q", e) }
func processEvent(
e gomatrixserverlib.Event,
query api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
) error {
refs := e.PrevEvents()
prevEventIDs := make([]string, len(refs))
for i := range refs {
prevEventIDs[i] = refs[i].EventID
}
// Fetch the state needed to authenticate the event.
needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e})
stateReq := api.QueryStateAfterEventsRequest{
RoomID: e.RoomID(),
PrevEventIDs: prevEventIDs,
StateToFetch: needed.Tuples(),
}
var stateResp api.QueryStateAfterEventsResponse
if err := query.QueryStateAfterEvents(&stateReq, &stateResp); err != nil {
return err
}
if !stateResp.RoomExists {
// TODO: When synapse receives a message for a room it is not in it
// asked the remote server for the state of the room so that it can
// check if the remote server knows of a join "m.room.member" event
// that this server is unaware of.
// However generally speaking we should reject events for rooms we
// aren't a member of.
return unknownRoomError(e.RoomID())
}
if !stateResp.PrevEventsExist {
// We are missing the previous events for this events.
// This means that there is a gap in our view of the history of the
// room. There two ways that we can handle such a gap:
// 1) We can fill in the gap using /get_missing_events
// 2) We can leave the gap and request the state of the room at
// this event from the remote server using either /state_ids
// or /state.
// Synapse will attempt to do 1 and if that fails or if the gap is
// too large then it will attempt 2.
// Synapse will use /state_ids if possible since ususally the state
// is largely unchanged and it is more efficient to fetch a list of
// event ids and then use /event to fetch the individual events.
// However not all version of synapse support /state_ids so you may
// need to fallback to /state.
// TODO: Attempt to fill in the gap using /get_missing_events
// TODO: Attempt to fetch the state using /state_ids and /events
// TODO: Attempt to fetch the state using /state
panic(fmt.Errorf("Receiving events with missing prev_events is no implemented"))
}
// Check that the event is allowed by the state at the event.
authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateResp.StateEvents {
authUsingState.AddEvent(&stateResp.StateEvents[i])
}
err := gomatrixserverlib.Allowed(e, &authUsingState)
if err != nil {
return err
}
// TODO: Check that the roomserver has a copy of all of the auth_events.
// TODO: Check that the event is allowed by its auth_events.
// pass the event to the roomserver
if err := producer.SendEvents([]gomatrixserverlib.Event{e}); err != nil {
return err
}
return nil
}

View File

@ -41,6 +41,7 @@ type QueryLatestEventsAndStateResponse struct {
// The latest events in the room. // The latest events in the room.
LatestEvents []gomatrixserverlib.EventReference LatestEvents []gomatrixserverlib.EventReference
// The state events requested. // The state events requested.
// This list will be in an arbitrary order.
StateEvents []gomatrixserverlib.Event StateEvents []gomatrixserverlib.Event
} }
@ -65,9 +66,30 @@ type QueryStateAfterEventsResponse struct {
// If some of previous events do not exist this will be false and StateEvents will be empty. // If some of previous events do not exist this will be false and StateEvents will be empty.
PrevEventsExist bool PrevEventsExist bool
// The state events requested. // The state events requested.
// This list will be in an arbitrary order.
StateEvents []gomatrixserverlib.Event StateEvents []gomatrixserverlib.Event
} }
// QueryEventsByIDRequest is a request to QueryEventsByID
type QueryEventsByIDRequest struct {
// The event IDs to look up.
EventIDs []string
}
// QueryEventsByIDResponse is a response to QueryEventsByID
type QueryEventsByIDResponse struct {
// Copy of the request for debugging.
QueryEventsByIDRequest
// A list of events with the requested IDs.
// If the roomserver does not have a copy of a requested event
// then it will omit that event from the list.
// If the roomserver thinks it has a copy of the event, but
// fails to read it from the database then it will fail
// the entire request.
// This list will be in an arbitrary order.
Events []gomatrixserverlib.Event
}
// RoomserverQueryAPI is used to query information from the room server. // RoomserverQueryAPI is used to query information from the room server.
type RoomserverQueryAPI interface { type RoomserverQueryAPI interface {
// Query the latest events and state for a room from the room server. // Query the latest events and state for a room from the room server.
@ -81,6 +103,12 @@ type RoomserverQueryAPI interface {
request *QueryStateAfterEventsRequest, request *QueryStateAfterEventsRequest,
response *QueryStateAfterEventsResponse, response *QueryStateAfterEventsResponse,
) error ) error
// Query a list of events by event ID.
QueryEventsByID(
request *QueryEventsByIDRequest,
response *QueryEventsByIDResponse,
) error
} }
// RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. // RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API.
@ -89,6 +117,9 @@ const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/QueryLatestEven
// RoomserverQueryStateAfterEventsPath is the HTTP path for the QueryStateAfterEvents API. // RoomserverQueryStateAfterEventsPath is the HTTP path for the QueryStateAfterEvents API.
const RoomserverQueryStateAfterEventsPath = "/api/roomserver/QueryStateAfterEvents" const RoomserverQueryStateAfterEventsPath = "/api/roomserver/QueryStateAfterEvents"
// RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API.
const RoomserverQueryEventsByIDPath = "/api/roomserver/QueryEventsByID"
// NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API.
// If httpClient is nil then it uses the http.DefaultClient // If httpClient is nil then it uses the http.DefaultClient
func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI { func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI {
@ -121,6 +152,15 @@ func (h *httpRoomserverQueryAPI) QueryStateAfterEvents(
return postJSON(h.httpClient, apiURL, request, response) return postJSON(h.httpClient, apiURL, request, response)
} }
// QueryEventsByID implements RoomserverQueryAPI
func (h *httpRoomserverQueryAPI) QueryEventsByID(
request *QueryEventsByIDRequest,
response *QueryEventsByIDResponse,
) error {
apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath
return postJSON(h.httpClient, apiURL, request, response)
}
func postJSON(httpClient http.Client, apiURL string, request, response interface{}) error { func postJSON(httpClient http.Client, apiURL string, request, response interface{}) error {
jsonBytes, err := json.Marshal(request) jsonBytes, err := json.Marshal(request)
if err != nil { if err != nil {

View File

@ -35,6 +35,9 @@ type RoomserverQueryAPIDatabase interface {
// Lookup event references for the latest events in the room and the current state snapshot. // Lookup event references for the latest events in the room and the current state snapshot.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error)
// Lookup the numeric IDs for a list of events.
// Returns an error if there was a problem talking to the database.
EventNIDs(eventIDs []string) (map[string]types.EventNID, error)
} }
// RoomserverQueryAPI is an implementation of RoomserverQueryAPI // RoomserverQueryAPI is an implementation of RoomserverQueryAPI
@ -46,7 +49,7 @@ type RoomserverQueryAPI struct {
func (r *RoomserverQueryAPI) QueryLatestEventsAndState( func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest, request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse, response *api.QueryLatestEventsAndStateResponse,
) (err error) { ) error {
response.QueryLatestEventsAndStateRequest = *request response.QueryLatestEventsAndStateRequest = *request
roomNID, err := r.DB.RoomNID(request.RoomID) roomNID, err := r.DB.RoomNID(request.RoomID)
if err != nil { if err != nil {
@ -81,7 +84,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
func (r *RoomserverQueryAPI) QueryStateAfterEvents( func (r *RoomserverQueryAPI) QueryStateAfterEvents(
request *api.QueryStateAfterEventsRequest, request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse, response *api.QueryStateAfterEventsResponse,
) (err error) { ) error {
response.QueryStateAfterEventsRequest = *request response.QueryStateAfterEventsRequest = *request
roomNID, err := r.DB.RoomNID(request.RoomID) roomNID, err := r.DB.RoomNID(request.RoomID)
if err != nil { if err != nil {
@ -115,12 +118,41 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
return nil return nil
} }
// QueryEventsByID implements api.RoomserverQueryAPI
func (r *RoomserverQueryAPI) QueryEventsByID(
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
response.QueryEventsByIDRequest = *request
eventNIDMap, err := r.DB.EventNIDs(request.EventIDs)
if err != nil {
return err
}
var eventNIDs []types.EventNID
for _, nid := range eventNIDMap {
eventNIDs = append(eventNIDs, nid)
}
events, err := r.loadEvents(eventNIDs)
if err != nil {
return err
}
response.Events = events
return nil
}
func (r *RoomserverQueryAPI) loadStateEvents(stateEntries []types.StateEntry) ([]gomatrixserverlib.Event, error) { func (r *RoomserverQueryAPI) loadStateEvents(stateEntries []types.StateEntry) ([]gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries)) eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries { for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID eventNIDs[i] = stateEntries[i].EventNID
} }
return r.loadEvents(eventNIDs)
}
func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixserverlib.Event, error) {
stateEvents, err := r.DB.Events(eventNIDs) stateEvents, err := r.DB.Events(eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -163,4 +195,18 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
return util.JSONResponse{Code: 200, JSON: &response} return util.JSONResponse{Code: 200, JSON: &response}
}), }),
) )
servMux.Handle(
api.RoomserverQueryEventsByIDPath,
common.MakeAPI("query_events_by_id", func(req *http.Request) util.JSONResponse {
var request api.QueryEventsByIDRequest
var response api.QueryEventsByIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryEventsByID(&request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: 200, JSON: &response}
}),
)
} }

View File

@ -104,6 +104,9 @@ const bulkSelectEventReferenceSQL = "" +
const bulkSelectEventIDSQL = "" + const bulkSelectEventIDSQL = "" +
"SELECT event_nid, event_id FROM events WHERE event_nid = ANY($1)" "SELECT event_nid, event_id FROM events WHERE event_nid = ANY($1)"
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM events WHERE event_id = ANY($1)"
type eventStatements struct { type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt selectEventStmt *sql.Stmt
@ -116,6 +119,7 @@ type eventStatements struct {
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt
} }
func (s *eventStatements) prepare(db *sql.DB) (err error) { func (s *eventStatements) prepare(db *sql.DB) (err error) {
@ -136,6 +140,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
}.prepare(db) }.prepare(db)
} }
@ -321,6 +326,26 @@ func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[typ
return results, nil return results, nil
} }
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]types.EventNID, error) {
rows, err := s.bulkSelectEventNIDStmt.Query(pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
defer rows.Close()
results := make(map[string]types.EventNID, len(eventIDs))
for rows.Next() {
var eventID string
var eventNID int64
if err = rows.Scan(&eventID, &eventNID); err != nil {
return nil, err
}
results[eventID] = types.EventNID(eventNID)
}
return results, nil
}
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
nids := make([]int64, len(eventNIDs)) nids := make([]int64, len(eventNIDs))
for i := range eventNIDs { for i := range eventNIDs {

View File

@ -170,6 +170,11 @@ func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
} }
// EventNIDs implements query.RoomQueryDatabase
func (d *Database) EventNIDs(eventIDs []string) (map[string]types.EventNID, error) {
return d.statements.bulkSelectEventNID(eventIDs)
}
// Events implements input.EventDatabase // Events implements input.EventDatabase
func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) { func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs) eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)