SelectJoinedHostsForRooms should use QueryVariadic on SQLite (#1238)

* SelectJoinedHostsForRooms should use QueryVariadic on SQLite

* Fix strings.Replace

* Fix statement
This commit is contained in:
Neil Alexander 2020-08-05 10:00:35 +01:00 committed by GitHub
parent 2197e54441
commit 22f028e141
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,6 +18,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -63,13 +64,13 @@ const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -93,9 +94,6 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error)
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return return
} }
if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return
}
return return
} }
@ -168,7 +166,8 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
iRoomIDs[i] = roomIDs[i] iRoomIDs[i] = roomIDs[i]
} }
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...) sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }