mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-10 06:53:00 +00:00
Update gomatrixserverlib
This commit is contained in:
parent
3b9222e8f7
commit
e6835660b0
2
vendor/manifest
vendored
2
vendor/manifest
vendored
@ -98,7 +98,7 @@
|
|||||||
{
|
{
|
||||||
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
||||||
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
||||||
"revision": "785a53c41170526aa7a91a1fc534afac6ce01a9b",
|
"revision": "9cefcd6c3a00bff51e719a33e19a16edf52cdd6f",
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -30,7 +30,7 @@ const (
|
|||||||
type ClientEvent struct {
|
type ClientEvent struct {
|
||||||
Content rawJSON `json:"content"`
|
Content rawJSON `json:"content"`
|
||||||
EventID string `json:"event_id"`
|
EventID string `json:"event_id"`
|
||||||
OriginServerTS int64 `json:"origin_server_ts"`
|
OriginServerTS Timestamp `json:"origin_server_ts"`
|
||||||
// RoomID is omitted on /sync responses
|
// RoomID is omitted on /sync responses
|
||||||
RoomID string `json:"room_id,omitempty"`
|
RoomID string `json:"room_id,omitempty"`
|
||||||
Sender string `json:"sender"`
|
Sender string `json:"sender"`
|
||||||
|
@ -103,7 +103,8 @@ type eventFields struct {
|
|||||||
Redacts string `json:"redacts"`
|
Redacts string `json:"redacts"`
|
||||||
Depth int64 `json:"depth"`
|
Depth int64 `json:"depth"`
|
||||||
Unsigned rawJSON `json:"unsigned"`
|
Unsigned rawJSON `json:"unsigned"`
|
||||||
OriginServerTS int64 `json:"origin_server_ts"`
|
OriginServerTS Timestamp `json:"origin_server_ts"`
|
||||||
|
Origin ServerName `json:"origin"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var emptyEventReferenceList = []EventReference{}
|
var emptyEventReferenceList = []EventReference{}
|
||||||
@ -145,8 +146,6 @@ func (eb *EventBuilder) Build(eventID string, now time.Time, origin ServerName,
|
|||||||
event.PrevState = &emptyEventReferenceList
|
event.PrevState = &emptyEventReferenceList
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Check size limits.
|
|
||||||
|
|
||||||
var eventJSON []byte
|
var eventJSON []byte
|
||||||
|
|
||||||
if eventJSON, err = json.Marshal(&event); err != nil {
|
if eventJSON, err = json.Marshal(&event); err != nil {
|
||||||
@ -166,7 +165,14 @@ func (eb *EventBuilder) Build(eventID string, now time.Time, origin ServerName,
|
|||||||
}
|
}
|
||||||
|
|
||||||
result.eventJSON = eventJSON
|
result.eventJSON = eventJSON
|
||||||
err = json.Unmarshal(eventJSON, &result.fields)
|
if err = json.Unmarshal(eventJSON, &result.fields); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = result.CheckFields(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,9 +191,6 @@ func NewEventFromUntrustedJSON(eventJSON []byte) (result Event, err error) {
|
|||||||
delete(event, "destinations")
|
delete(event, "destinations")
|
||||||
delete(event, "age_ts")
|
delete(event, "age_ts")
|
||||||
|
|
||||||
// TODO: Check that the event fields are correctly defined.
|
|
||||||
// TODO: Check size limits.
|
|
||||||
|
|
||||||
if eventJSON, err = json.Marshal(event); err != nil {
|
if eventJSON, err = json.Marshal(event); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -206,7 +209,14 @@ func NewEventFromUntrustedJSON(eventJSON []byte) (result Event, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result.eventJSON = eventJSON
|
result.eventJSON = eventJSON
|
||||||
err = json.Unmarshal(eventJSON, &result.fields)
|
if err = json.Unmarshal(eventJSON, &result.fields); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = result.CheckFields(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -303,6 +313,125 @@ func (e Event) StateKeyEquals(stateKey string) bool {
|
|||||||
return *e.fields.StateKey == stateKey
|
return *e.fields.StateKey == stateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The event ID, room ID, sender, event type and state key fields cannot be
|
||||||
|
// bigger than this.
|
||||||
|
// https://github.com/matrix-org/synapse/blob/v0.21.0/synapse/event_auth.py#L173-L182
|
||||||
|
maxIDLength = 255
|
||||||
|
// The entire event JSON, including signatures cannot be bigger than this.
|
||||||
|
// https://github.com/matrix-org/synapse/blob/v0.21.0/synapse/event_auth.py#L183-184
|
||||||
|
maxEventLength = 65536
|
||||||
|
)
|
||||||
|
|
||||||
|
// CheckFields checks that the event fields are valid.
|
||||||
|
// Returns an error if the IDs have the wrong format or too long.
|
||||||
|
// Returns an error if the total length of the event JSON is too long.
|
||||||
|
// Returns an error if the event ID doesn't match the origin of the event.
|
||||||
|
// https://matrix.org/docs/spec/client_server/r0.2.0.html#size-limits
|
||||||
|
func (e Event) CheckFields() error {
|
||||||
|
if len(e.eventJSON) > maxEventLength {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: event is too long, length %d > maximum %d",
|
||||||
|
len(e.eventJSON), maxEventLength,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(e.fields.Type) > maxIDLength {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: event type is too long, length %d > maximum %d",
|
||||||
|
len(e.fields.Type), maxIDLength,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.fields.StateKey != nil && len(*e.fields.StateKey) > maxIDLength {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: state key is too long, length %d > maximum %d",
|
||||||
|
len(*e.fields.StateKey), maxIDLength,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := checkID(e.fields.RoomID, "room", '!')
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
senderDomain, err := checkID(e.fields.Sender, "user", '@')
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventDomain, err := checkID(e.fields.EventID, "event", '$')
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Synapse requires that the event ID domain has a valid signature.
|
||||||
|
// https://github.com/matrix-org/synapse/blob/v0.21.0/synapse/event_auth.py#L66-L68
|
||||||
|
// Synapse requires that the event origin has a valid signature.
|
||||||
|
// https://github.com/matrix-org/synapse/blob/v0.21.0/synapse/federation/federation_base.py#L133-L136
|
||||||
|
// Since both domains must be valid domains, and there is no good reason for them
|
||||||
|
// to be different we might as well ensure that they are the same since it
|
||||||
|
// makes the signature checks simpler.
|
||||||
|
if e.fields.Origin != ServerName(eventDomain) {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: event ID domain doesn't match origin: %q != %q",
|
||||||
|
eventDomain, e.fields.Origin,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventDomain != senderDomain {
|
||||||
|
// For the most part all events should be sent by a user on the
|
||||||
|
// originating server
|
||||||
|
// However "m.room.member" events created from third-party invites
|
||||||
|
// are allowed to have a different sender because they have the same
|
||||||
|
// sender as the "m.room.third_party_invite" event they derived from.
|
||||||
|
// https://github.com/matrix-org/synapse/blob/v0.21.0/synapse/event_auth.py#L58-L64
|
||||||
|
if e.fields.Type != "m.room.member" {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: sender domain doesn't match origin: %q != %q",
|
||||||
|
eventDomain, e.fields.Origin,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
c, err := newMemberContentFromEvent(e)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c.Membership != invite || c.ThirdPartyInvite == nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: sender domain doesn't match origin: %q != %q",
|
||||||
|
eventDomain, e.fields.Origin,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkID(id, kind string, sigil byte) (domain string, err error) {
|
||||||
|
domain, err = domainFromID(id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if id[0] != sigil {
|
||||||
|
err = fmt.Errorf(
|
||||||
|
"gomatrixserverlib: invalid %s ID, wanted first byte to be '%c' got '%c'",
|
||||||
|
kind, sigil, id[0],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(id) > maxIDLength {
|
||||||
|
err = fmt.Errorf(
|
||||||
|
"gomatrixserverlib: %s ID is too long, length %d > maximum %d",
|
||||||
|
kind, len(id), maxIDLength,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Origin returns the name of the server that sent the event
|
||||||
|
func (e Event) Origin() ServerName { return e.fields.Origin }
|
||||||
|
|
||||||
// EventID returns the event ID of the event.
|
// EventID returns the event ID of the event.
|
||||||
func (e Event) EventID() string {
|
func (e Event) EventID() string {
|
||||||
return e.fields.EventID
|
return e.fields.EventID
|
||||||
@ -319,7 +448,7 @@ func (e Event) Type() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OriginServerTS returns the unix timestamp when this event was created on the origin server, with millisecond resolution.
|
// OriginServerTS returns the unix timestamp when this event was created on the origin server, with millisecond resolution.
|
||||||
func (e Event) OriginServerTS() int64 {
|
func (e Event) OriginServerTS() Timestamp {
|
||||||
return e.fields.OriginServerTS
|
return e.fields.OriginServerTS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,7 +247,7 @@ type AuthEvents struct {
|
|||||||
// the event is replaced with the new event. Only returns an error if the event is not a state event.
|
// the event is replaced with the new event. Only returns an error if the event is not a state event.
|
||||||
func (a *AuthEvents) AddEvent(event *Event) error {
|
func (a *AuthEvents) AddEvent(event *Event) error {
|
||||||
if event.StateKey() == nil {
|
if event.StateKey() == nil {
|
||||||
return fmt.Errorf("AddEvent: event %s does not have a state key", event.Type())
|
return fmt.Errorf("AddEvent: event %q does not have a state key", event.Type())
|
||||||
}
|
}
|
||||||
a.events[StateKeyTuple{event.Type(), *event.StateKey()}] = event
|
a.events[StateKeyTuple{event.Type(), *event.StateKey()}] = event
|
||||||
return nil
|
return nil
|
||||||
|
@ -184,3 +184,55 @@ func verifyEventSignature(signingName string, keyID KeyID, publicKey ed25519.Pub
|
|||||||
|
|
||||||
return VerifyJSON(signingName, keyID, publicKey, redactedJSON)
|
return VerifyJSON(signingName, keyID, publicKey, redactedJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VerifyEventSignatures checks that each event in a list of events has valid
|
||||||
|
// signatures from the server that sent it.
|
||||||
|
func VerifyEventSignatures(events []Event, keyRing KeyRing) error {
|
||||||
|
var toVerify []VerifyJSONRequest
|
||||||
|
for _, event := range events {
|
||||||
|
redactedJSON, err := redactEvent(event.eventJSON)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
v := VerifyJSONRequest{
|
||||||
|
Message: redactedJSON,
|
||||||
|
AtTS: event.OriginServerTS(),
|
||||||
|
ServerName: event.Origin(),
|
||||||
|
}
|
||||||
|
toVerify = append(toVerify, v)
|
||||||
|
|
||||||
|
// "m.room.member" invite events are signed by both the server sending
|
||||||
|
// the invite and the server the invite is for.
|
||||||
|
if event.Type() == "m.room.member" && event.StateKey() != nil {
|
||||||
|
targetDomain, err := domainFromID(*event.StateKey())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ServerName(targetDomain) != event.Origin() {
|
||||||
|
c, err := newMemberContentFromEvent(event)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c.Membership == invite {
|
||||||
|
v.ServerName = ServerName(targetDomain)
|
||||||
|
toVerify = append(toVerify, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := keyRing.VerifyJSONs(toVerify)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all the event JSON was correctly signed.
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Everything was okay.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -34,6 +34,53 @@ type RespState struct {
|
|||||||
AuthEvents []Event `json:"auth_chain"`
|
AuthEvents []Event `json:"auth_chain"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that a response to /state is valid.
|
||||||
|
func (r RespState) Check(keyRing KeyRing) error {
|
||||||
|
var allEvents []Event
|
||||||
|
for _, event := range r.AuthEvents {
|
||||||
|
if event.StateKey() == nil {
|
||||||
|
return fmt.Errorf("gomatrixserverlib: event %q does not have a state key", event.EventID())
|
||||||
|
}
|
||||||
|
allEvents = append(allEvents, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateTuples := map[StateKeyTuple]bool{}
|
||||||
|
for _, event := range r.StateEvents {
|
||||||
|
if event.StateKey() == nil {
|
||||||
|
return fmt.Errorf("gomatrixserverlib: event %q does not have a state key", event.EventID())
|
||||||
|
}
|
||||||
|
stateTuple := StateKeyTuple{event.Type(), *event.StateKey()}
|
||||||
|
if stateTuples[stateTuple] {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: duplicate state key tuple (%q, %q)",
|
||||||
|
event.Type(), *event.StateKey(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
stateTuples[stateTuple] = true
|
||||||
|
allEvents = append(allEvents, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the events pass signature checks.
|
||||||
|
if err := VerifyEventSignatures(allEvents, keyRing); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsByID := map[string]*Event{}
|
||||||
|
// Collect a map of event reference to event
|
||||||
|
for i := range allEvents {
|
||||||
|
eventsByID[allEvents[i].EventID()] = &allEvents[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check whether the events are allowed by the auth rules.
|
||||||
|
for _, event := range allEvents {
|
||||||
|
if err := checkAllowedByAuthEvents(event, eventsByID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// A RespMakeJoin is the content of a response to GET /_matrix/federation/v1/make_join/{roomID}/{userID}
|
// A RespMakeJoin is the content of a response to GET /_matrix/federation/v1/make_join/{roomID}/{userID}
|
||||||
type RespMakeJoin struct {
|
type RespMakeJoin struct {
|
||||||
// An incomplete m.room.member event for a user on the requesting server
|
// An incomplete m.room.member event for a user on the requesting server
|
||||||
@ -43,6 +90,7 @@ type RespMakeJoin struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A RespSendJoin is the content of a response to PUT /_matrix/federation/v1/send_join/{roomID}/{eventID}
|
// A RespSendJoin is the content of a response to PUT /_matrix/federation/v1/send_join/{roomID}/{eventID}
|
||||||
|
// It has the same data as a response to /state, but in a slightly different wire format.
|
||||||
type RespSendJoin RespState
|
type RespSendJoin RespState
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaller
|
// MarshalJSON implements json.Marshaller
|
||||||
@ -93,6 +141,43 @@ type respSendJoinFields struct {
|
|||||||
AuthEvents []Event `json:"auth_chain"`
|
AuthEvents []Event `json:"auth_chain"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that a reponse to /send_join is valid.
|
||||||
|
// This checks that it would be valid as a response to /state
|
||||||
|
// This also checks that the join event is allowed by the state.
|
||||||
|
func (r RespSendJoin) Check(keyRing KeyRing, joinEvent Event) error {
|
||||||
|
// First check that the state is valid.
|
||||||
|
// The response to /send_join has the same data as a response to /state
|
||||||
|
// and the checks for a response to /state also apply.
|
||||||
|
if err := RespState(r).Check(keyRing); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEventsByID := map[string]*Event{}
|
||||||
|
authEvents := NewAuthEvents(nil)
|
||||||
|
for i, event := range r.StateEvents {
|
||||||
|
stateEventsByID[event.EventID()] = &r.StateEvents[i]
|
||||||
|
if err := authEvents.AddEvent(&r.StateEvents[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now check that the join event is valid against its auth events.
|
||||||
|
if err := checkAllowedByAuthEvents(joinEvent, stateEventsByID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now check that the join event is valid against the supplied state.
|
||||||
|
if err := Allowed(joinEvent, &authEvents); err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: event with ID %q is not allowed by the supplied state: %s",
|
||||||
|
joinEvent.EventID(), err.Error(),
|
||||||
|
)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// A RespDirectory is the content of a response to GET /_matrix/federation/v1/query/directory
|
// A RespDirectory is the content of a response to GET /_matrix/federation/v1/query/directory
|
||||||
// This is returned when looking up a room alias from a remote server.
|
// This is returned when looking up a room alias from a remote server.
|
||||||
// See https://matrix.org/docs/spec/server_server/unstable.html#directory
|
// See https://matrix.org/docs/spec/server_server/unstable.html#directory
|
||||||
@ -104,3 +189,26 @@ type RespDirectory struct {
|
|||||||
// before it finds one that it can use to join the room.
|
// before it finds one that it can use to join the room.
|
||||||
Servers []ServerName `json:"servers"`
|
Servers []ServerName `json:"servers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkAllowedByAuthEvents(event Event, eventsByID map[string]*Event) error {
|
||||||
|
authEvents := NewAuthEvents(nil)
|
||||||
|
for _, authRef := range event.AuthEvents() {
|
||||||
|
authEvent := eventsByID[authRef.EventID]
|
||||||
|
if authEvent == nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: missing auth event with ID %q for event %q",
|
||||||
|
authRef.EventID, event.EventID(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err := authEvents.AddEvent(authEvent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := Allowed(event, &authEvents); err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"gomatrixserverlib: event with ID %q is not allowed by its auth_events: %s",
|
||||||
|
event.EventID(), err.Error(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -58,7 +58,7 @@ type VerifyJSONResult struct {
|
|||||||
// Whether the message passed the signature checks.
|
// Whether the message passed the signature checks.
|
||||||
// This will be nil if the message passed the checks.
|
// This will be nil if the message passed the checks.
|
||||||
// This will have an error if the message did not pass the checks.
|
// This will have an error if the message did not pass the checks.
|
||||||
Result error
|
Error error
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyJSONs performs bulk JSON signature verification for a list of VerifyJSONRequests.
|
// VerifyJSONs performs bulk JSON signature verification for a list of VerifyJSONRequests.
|
||||||
@ -73,7 +73,7 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
|
|||||||
for i := range requests {
|
for i := range requests {
|
||||||
ids, err := ListKeyIDs(string(requests[i].ServerName), requests[i].Message)
|
ids, err := ListKeyIDs(string(requests[i].ServerName), requests[i].Message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
results[i].Result = fmt.Errorf("gomatrixserverlib: error extracting key IDs")
|
results[i].Error = fmt.Errorf("gomatrixserverlib: error extracting key IDs")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, keyID := range ids {
|
for _, keyID := range ids {
|
||||||
@ -82,7 +82,7 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(keyIDs[i]) == 0 {
|
if len(keyIDs[i]) == 0 {
|
||||||
results[i].Result = fmt.Errorf(
|
results[i].Error = fmt.Errorf(
|
||||||
"gomatrixserverlib: not signed by %q with a supported algorithm", requests[i].ServerName,
|
"gomatrixserverlib: not signed by %q with a supported algorithm", requests[i].ServerName,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@ -91,7 +91,7 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
|
|||||||
// This will be unset if one of the signature checks passes.
|
// This will be unset if one of the signature checks passes.
|
||||||
// This will be overwritten if one of the signature checks fails.
|
// This will be overwritten if one of the signature checks fails.
|
||||||
// Therefore this will only remain in place if the keys couldn't be downloaded.
|
// Therefore this will only remain in place if the keys couldn't be downloaded.
|
||||||
results[i].Result = fmt.Errorf(
|
results[i].Error = fmt.Errorf(
|
||||||
"gomatrixserverlib: could not download key for %q", requests[i].ServerName,
|
"gomatrixserverlib: could not download key for %q", requests[i].ServerName,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -139,7 +139,7 @@ func (k *KeyRing) isAlgorithmSupported(keyID KeyID) bool {
|
|||||||
func (k *KeyRing) publicKeyRequests(requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID) map[PublicKeyRequest]Timestamp {
|
func (k *KeyRing) publicKeyRequests(requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID) map[PublicKeyRequest]Timestamp {
|
||||||
keyRequests := map[PublicKeyRequest]Timestamp{}
|
keyRequests := map[PublicKeyRequest]Timestamp{}
|
||||||
for i := range requests {
|
for i := range requests {
|
||||||
if results[i].Result == nil {
|
if results[i].Error == nil {
|
||||||
// We've already verified this message, we don't need to refetch the keys for it.
|
// We've already verified this message, we don't need to refetch the keys for it.
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -165,7 +165,7 @@ func (k *KeyRing) checkUsingKeys(
|
|||||||
keys map[PublicKeyRequest]ServerKeys,
|
keys map[PublicKeyRequest]ServerKeys,
|
||||||
) {
|
) {
|
||||||
for i := range requests {
|
for i := range requests {
|
||||||
if results[i].Result == nil {
|
if results[i].Error == nil {
|
||||||
// We've already checked this message and it passed the signature checks.
|
// We've already checked this message and it passed the signature checks.
|
||||||
// So we can skip to the next message.
|
// So we can skip to the next message.
|
||||||
continue
|
continue
|
||||||
@ -180,7 +180,7 @@ func (k *KeyRing) checkUsingKeys(
|
|||||||
if publicKey == nil {
|
if publicKey == nil {
|
||||||
// The key wasn't valid at the timestamp we needed it to be valid at.
|
// The key wasn't valid at the timestamp we needed it to be valid at.
|
||||||
// So skip onto the next key.
|
// So skip onto the next key.
|
||||||
results[i].Result = fmt.Errorf(
|
results[i].Error = fmt.Errorf(
|
||||||
"gomatrixserverlib: key with ID %q for %q not valid at %d",
|
"gomatrixserverlib: key with ID %q for %q not valid at %d",
|
||||||
keyID, requests[i].ServerName, requests[i].AtTS,
|
keyID, requests[i].ServerName, requests[i].AtTS,
|
||||||
)
|
)
|
||||||
@ -190,11 +190,11 @@ func (k *KeyRing) checkUsingKeys(
|
|||||||
string(requests[i].ServerName), keyID, ed25519.PublicKey(publicKey), requests[i].Message,
|
string(requests[i].ServerName), keyID, ed25519.PublicKey(publicKey), requests[i].Message,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
// The signature wasn't valid, record the error and try the next key ID.
|
// The signature wasn't valid, record the error and try the next key ID.
|
||||||
results[i].Result = err
|
results[i].Error = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// The signature is valid, set the result to nil.
|
// The signature is valid, set the result to nil.
|
||||||
results[i].Result = nil
|
results[i].Error = nil
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -70,8 +70,8 @@ func TestVerifyJSONsSuccess(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(results) != 1 || results[0].Result != nil {
|
if len(results) != 1 || results[0].Error != nil {
|
||||||
t.Fatalf("VerifyJSON(): Wanted [{Result: nil}] got %#v", results)
|
t.Fatalf("VerifyJSON(): Wanted [{Error: nil}] got %#v", results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,8 +86,8 @@ func TestVerifyJSONsUnknownServerFails(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(results) != 1 || results[0].Result == nil {
|
if len(results) != 1 || results[0].Error == nil {
|
||||||
t.Fatalf("VerifyJSON(): Wanted [{Result: <some error>}] got %#v", results)
|
t.Fatalf("VerifyJSON(): Wanted [{Error: <some error>}] got %#v", results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,8 +103,8 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(results) != 1 || results[0].Result == nil {
|
if len(results) != 1 || results[0].Error == nil {
|
||||||
t.Fatalf("VerifyJSON(): Wanted [{Result: <some error>}] got %#v", results)
|
t.Fatalf("VerifyJSON(): Wanted [{Error: <some error>}] got %#v", results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,9 +224,9 @@ func VerifyHTTPRequest(
|
|||||||
util.GetLogger(req.Context()).WithError(err).Print(message)
|
util.GetLogger(req.Context()).WithError(err).Print(message)
|
||||||
return nil, util.MessageResponse(500, message)
|
return nil, util.MessageResponse(500, message)
|
||||||
}
|
}
|
||||||
if results[0].Result != nil {
|
if results[0].Error != nil {
|
||||||
message := "Invalid request signature"
|
message := "Invalid request signature"
|
||||||
util.GetLogger(req.Context()).WithError(results[0].Result).Print(message)
|
util.GetLogger(req.Context()).WithError(results[0].Error).Print(message)
|
||||||
return nil, util.MessageResponse(401, message)
|
return nil, util.MessageResponse(401, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user