From 4ccf6d6f67e7029a0c4f074aca93b0bcf267e5b6 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 28 Feb 2024 20:58:56 +0100 Subject: [PATCH] Cache ACLs regexes (#3336) Since #3334 didn't change much on d.m.org, this is another attempt to speed up startup. Given moderation bots like Mjolnir/Draupnir are in many rooms with quite often the same or similar ACLs, caching the compiled regexes _should_ reduce the startup time. Using a pointer to the `*regexp.Regex` ensures we only store _one_ instance of a regex in memory, instead of potentially storing it hundred of times. This should reduce memory consumption on servers with many rooms with ACLs drastically. (5.1MB vs 1.7MB with this change on my server with 8 ACL'd rooms [3 using the same ACLs]) [skip ci] --- roomserver/acls/acls.go | 47 ++++++++++++++++++++++++------ roomserver/acls/acls_test.go | 56 +++++++++++++++++++++++++++++++++--- 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 660f4f3b..017682e0 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -41,15 +41,21 @@ type ServerACLDatabase interface { } type ServerACLs struct { - acls map[string]*serverACL // room ID -> ACL - aclsMutex sync.RWMutex // protects the above + acls map[string]*serverACL // room ID -> ACL + aclsMutex sync.RWMutex // protects the above + aclRegexCache map[string]**regexp.Regexp // Cache from "serverName" -> pointer to a regex + aclRegexCacheMutex sync.RWMutex // protects the above } func NewServerACLs(db ServerACLDatabase) *ServerACLs { ctx := context.TODO() acls := &ServerACLs{ acls: make(map[string]*serverACL), + // Be generous when creating the cache, as in reality + // there are hundreds of servers in an ACL. + aclRegexCache: make(map[string]**regexp.Regexp, 100), } + // Look up all of the rooms that the current state server knows about. rooms, err := db.GetKnownRooms(ctx) if err != nil { @@ -67,6 +73,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { for _, event := range events { acls.OnServerACLUpdate(event) } + return acls } @@ -78,8 +85,8 @@ type ServerACL struct { type serverACL struct { ServerACL - allowedRegexes []*regexp.Regexp - deniedRegexes []*regexp.Regexp + allowedRegexes []**regexp.Regexp + deniedRegexes []**regexp.Regexp } func compileACLRegex(orig string) (*regexp.Regexp, error) { @@ -89,6 +96,25 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) { return regexp.Compile(escaped) } +// cachedCompileACLRegex is a wrapper around compileACLRegex with added caching +func (s *ServerACLs) cachedCompileACLRegex(orig string) (**regexp.Regexp, error) { + s.aclRegexCacheMutex.RLock() + re, ok := s.aclRegexCache[orig] + if ok { + s.aclRegexCacheMutex.RUnlock() + return re, nil + } + s.aclRegexCacheMutex.RUnlock() + compiled, err := compileACLRegex(orig) + if err != nil { + return nil, err + } + s.aclRegexCacheMutex.Lock() + defer s.aclRegexCacheMutex.Unlock() + s.aclRegexCache[orig] = &compiled + return &compiled, nil +} + func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) { acls := &serverACL{} if err := json.Unmarshal([]byte(strippedEvent.ContentValue), &acls.ServerACL); err != nil { @@ -100,14 +126,14 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) { // special characters and then replace * and ? with their regex counterparts. // https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl for _, orig := range acls.Allowed { - if expr, err := compileACLRegex(orig); err != nil { + if expr, err := s.cachedCompileACLRegex(orig); err != nil { logrus.WithError(err).Errorf("Failed to compile allowed regex") } else { acls.allowedRegexes = append(acls.allowedRegexes, expr) } } for _, orig := range acls.Denied { - if expr, err := compileACLRegex(orig); err != nil { + if expr, err := s.cachedCompileACLRegex(orig); err != nil { logrus.WithError(err).Errorf("Failed to compile denied regex") } else { acls.deniedRegexes = append(acls.deniedRegexes, expr) @@ -118,6 +144,11 @@ func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) { "num_allowed": len(acls.allowedRegexes), "num_denied": len(acls.deniedRegexes), }).Debugf("Updating server ACLs for %q", strippedEvent.RoomID) + + // Clear out Denied and Allowed, now that we have the compiled regexes. + // They are not needed anymore from this point on. + acls.Denied = nil + acls.Allowed = nil s.aclsMutex.Lock() defer s.aclsMutex.Unlock() s.acls[strippedEvent.RoomID] = acls @@ -150,14 +181,14 @@ func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID s // Check if the hostname matches one of the denied regexes. If it does then // the server is banned from the room. for _, expr := range acls.deniedRegexes { - if expr.MatchString(string(serverName)) { + if (*expr).MatchString(string(serverName)) { return true } } // Check if the hostname matches one of the allowed regexes. If it does then // the server is NOT banned from the room. for _, expr := range acls.allowedRegexes { - if expr.MatchString(string(serverName)) { + if (*expr).MatchString(string(serverName)) { return false } } diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go index 9fb6a558..efe1d209 100644 --- a/roomserver/acls/acls_test.go +++ b/roomserver/acls/acls_test.go @@ -15,8 +15,14 @@ package acls import ( + "context" "regexp" "testing" + + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/stretchr/testify/assert" ) func TestOpenACLsWithBlacklist(t *testing.T) { @@ -38,8 +44,8 @@ func TestOpenACLsWithBlacklist(t *testing.T) { ServerACL: ServerACL{ AllowIPLiterals: true, }, - allowedRegexes: []*regexp.Regexp{allowRegex}, - deniedRegexes: []*regexp.Regexp{denyRegex}, + allowedRegexes: []**regexp.Regexp{&allowRegex}, + deniedRegexes: []**regexp.Regexp{&denyRegex}, } if acls.IsServerBannedFromRoom("1.2.3.4", roomID) { @@ -77,8 +83,8 @@ func TestDefaultACLsWithWhitelist(t *testing.T) { ServerACL: ServerACL{ AllowIPLiterals: false, }, - allowedRegexes: []*regexp.Regexp{allowRegex}, - deniedRegexes: []*regexp.Regexp{}, + allowedRegexes: []**regexp.Regexp{&allowRegex}, + deniedRegexes: []**regexp.Regexp{}, } if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) { @@ -103,3 +109,45 @@ func TestDefaultACLsWithWhitelist(t *testing.T) { t.Fatal("Expected qux.com:4567 to be allowed but wasn't") } } + +var ( + content1 = `{"allow":["*"],"allow_ip_literals":false,"deny":["hello.world", "*.hello.world"]}` +) + +type dummyACLDB struct{} + +func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) { + return []string{"1", "2"}, nil +} + +func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { + return []tables.StrippedEvent{ + { + RoomID: "1", + ContentValue: content1, + }, + { + RoomID: "2", + ContentValue: content1, + }, + }, nil +} + +func TestCachedRegex(t *testing.T) { + db := dummyACLDB{} + wantBannedServer := spec.ServerName("hello.world") + + acls := NewServerACLs(db) + + // Check that hello.world is banned in room 1 + banned := acls.IsServerBannedFromRoom(wantBannedServer, "1") + assert.True(t, banned) + + // Check that hello.world is banned in room 2 + banned = acls.IsServerBannedFromRoom(wantBannedServer, "2") + assert.True(t, banned) + + // Check that matrix.hello.world is banned in room 2 + banned = acls.IsServerBannedFromRoom("matrix."+wantBannedServer, "2") + assert.True(t, banned) +}