mirror of
https://github.com/1f349/dendrite.git
synced 2024-12-23 16:54:08 +00:00
Store keys rather than json in the keydatabase (#330)
* bump gomatrixserverlib (changes to KeyFetcher and KeyDatabase interfaces) * Store keys rather than json in the keydatabase Rather than storing the raw JSON returned from a /keys/v1/query call in the table, store the key itself. This makes keydb.Database implement the updated KeyDatabase interface.
This commit is contained in:
parent
7f85422471
commit
4124ce2ac0
@ -48,14 +48,14 @@ func NewDatabase(dataSourceName string) (*Database, error) {
|
|||||||
func (d *Database) FetchKeys(
|
func (d *Database) FetchKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
||||||
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||||
return d.statements.bulkSelectServerKeys(ctx, requests)
|
return d.statements.bulkSelectServerKeys(ctx, requests)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||||
func (d *Database) StoreKeys(
|
func (d *Database) StoreKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys,
|
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||||
) error {
|
) error {
|
||||||
// TODO: Inserting all the keys within a single transaction may
|
// TODO: Inserting all the keys within a single transaction may
|
||||||
// be more efficient since the transaction overhead can be quite
|
// be more efficient since the transaction overhead can be quite
|
||||||
|
@ -17,14 +17,13 @@ package keydb
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const serverKeysSchema = `
|
const serverKeysSchema = `
|
||||||
-- A cache of server keys downloaded from remote servers.
|
-- A cache of signing keys downloaded from remote servers.
|
||||||
CREATE TABLE IF NOT EXISTS keydb_server_keys (
|
CREATE TABLE IF NOT EXISTS keydb_server_keys (
|
||||||
-- The name of the matrix server the key is for.
|
-- The name of the matrix server the key is for.
|
||||||
server_name TEXT NOT NULL,
|
server_name TEXT NOT NULL,
|
||||||
@ -33,10 +32,14 @@ CREATE TABLE IF NOT EXISTS keydb_server_keys (
|
|||||||
-- Combined server name and key ID separated by the ASCII unit separator
|
-- Combined server name and key ID separated by the ASCII unit separator
|
||||||
-- to make it easier to run bulk queries.
|
-- to make it easier to run bulk queries.
|
||||||
server_name_and_key_id TEXT NOT NULL,
|
server_name_and_key_id TEXT NOT NULL,
|
||||||
-- When the keys are valid until as a millisecond timestamp.
|
-- When the key is valid until as a millisecond timestamp.
|
||||||
|
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
|
||||||
valid_until_ts BIGINT NOT NULL,
|
valid_until_ts BIGINT NOT NULL,
|
||||||
-- The raw JSON for the server key.
|
-- When the key expired as a millisecond timestamp.
|
||||||
server_key_json TEXT NOT NULL,
|
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
|
||||||
|
expired_ts BIGINT NOT NULL,
|
||||||
|
-- The base64-encoded public key.
|
||||||
|
server_key TEXT NOT NULL,
|
||||||
CONSTRAINT keydb_server_keys_unique UNIQUE (server_name, server_key_id)
|
CONSTRAINT keydb_server_keys_unique UNIQUE (server_name, server_key_id)
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -44,15 +47,16 @@ CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (se
|
|||||||
`
|
`
|
||||||
|
|
||||||
const bulkSelectServerKeysSQL = "" +
|
const bulkSelectServerKeysSQL = "" +
|
||||||
"SELECT server_name, server_key_id, server_key_json FROM keydb_server_keys" +
|
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
|
||||||
|
" server_key FROM keydb_server_keys" +
|
||||||
" WHERE server_name_and_key_id = ANY($1)"
|
" WHERE server_name_and_key_id = ANY($1)"
|
||||||
|
|
||||||
const upsertServerKeysSQL = "" +
|
const upsertServerKeysSQL = "" +
|
||||||
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
|
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
|
||||||
" server_name_and_key_id, valid_until_ts, server_key_json)" +
|
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5)" +
|
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
" ON CONFLICT ON CONSTRAINT keydb_server_keys_unique" +
|
" ON CONFLICT ON CONSTRAINT keydb_server_keys_unique" +
|
||||||
" DO UPDATE SET valid_until_ts = $4, server_key_json = $5"
|
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
|
||||||
|
|
||||||
type serverKeyStatements struct {
|
type serverKeyStatements struct {
|
||||||
bulkSelectServerKeysStmt *sql.Stmt
|
bulkSelectServerKeysStmt *sql.Stmt
|
||||||
@ -76,7 +80,7 @@ func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
|
|||||||
func (s *serverKeyStatements) bulkSelectServerKeys(
|
func (s *serverKeyStatements) bulkSelectServerKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
||||||
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||||
var nameAndKeyIDs []string
|
var nameAndKeyIDs []string
|
||||||
for request := range requests {
|
for request := range requests {
|
||||||
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
||||||
@ -87,23 +91,30 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close() // nolint: errcheck
|
defer rows.Close() // nolint: errcheck
|
||||||
results := map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys{}
|
results := map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var serverName string
|
var serverName string
|
||||||
var keyID string
|
var keyID string
|
||||||
var keyJSON []byte
|
var key string
|
||||||
if err := rows.Scan(&serverName, &keyID, &keyJSON); err != nil {
|
var validUntilTS int64
|
||||||
return nil, err
|
var expiredTS int64
|
||||||
}
|
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
|
||||||
var serverKeys gomatrixserverlib.ServerKeys
|
|
||||||
if err := json.Unmarshal(keyJSON, &serverKeys); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
r := gomatrixserverlib.PublicKeyRequest{
|
r := gomatrixserverlib.PublicKeyRequest{
|
||||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||||
KeyID: gomatrixserverlib.KeyID(keyID),
|
KeyID: gomatrixserverlib.KeyID(keyID),
|
||||||
}
|
}
|
||||||
results[r] = serverKeys
|
vk := gomatrixserverlib.VerifyKey{}
|
||||||
|
err = vk.Key.Decode(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results[r] = gomatrixserverlib.PublicKeyLookupResult{
|
||||||
|
VerifyKey: vk,
|
||||||
|
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
|
||||||
|
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
@ -111,19 +122,16 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
|
|||||||
func (s *serverKeyStatements) upsertServerKeys(
|
func (s *serverKeyStatements) upsertServerKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request gomatrixserverlib.PublicKeyRequest,
|
request gomatrixserverlib.PublicKeyRequest,
|
||||||
keys gomatrixserverlib.ServerKeys,
|
key gomatrixserverlib.PublicKeyLookupResult,
|
||||||
) error {
|
) error {
|
||||||
keyJSON, err := json.Marshal(keys)
|
_, err := s.upsertServerKeysStmt.ExecContext(
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = s.upsertServerKeysStmt.ExecContext(
|
|
||||||
ctx,
|
ctx,
|
||||||
string(request.ServerName),
|
string(request.ServerName),
|
||||||
string(request.KeyID),
|
string(request.KeyID),
|
||||||
nameAndKeyID(request),
|
nameAndKeyID(request),
|
||||||
int64(keys.ValidUntilTS),
|
key.ValidUntilTS,
|
||||||
keyJSON,
|
key.ExpiredTS,
|
||||||
|
key.Key.Encode(),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,6 @@ func localKeys(cfg config.Dendrite, validUntil time.Time) (*gomatrixserverlib.Se
|
|||||||
var keys gomatrixserverlib.ServerKeys
|
var keys gomatrixserverlib.ServerKeys
|
||||||
|
|
||||||
keys.ServerName = cfg.Matrix.ServerName
|
keys.ServerName = cfg.Matrix.ServerName
|
||||||
keys.FromServer = cfg.Matrix.ServerName
|
|
||||||
|
|
||||||
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
|
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
|
||||||
|
|
||||||
|
2
vendor/manifest
vendored
2
vendor/manifest
vendored
@ -135,7 +135,7 @@
|
|||||||
{
|
{
|
||||||
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
"importpath": "github.com/matrix-org/gomatrixserverlib",
|
||||||
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
"repository": "https://github.com/matrix-org/gomatrixserverlib",
|
||||||
"revision": "fb17c27f65a0699b0d15f5311a530225b4aea5e0",
|
"revision": "076933f95312aae3a9476e78d6b4118e1b45d542",
|
||||||
"branch": "master"
|
"branch": "master"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -26,13 +26,31 @@ import (
|
|||||||
// When the bytes are unmarshalled from JSON they are decoded from base64.
|
// When the bytes are unmarshalled from JSON they are decoded from base64.
|
||||||
type Base64String []byte
|
type Base64String []byte
|
||||||
|
|
||||||
|
// Encode encodes the bytes as base64
|
||||||
|
func (b64 Base64String) Encode() string {
|
||||||
|
return base64.RawStdEncoding.EncodeToString(b64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode decodes the given input into this Base64String
|
||||||
|
func (b64 *Base64String) Decode(str string) error {
|
||||||
|
// We must check whether the string was encoded in a URL-safe way in order
|
||||||
|
// to use the appropriate encoding.
|
||||||
|
var err error
|
||||||
|
if strings.ContainsAny(str, "-_") {
|
||||||
|
*b64, err = base64.RawURLEncoding.DecodeString(str)
|
||||||
|
} else {
|
||||||
|
*b64, err = base64.RawStdEncoding.DecodeString(str)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON encodes the bytes as base64 and then encodes the base64 as a JSON string.
|
// MarshalJSON encodes the bytes as base64 and then encodes the base64 as a JSON string.
|
||||||
// This takes a value receiver so that maps and slices of Base64String encode correctly.
|
// This takes a value receiver so that maps and slices of Base64String encode correctly.
|
||||||
func (b64 Base64String) MarshalJSON() ([]byte, error) {
|
func (b64 Base64String) MarshalJSON() ([]byte, error) {
|
||||||
// This could be made more efficient by using base64.RawStdEncoding.Encode
|
// This could be made more efficient by using base64.RawStdEncoding.Encode
|
||||||
// to write the base64 directly to the JSON. We don't need to JSON escape
|
// to write the base64 directly to the JSON. We don't need to JSON escape
|
||||||
// any of the characters used in base64.
|
// any of the characters used in base64.
|
||||||
return json.Marshal(base64.RawStdEncoding.EncodeToString(b64))
|
return json.Marshal(b64.Encode())
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON decodes a JSON string and then decodes the resulting base64.
|
// UnmarshalJSON decodes a JSON string and then decodes the resulting base64.
|
||||||
@ -44,12 +62,6 @@ func (b64 *Base64String) UnmarshalJSON(raw []byte) (err error) {
|
|||||||
if err = json.Unmarshal(raw, &str); err != nil {
|
if err = json.Unmarshal(raw, &str); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// We must check whether the string was encoded in a URL-safe way in order
|
err = b64.Decode(str)
|
||||||
// to use the appropriate encoding.
|
|
||||||
if strings.ContainsAny(str, "-_") {
|
|
||||||
*b64, err = base64.RawURLEncoding.DecodeString(str)
|
|
||||||
} else {
|
|
||||||
*b64, err = base64.RawStdEncoding.DecodeString(str)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -26,10 +26,15 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Default HTTPS request timeout
|
||||||
|
const requestTimeout time.Duration = time.Duration(30) * time.Second
|
||||||
|
|
||||||
// A Client makes request to the federation listeners of matrix
|
// A Client makes request to the federation listeners of matrix
|
||||||
// homeservers
|
// homeservers
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@ -41,9 +46,16 @@ type UserInfo struct {
|
|||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient makes a new Client
|
// NewClient makes a new Client (with default timeout)
|
||||||
func NewClient() *Client {
|
func NewClient() *Client {
|
||||||
return &Client{client: http.Client{Transport: newFederationTripper()}}
|
return NewClientWithTimeout(requestTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientWithTimeout makes a new Client with a specified request timeout
|
||||||
|
func NewClientWithTimeout(timeout time.Duration) *Client {
|
||||||
|
return &Client{client: http.Client{
|
||||||
|
Transport: newFederationTripper(),
|
||||||
|
Timeout: timeout}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type federationTripper struct {
|
type federationTripper struct {
|
||||||
@ -132,7 +144,7 @@ func (fc *Client) LookupUserInfo(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var response *http.Response
|
var response *http.Response
|
||||||
response, err = fc.doHTTPRequest(ctx, req)
|
response, err = fc.DoHTTPRequest(ctx, req)
|
||||||
if response != nil {
|
if response != nil {
|
||||||
defer response.Body.Close() // nolint: errcheck
|
defer response.Body.Close() // nolint: errcheck
|
||||||
}
|
}
|
||||||
@ -171,10 +183,10 @@ func (fc *Client) LookupUserInfo(
|
|||||||
// Perspective servers can use that timestamp to determine whether they can
|
// Perspective servers can use that timestamp to determine whether they can
|
||||||
// return a cached copy of the keys or whether they will need to retrieve a fresh
|
// return a cached copy of the keys or whether they will need to retrieve a fresh
|
||||||
// copy of the keys.
|
// copy of the keys.
|
||||||
// Returns the keys or an error if there was a problem talking to the server.
|
// Returns the keys returned by the server, or an error if there was a problem talking to the server.
|
||||||
func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
func (fc *Client) LookupServerKeys(
|
||||||
ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
|
ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
|
||||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
) ([]ServerKeys, error) {
|
||||||
url := url.URL{
|
url := url.URL{
|
||||||
Scheme: "matrix",
|
Scheme: "matrix",
|
||||||
Host: string(matrixServer),
|
Host: string(matrixServer),
|
||||||
@ -203,48 +215,24 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var body struct {
|
||||||
|
ServerKeyList []ServerKeys `json:"server_keys"`
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", url.String(), bytes.NewBuffer(requestBytes))
|
req, err := http.NewRequest("POST", url.String(), bytes.NewBuffer(requestBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
response, err := fc.doHTTPRequest(ctx, req)
|
err = fc.DoRequestAndParseResponse(
|
||||||
if response != nil {
|
ctx, req, &body,
|
||||||
defer response.Body.Close() // nolint: errcheck
|
)
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode != 200 {
|
return body.ServerKeyList, nil
|
||||||
var errorOutput []byte
|
|
||||||
if errorOutput, err = ioutil.ReadAll(response.Body); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("HTTP %d : %s", response.StatusCode, errorOutput)
|
|
||||||
}
|
|
||||||
|
|
||||||
var body struct {
|
|
||||||
ServerKeyList []ServerKeys `json:"server_keys"`
|
|
||||||
}
|
|
||||||
if err = json.NewDecoder(response.Body).Decode(&body); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
result := map[PublicKeyRequest]ServerKeys{}
|
|
||||||
for _, keys := range body.ServerKeyList {
|
|
||||||
keys.FromServer = matrixServer
|
|
||||||
// TODO: What happens if the same key ID appears in multiple responses?
|
|
||||||
// We should probably take the response with the highest valid_until_ts.
|
|
||||||
for keyID := range keys.VerifyKeys {
|
|
||||||
result[PublicKeyRequest{keys.ServerName, keyID}] = keys
|
|
||||||
}
|
|
||||||
for keyID := range keys.OldVerifyKeys {
|
|
||||||
result[PublicKeyRequest{keys.ServerName, keyID}] = keys
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error
|
// CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error
|
||||||
@ -257,10 +245,70 @@ func (fc *Client) CreateMediaDownloadRequest(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fc.doHTTPRequest(ctx, req)
|
return fc.DoHTTPRequest(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fc *Client) doHTTPRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
// DoRequestAndParseResponse calls DoHTTPRequest and then decodes the response.
|
||||||
|
//
|
||||||
|
// If the HTTP response is not a 200, an attempt is made to parse the response
|
||||||
|
// body into a gomatrix.RespError. In any case, a non-200 response will result
|
||||||
|
// in a gomatrix.HTTPError.
|
||||||
|
//
|
||||||
|
func (fc *Client) DoRequestAndParseResponse(
|
||||||
|
ctx context.Context,
|
||||||
|
req *http.Request,
|
||||||
|
result interface{},
|
||||||
|
) error {
|
||||||
|
response, err := fc.DoHTTPRequest(ctx, req)
|
||||||
|
if response != nil {
|
||||||
|
defer response.Body.Close() // nolint: errcheck
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.StatusCode/100 != 2 { // not 2xx
|
||||||
|
// Adapted from https://github.com/matrix-org/gomatrix/blob/master/client.go
|
||||||
|
var contents []byte
|
||||||
|
contents, err = ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var wrap error
|
||||||
|
var respErr gomatrix.RespError
|
||||||
|
if _ = json.Unmarshal(contents, &respErr); respErr.ErrCode != "" {
|
||||||
|
wrap = respErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we failed to decode as RespError, don't just drop the HTTP body, include it in the
|
||||||
|
// HTTP error instead (e.g proxy errors which return HTML).
|
||||||
|
msg := "Failed to " + req.Method + " JSON to " + req.RequestURI
|
||||||
|
if wrap == nil {
|
||||||
|
msg = msg + ": " + string(contents)
|
||||||
|
}
|
||||||
|
|
||||||
|
return gomatrix.HTTPError{
|
||||||
|
Code: response.StatusCode,
|
||||||
|
Message: msg,
|
||||||
|
WrappedError: wrap,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewDecoder(response.Body).Decode(result); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoHTTPRequest creates an outgoing request ID and adds it to the context
|
||||||
|
// before sending off the request and awaiting a response.
|
||||||
|
//
|
||||||
|
// If the returned error is nil, the Response will contain a non-nil
|
||||||
|
// Body which the caller is expected to close.
|
||||||
|
//
|
||||||
|
func (fc *Client) DoHTTPRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||||
reqID := util.RandomString(12)
|
reqID := util.RandomString(12)
|
||||||
logger := util.GetLogger(ctx).WithField("server", req.URL.Host).WithField("out.req.ID", reqID)
|
logger := util.GetLogger(ctx).WithField("server", req.URL.Host).WithField("out.req.ID", reqID)
|
||||||
newCtx := util.ContextWithLogger(ctx, logger)
|
newCtx := util.ContextWithLogger(ctx, logger)
|
||||||
|
@ -2,17 +2,13 @@ package gomatrixserverlib
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrix"
|
|
||||||
"github.com/matrix-org/util"
|
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
// An FederationClient is a matrix federation client that adds
|
// A FederationClient is a matrix federation client that adds
|
||||||
// "Authorization: X-Matrix" headers to requests that need ed25519 signatures
|
// "Authorization: X-Matrix" headers to requests that need ed25519 signatures
|
||||||
type FederationClient struct {
|
type FederationClient struct {
|
||||||
Client
|
Client
|
||||||
@ -34,10 +30,6 @@ func NewFederationClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest, resBody interface{}) error {
|
func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest, resBody interface{}) error {
|
||||||
reqID := util.RandomString(12)
|
|
||||||
logger := util.GetLogger(ctx).WithField("server", r.fields.Destination).WithField("out.req.ID", reqID)
|
|
||||||
newCtx := util.ContextWithLogger(ctx, logger)
|
|
||||||
|
|
||||||
if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil {
|
if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -47,52 +39,7 @@ func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Infof("Outgoing request %s %s", req.Method, req.URL)
|
return ac.Client.DoRequestAndParseResponse(ctx, req, resBody)
|
||||||
res, err := ac.client.Do(req.WithContext(newCtx))
|
|
||||||
if res != nil {
|
|
||||||
defer res.Body.Close() // nolint: errcheck
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Infof("Outgoing request %s %s failed with %v", req.Method, req.URL, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
contents, err := ioutil.ReadAll(res.Body)
|
|
||||||
|
|
||||||
logger.Infof("Response %d from %s %s", res.StatusCode, req.Method, req.URL)
|
|
||||||
|
|
||||||
if res.StatusCode/100 != 2 { // not 2xx
|
|
||||||
// Adapted from https://github.com/matrix-org/gomatrix/blob/master/client.go
|
|
||||||
var wrap error
|
|
||||||
var respErr gomatrix.RespError
|
|
||||||
if _ = json.Unmarshal(contents, &respErr); respErr.ErrCode != "" {
|
|
||||||
wrap = respErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we failed to decode as RespError, don't just drop the HTTP body, include it in the
|
|
||||||
// HTTP error instead (e.g proxy errors which return HTML).
|
|
||||||
msg := "Failed to " + r.Method() + " JSON to " + r.RequestURI()
|
|
||||||
if wrap == nil {
|
|
||||||
msg = msg + ": " + string(contents)
|
|
||||||
}
|
|
||||||
|
|
||||||
return gomatrix.HTTPError{
|
|
||||||
Code: res.StatusCode,
|
|
||||||
Message: msg,
|
|
||||||
WrappedError: wrap,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resBody == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.Unmarshal(contents, resBody)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var federationPathPrefix = "/_matrix/federation/v1"
|
var federationPathPrefix = "/_matrix/federation/v1"
|
||||||
|
@ -17,6 +17,38 @@ type PublicKeyRequest struct {
|
|||||||
KeyID KeyID
|
KeyID KeyID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PublicKeyNotExpired is a magic value for PublicKeyLookupResult.ExpiredTS:
|
||||||
|
// it indicates that this is an active key which has not yet expired
|
||||||
|
const PublicKeyNotExpired = Timestamp(0)
|
||||||
|
|
||||||
|
// PublicKeyNotValid is a magic value for PublicKeyLookupResult.ValidUntilTS:
|
||||||
|
// it is used when we don't have a validity period for this key. Most likely
|
||||||
|
// it is an old key with an expiry date.
|
||||||
|
const PublicKeyNotValid = Timestamp(0)
|
||||||
|
|
||||||
|
// A PublicKeyLookupResult is the result of looking up a server signing key.
|
||||||
|
type PublicKeyLookupResult struct {
|
||||||
|
VerifyKey
|
||||||
|
// if this key has expired, the time it stopped being valid for event signing in milliseconds.
|
||||||
|
// if the key has not expired, the magic value PublicKeyNotExpired.
|
||||||
|
ExpiredTS Timestamp
|
||||||
|
// When this result is valid until in milliseconds.
|
||||||
|
// if the key has expired, the magic value PublicKeyNotValid.
|
||||||
|
ValidUntilTS Timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
// WasValidAt checks if this signing key is valid for an event signed at the
|
||||||
|
// given timestamp.
|
||||||
|
func (r PublicKeyLookupResult) WasValidAt(atTs Timestamp) bool {
|
||||||
|
if r.ExpiredTS != PublicKeyNotExpired && atTs >= r.ExpiredTS {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if r.ValidUntilTS == PublicKeyNotValid || atTs > r.ValidUntilTS {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// A KeyFetcher is a way of fetching public keys in bulk.
|
// A KeyFetcher is a way of fetching public keys in bulk.
|
||||||
type KeyFetcher interface {
|
type KeyFetcher interface {
|
||||||
// Lookup a batch of public keys.
|
// Lookup a batch of public keys.
|
||||||
@ -27,7 +59,7 @@ type KeyFetcher interface {
|
|||||||
// The result may have fewer (server name, key ID) pairs than were in the request.
|
// The result may have fewer (server name, key ID) pairs than were in the request.
|
||||||
// The result may have more (server name, key ID) pairs than were in the request.
|
// The result may have more (server name, key ID) pairs than were in the request.
|
||||||
// Returns an error if there was a problem fetching the keys.
|
// Returns an error if there was a problem fetching the keys.
|
||||||
FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
|
FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]PublicKeyLookupResult, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// A KeyDatabase is a store for caching public keys.
|
// A KeyDatabase is a store for caching public keys.
|
||||||
@ -40,7 +72,7 @@ type KeyDatabase interface {
|
|||||||
// to a concurrent FetchKeys(). This is acceptable since the database is
|
// to a concurrent FetchKeys(). This is acceptable since the database is
|
||||||
// only used as a cache for the keys, so if a FetchKeys() races with a
|
// only used as a cache for the keys, so if a FetchKeys() races with a
|
||||||
// StoreKeys() and some of the keys are missing they will be just be refetched.
|
// StoreKeys() and some of the keys are missing they will be just be refetched.
|
||||||
StoreKeys(ctx context.Context, results map[PublicKeyRequest]ServerKeys) error
|
StoreKeys(ctx context.Context, results map[PublicKeyRequest]PublicKeyLookupResult) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
|
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
|
||||||
@ -123,6 +155,8 @@ func (k KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest)
|
|||||||
k.checkUsingKeys(requests, results, keyIDs, keysFromDatabase)
|
k.checkUsingKeys(requests, results, keyIDs, keysFromDatabase)
|
||||||
|
|
||||||
for i := range k.KeyFetchers {
|
for i := range k.KeyFetchers {
|
||||||
|
// TODO: we should distinguish here between expired keys, and those we don't have.
|
||||||
|
// If the key has expired, it's no use re-requesting it.
|
||||||
keyRequests := k.publicKeyRequests(requests, results, keyIDs)
|
keyRequests := k.publicKeyRequests(requests, results, keyIDs)
|
||||||
if len(keyRequests) == 0 {
|
if len(keyRequests) == 0 {
|
||||||
// There aren't any keys to fetch so we can stop here.
|
// There aren't any keys to fetch so we can stop here.
|
||||||
@ -178,7 +212,7 @@ func (k *KeyRing) publicKeyRequests(
|
|||||||
|
|
||||||
func (k *KeyRing) checkUsingKeys(
|
func (k *KeyRing) checkUsingKeys(
|
||||||
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
|
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
|
||||||
keys map[PublicKeyRequest]ServerKeys,
|
keys map[PublicKeyRequest]PublicKeyLookupResult,
|
||||||
) {
|
) {
|
||||||
for i := range requests {
|
for i := range requests {
|
||||||
if results[i].Error == nil {
|
if results[i].Error == nil {
|
||||||
@ -187,13 +221,12 @@ func (k *KeyRing) checkUsingKeys(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, keyID := range keyIDs[i] {
|
for _, keyID := range keyIDs[i] {
|
||||||
serverKeys, ok := keys[PublicKeyRequest{requests[i].ServerName, keyID}]
|
serverKey, ok := keys[PublicKeyRequest{requests[i].ServerName, keyID}]
|
||||||
if !ok {
|
if !ok {
|
||||||
// No key for this key ID so we continue onto the next key ID.
|
// No key for this key ID so we continue onto the next key ID.
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
publicKey := serverKeys.PublicKey(keyID, requests[i].AtTS)
|
if !serverKey.WasValidAt(requests[i].AtTS) {
|
||||||
if publicKey == nil {
|
|
||||||
// The key wasn't valid at the timestamp we needed it to be valid at.
|
// The key wasn't valid at the timestamp we needed it to be valid at.
|
||||||
// So skip onto the next key.
|
// So skip onto the next key.
|
||||||
results[i].Error = fmt.Errorf(
|
results[i].Error = fmt.Errorf(
|
||||||
@ -203,7 +236,7 @@ func (k *KeyRing) checkUsingKeys(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := VerifyJSON(
|
if err := VerifyJSON(
|
||||||
string(requests[i].ServerName), keyID, ed25519.PublicKey(publicKey), requests[i].Message,
|
string(requests[i].ServerName), keyID, ed25519.PublicKey(serverKey.Key), requests[i].Message,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
// The signature wasn't valid, record the error and try the next key ID.
|
// The signature wasn't valid, record the error and try the next key ID.
|
||||||
results[i].Error = err
|
results[i].Error = err
|
||||||
@ -229,13 +262,15 @@ type PerspectiveKeyFetcher struct {
|
|||||||
// FetchKeys implements KeyFetcher
|
// FetchKeys implements KeyFetcher
|
||||||
func (p *PerspectiveKeyFetcher) FetchKeys(
|
func (p *PerspectiveKeyFetcher) FetchKeys(
|
||||||
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
) (map[PublicKeyRequest]PublicKeyLookupResult, error) {
|
||||||
results, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
|
serverKeys, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for req, keys := range results {
|
results := map[PublicKeyRequest]PublicKeyLookupResult{}
|
||||||
|
|
||||||
|
for _, keys := range serverKeys {
|
||||||
var valid bool
|
var valid bool
|
||||||
keyIDs, err := ListKeyIDs(string(p.PerspectiveServerName), keys.Raw)
|
keyIDs, err := ListKeyIDs(string(p.PerspectiveServerName), keys.Raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -261,12 +296,16 @@ func (p *PerspectiveKeyFetcher) FetchKeys(
|
|||||||
return nil, fmt.Errorf("gomatrixserverlib: not signed with a known key for the perspective server")
|
return nil, fmt.Errorf("gomatrixserverlib: not signed with a known key for the perspective server")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the keys are valid for the server.
|
// Check that the keys are valid for the server they claim to be
|
||||||
checks, _, _ := CheckKeys(req.ServerName, time.Unix(0, 0), keys, nil)
|
checks, _, _ := CheckKeys(keys.ServerName, time.Unix(0, 0), keys, nil)
|
||||||
if !checks.AllChecksOK {
|
if !checks.AllChecksOK {
|
||||||
// This is bad because it means that the perspective server was trying to feed us an invalid response.
|
// This is bad because it means that the perspective server was trying to feed us an invalid response.
|
||||||
return nil, fmt.Errorf("gomatrixserverlib: key response from perspective server failed checks")
|
return nil, fmt.Errorf("gomatrixserverlib: key response from perspective server failed checks")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: What happens if the same key ID appears in multiple responses?
|
||||||
|
// We should probably take the response with the highest valid_until_ts.
|
||||||
|
mapServerKeysToPublicKeyLookupResult(keys, results)
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
@ -282,7 +321,7 @@ type DirectKeyFetcher struct {
|
|||||||
// FetchKeys implements KeyFetcher
|
// FetchKeys implements KeyFetcher
|
||||||
func (d *DirectKeyFetcher) FetchKeys(
|
func (d *DirectKeyFetcher) FetchKeys(
|
||||||
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
) (map[PublicKeyRequest]PublicKeyLookupResult, error) {
|
||||||
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{}
|
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{}
|
||||||
for req, ts := range requests {
|
for req, ts := range requests {
|
||||||
server := byServer[req.ServerName]
|
server := byServer[req.ServerName]
|
||||||
@ -293,7 +332,7 @@ func (d *DirectKeyFetcher) FetchKeys(
|
|||||||
server[req] = ts
|
server[req] = ts
|
||||||
}
|
}
|
||||||
|
|
||||||
results := map[PublicKeyRequest]ServerKeys{}
|
results := map[PublicKeyRequest]PublicKeyLookupResult{}
|
||||||
for server, reqs := range byServer {
|
for server, reqs := range byServer {
|
||||||
// TODO: make these requests in parallel
|
// TODO: make these requests in parallel
|
||||||
serverResults, err := d.fetchKeysForServer(ctx, server, reqs)
|
serverResults, err := d.fetchKeysForServer(ctx, server, reqs)
|
||||||
@ -310,19 +349,50 @@ func (d *DirectKeyFetcher) FetchKeys(
|
|||||||
|
|
||||||
func (d *DirectKeyFetcher) fetchKeysForServer(
|
func (d *DirectKeyFetcher) fetchKeysForServer(
|
||||||
ctx context.Context, serverName ServerName, requests map[PublicKeyRequest]Timestamp,
|
ctx context.Context, serverName ServerName, requests map[PublicKeyRequest]Timestamp,
|
||||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
) (map[PublicKeyRequest]PublicKeyLookupResult, error) {
|
||||||
results, err := d.Client.LookupServerKeys(ctx, serverName, requests)
|
serverKeys, err := d.Client.LookupServerKeys(ctx, serverName, requests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for req, keys := range results {
|
results := map[PublicKeyRequest]PublicKeyLookupResult{}
|
||||||
|
for _, keys := range serverKeys {
|
||||||
// Check that the keys are valid for the server.
|
// Check that the keys are valid for the server.
|
||||||
checks, _, _ := CheckKeys(req.ServerName, time.Unix(0, 0), keys, nil)
|
checks, _, _ := CheckKeys(serverName, time.Unix(0, 0), keys, nil)
|
||||||
if !checks.AllChecksOK {
|
if !checks.AllChecksOK {
|
||||||
return nil, fmt.Errorf("gomatrixserverlib: key response direct from %q failed checks", serverName)
|
return nil, fmt.Errorf("gomatrixserverlib: key response direct from %q failed checks", serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: What happens if the same key ID appears in multiple responses?
|
||||||
|
// We should probably take the response with the highest valid_until_ts.
|
||||||
|
mapServerKeysToPublicKeyLookupResult(keys, results)
|
||||||
}
|
}
|
||||||
|
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mapServerKeysToPublicKeyLookupResult takes the (verified) result from a
|
||||||
|
// /key/v2/query call and inserts it into a PublicKeyRequest->PublicKeyLookupResult
|
||||||
|
// map.
|
||||||
|
func mapServerKeysToPublicKeyLookupResult(serverKeys ServerKeys, results map[PublicKeyRequest]PublicKeyLookupResult) {
|
||||||
|
for keyID, key := range serverKeys.VerifyKeys {
|
||||||
|
results[PublicKeyRequest{
|
||||||
|
ServerName: serverKeys.ServerName,
|
||||||
|
KeyID: keyID,
|
||||||
|
}] = PublicKeyLookupResult{
|
||||||
|
VerifyKey: key,
|
||||||
|
ValidUntilTS: serverKeys.ValidUntilTS,
|
||||||
|
ExpiredTS: PublicKeyNotExpired,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for keyID, key := range serverKeys.OldVerifyKeys {
|
||||||
|
results[PublicKeyRequest{
|
||||||
|
ServerName: serverKeys.ServerName,
|
||||||
|
KeyID: keyID,
|
||||||
|
}] = PublicKeyLookupResult{
|
||||||
|
VerifyKey: key.VerifyKey,
|
||||||
|
ValidUntilTS: PublicKeyNotValid,
|
||||||
|
ExpiredTS: key.ExpiredTS,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,7 +2,6 @@ package gomatrixserverlib
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,26 +38,44 @@ type testKeyDatabase struct{}
|
|||||||
|
|
||||||
func (db *testKeyDatabase) FetchKeys(
|
func (db *testKeyDatabase) FetchKeys(
|
||||||
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
) (map[PublicKeyRequest]PublicKeyLookupResult, error) {
|
||||||
results := map[PublicKeyRequest]ServerKeys{}
|
results := map[PublicKeyRequest]PublicKeyLookupResult{}
|
||||||
var keys ServerKeys
|
|
||||||
if err := json.Unmarshal([]byte(testKeys), &keys); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req1 := PublicKeyRequest{"localhost:8800", "ed25519:old"}
|
req1 := PublicKeyRequest{"localhost:8800", "ed25519:old"}
|
||||||
req2 := PublicKeyRequest{"localhost:8800", "ed25519:a_Obwu"}
|
req2 := PublicKeyRequest{"localhost:8800", "ed25519:a_Obwu"}
|
||||||
|
|
||||||
for req := range requests {
|
for req := range requests {
|
||||||
if req == req1 || req == req2 {
|
if req == req1 {
|
||||||
results[req] = keys
|
vk := VerifyKey{}
|
||||||
|
err := vk.Key.Decode("O2onvM62pC1io6jQKm8Nc2UyFXcd4kOmOsBIoYtZ2ik")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results[req] = PublicKeyLookupResult{
|
||||||
|
VerifyKey: vk,
|
||||||
|
ValidUntilTS: PublicKeyNotValid,
|
||||||
|
ExpiredTS: 929059200,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req == req2 {
|
||||||
|
vk := VerifyKey{}
|
||||||
|
err := vk.Key.Decode("2UwTWD4+tgTgENV7znGGNqhAOGY+BW1mRAnC6W6FBQg")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results[req] = PublicKeyLookupResult{
|
||||||
|
VerifyKey: vk,
|
||||||
|
ValidUntilTS: 1493142432964,
|
||||||
|
ExpiredTS: PublicKeyNotExpired,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *testKeyDatabase) StoreKeys(
|
func (db *testKeyDatabase) StoreKeys(
|
||||||
ctx context.Context, requests map[PublicKeyRequest]ServerKeys,
|
ctx context.Context, requests map[PublicKeyRequest]PublicKeyLookupResult,
|
||||||
) error {
|
) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -136,12 +153,12 @@ var testErrorStore = erroringKeyDatabaseError(2)
|
|||||||
|
|
||||||
func (e *erroringKeyDatabase) FetchKeys(
|
func (e *erroringKeyDatabase) FetchKeys(
|
||||||
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
|
||||||
) (map[PublicKeyRequest]ServerKeys, error) {
|
) (map[PublicKeyRequest]PublicKeyLookupResult, error) {
|
||||||
return nil, &testErrorFetch
|
return nil, &testErrorFetch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *erroringKeyDatabase) StoreKeys(
|
func (e *erroringKeyDatabase) StoreKeys(
|
||||||
ctx context.Context, keys map[PublicKeyRequest]ServerKeys,
|
ctx context.Context, keys map[PublicKeyRequest]PublicKeyLookupResult,
|
||||||
) error {
|
) error {
|
||||||
return &testErrorStore
|
return &testErrorStore
|
||||||
}
|
}
|
||||||
|
@ -37,8 +37,6 @@ type ServerName string
|
|||||||
type ServerKeys struct {
|
type ServerKeys struct {
|
||||||
// Copy of the raw JSON for signature checking.
|
// Copy of the raw JSON for signature checking.
|
||||||
Raw []byte
|
Raw []byte
|
||||||
// The server the raw JSON was downloaded from.
|
|
||||||
FromServer ServerName
|
|
||||||
// The decoded JSON fields.
|
// The decoded JSON fields.
|
||||||
ServerKeyFields
|
ServerKeyFields
|
||||||
}
|
}
|
||||||
@ -140,7 +138,6 @@ func FetchKeysDirect(serverName ServerName, addr, sni string) (*ServerKeys, *tls
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
var keys ServerKeys
|
var keys ServerKeys
|
||||||
keys.FromServer = serverName
|
|
||||||
if err = json.NewDecoder(response.Body).Decode(&keys); err != nil {
|
if err = json.NewDecoder(response.Body).Decode(&keys); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
"golint",
|
"golint",
|
||||||
"varcheck",
|
"varcheck",
|
||||||
"structcheck",
|
"structcheck",
|
||||||
"aligncheck",
|
"maligned",
|
||||||
"ineffassign",
|
"ineffassign",
|
||||||
"gas",
|
"gas",
|
||||||
"misspell",
|
"misspell",
|
||||||
|
Loading…
Reference in New Issue
Block a user