From e216c2fbf0fd117ddb8b96b05d514b9987cbb0d2 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 21 Jul 2023 08:34:01 +0200 Subject: [PATCH] Update ConnectionManager to still allow component defined connections (#3154) --- internal/sqlutil/connection_manager.go | 69 ++++++++++++--------- internal/sqlutil/connection_manager_test.go | 22 +++++++ 2 files changed, 61 insertions(+), 30 deletions(-) diff --git a/internal/sqlutil/connection_manager.go b/internal/sqlutil/connection_manager.go index 4933cfaf..437da6c8 100644 --- a/internal/sqlutil/connection_manager.go +++ b/internal/sqlutil/connection_manager.go @@ -17,16 +17,21 @@ package sqlutil import ( "database/sql" "fmt" + "sync" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" ) type Connections struct { - db *sql.DB - writer Writer - globalConfig config.DatabaseOptions - processContext *process.ProcessContext + globalConfig config.DatabaseOptions + processContext *process.ProcessContext + existingConnections sync.Map +} + +type con struct { + db *sql.DB + writer Writer } func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) *Connections { @@ -38,9 +43,13 @@ func NewConnectionManager(processCtx *process.ProcessContext, globalConfig confi func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, Writer, error) { var err error + // If no connectionString was provided, try the global one if dbProperties.ConnectionString == "" { - // if no connectionString was provided, try the global one dbProperties = &c.globalConfig + // If we still don't have a connection string, that's a problem + if dbProperties.ConnectionString == "" { + return nil, nil, fmt.Errorf("no database connections configured") + } } writer := NewDummyWriter() @@ -48,30 +57,30 @@ func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, writer = NewExclusiveWriter() } - if dbProperties.ConnectionString != "" && c.db == nil { - // Open a new database connection using the supplied config. - c.db, err = Open(dbProperties, writer) - if err != nil { - return nil, nil, err + existing, loaded := c.existingConnections.LoadOrStore(dbProperties.ConnectionString, &con{}) + if loaded { + // We found an existing connection + ex := existing.(*con) + return ex.db, ex.writer, nil + } + + // Open a new database connection using the supplied config. + db, err := Open(dbProperties, writer) + if err != nil { + return nil, nil, err + } + c.existingConnections.Store(dbProperties.ConnectionString, &con{db: db, writer: writer}) + go func() { + if c.processContext == nil { + return } - c.writer = writer - go func() { - if c.processContext == nil { - return - } - // If we have a ProcessContext, start a component and wait for - // Dendrite to shut down to cleanly close the database connection. - c.processContext.ComponentStarted() - <-c.processContext.WaitForShutdown() - _ = c.db.Close() - c.processContext.ComponentFinished() - }() - return c.db, c.writer, nil - } - if c.db != nil && c.writer != nil { - // Ignore the supplied config and return the global pool and - // writer. - return c.db, c.writer, nil - } - return nil, nil, fmt.Errorf("no database connections configured") + // If we have a ProcessContext, start a component and wait for + // Dendrite to shut down to cleanly close the database connection. + c.processContext.ComponentStarted() + <-c.processContext.WaitForShutdown() + _ = db.Close() + c.processContext.ComponentFinished() + }() + return db, writer, nil + } diff --git a/internal/sqlutil/connection_manager_test.go b/internal/sqlutil/connection_manager_test.go index 965d3b9b..5086684b 100644 --- a/internal/sqlutil/connection_manager_test.go +++ b/internal/sqlutil/connection_manager_test.go @@ -48,6 +48,22 @@ func TestConnectionManager(t *testing.T) { if !reflect.DeepEqual(writer, writer2) { t.Fatalf("expected database writer to be reused") } + + // This test does not work with Postgres, because we can't just simply append + // "x" or replace the database to use. + if dbType == test.DBTypePostgres { + return + } + + // Test different connection string + dbProps = &config.DatabaseOptions{ConnectionString: config.DataSource(conStr + "x")} + db3, _, err := cm.Connection(dbProps) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(db, db3) { + t.Fatalf("expected different database connection") + } }) }) @@ -115,4 +131,10 @@ func TestConnectionManager(t *testing.T) { if err == nil { t.Fatal("expected an error but got none") } + + // empty connection string is not allowed + _, _, err = cm2.Connection(&config.DatabaseOptions{}) + if err == nil { + t.Fatal("expected an error but got none") + } }