diff --git a/internal/caching/cache_serverkeys.go b/internal/caching/cache_serverkeys.go index 8c71ffbd..b5e31575 100644 --- a/internal/caching/cache_serverkeys.go +++ b/internal/caching/cache_serverkeys.go @@ -2,6 +2,7 @@ package caching import ( "fmt" + "time" "github.com/matrix-org/gomatrixserverlib" ) @@ -23,9 +24,17 @@ func (c Caches) GetServerKey( request gomatrixserverlib.PublicKeyLookupRequest, ) (gomatrixserverlib.PublicKeyLookupResult, bool) { key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID) + now := gomatrixserverlib.AsTimestamp(time.Now()) val, found := c.ServerKeys.Get(key) if found && val != nil { if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok { + if !keyLookupResult.WasValidAt(now, true) { + // We appear to be past the key validity so don't return this + // with the results. This ensures that the cache doesn't return + // values that are not useful to us. + c.ServerKeys.Unset(key) + return gomatrixserverlib.PublicKeyLookupResult{}, false + } return keyLookupResult, true } } diff --git a/internal/caching/caches.go b/internal/caching/caches.go index 70f380ba..419623e2 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -12,4 +12,5 @@ type Caches struct { type Cache interface { Get(key string) (value interface{}, ok bool) Set(key string, value interface{}) + Unset(key string) } diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index f7901d2e..158deca4 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -68,6 +68,13 @@ func (c *InMemoryLRUCachePartition) Set(key string, value interface{}) { c.lru.Add(key, value) } +func (c *InMemoryLRUCachePartition) Unset(key string) { + if !c.mutable { + panic(fmt.Sprintf("invalid use of immutable cache tries to unset value of %q", key)) + } + c.lru.Remove(key) +} + func (c *InMemoryLRUCachePartition) Get(key string) (value interface{}, ok bool) { return c.lru.Get(key) } diff --git a/serverkeyapi/internal/api.go b/serverkeyapi/internal/api.go index 92d6a70b..7a35aa8e 100644 --- a/serverkeyapi/internal/api.go +++ b/serverkeyapi/internal/api.go @@ -22,7 +22,7 @@ func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing { // and keeping the cache up-to-date. return &gomatrixserverlib.KeyRing{ KeyDatabase: s, - KeyFetchers: []gomatrixserverlib.KeyFetcher{s}, + KeyFetchers: []gomatrixserverlib.KeyFetcher{}, } } @@ -45,15 +45,17 @@ func (s *ServerKeyAPI) FetchKeys( // because the caller gives up waiting. ctx := context.Background() results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} + now := gomatrixserverlib.AsTimestamp(time.Now()) // First consult our local database and see if we have the requested // keys. These might come from a cache, depending on the database // implementation used. - now := gomatrixserverlib.AsTimestamp(time.Now()) if dbResults, err := s.OurKeyRing.KeyDatabase.FetchKeys(ctx, requests); err == nil { // We successfully got some keys. Add them to the results and // remove them from the request list. for req, res := range dbResults { - if now > res.ValidUntilTS && res.ExpiredTS == gomatrixserverlib.PublicKeyNotExpired { + if !res.WasValidAt(now, true) { + // We appear to be past the key validity. Don't return this + // key with the results. continue } results[req] = res @@ -71,6 +73,11 @@ func (s *ServerKeyAPI) FetchKeys( // We successfully got some keys. Add them to the results and // remove them from the request list. for req, res := range fetcherResults { + if !res.WasValidAt(now, true) { + // We appear to be past the key validity. Don't return this + // key with the results. + continue + } results[req] = res delete(requests, req) } diff --git a/serverkeyapi/inthttp/client.go b/serverkeyapi/inthttp/client.go index f22b0e31..2587160d 100644 --- a/serverkeyapi/inthttp/client.go +++ b/serverkeyapi/inthttp/client.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net/http" - "time" "github.com/matrix-org/dendrite/internal/caching" internalHTTP "github.com/matrix-org/dendrite/internal/http" @@ -50,7 +49,7 @@ func (s *httpServerKeyInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { // the other end of the API. return &gomatrixserverlib.KeyRing{ KeyDatabase: s, - KeyFetchers: []gomatrixserverlib.KeyFetcher{s}, + KeyFetchers: []gomatrixserverlib.KeyFetcher{}, } } @@ -90,12 +89,8 @@ func (s *httpServerKeyInternalAPI) FetchKeys( response := api.QueryPublicKeysResponse{ Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult), } - now := gomatrixserverlib.AsTimestamp(time.Now()) for req, ts := range requests { if res, ok := s.cache.GetServerKey(req); ok { - if now > res.ValidUntilTS && res.ExpiredTS == gomatrixserverlib.PublicKeyNotExpired { - continue - } result[req] = res continue } diff --git a/serverkeyapi/inthttp/server.go b/serverkeyapi/inthttp/server.go index 9efe7d9d..fd4b72c7 100644 --- a/serverkeyapi/inthttp/server.go +++ b/serverkeyapi/inthttp/server.go @@ -8,7 +8,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/serverkeyapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -35,12 +34,7 @@ func AddRoutes(s api.ServerKeyInternalAPI, internalAPIMux *mux.Router, cache cac if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - store := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) - for req, res := range request.Keys { - store[req] = res - cache.StoreServerKey(req, res) - } - if err := s.StoreKeys(req.Context(), store); err != nil { + if err := s.StoreKeys(req.Context(), request.Keys); err != nil { return util.ErrorResponse(err) } return util.JSONResponse{Code: http.StatusOK, JSON: &response}