dendrite/syncapi/storage/sqlite3/stream_id_table.go
Kegsay ecd7accbad
Rehuffle where things are in the internal package (#1122)
renamed:    internal/eventcontent.go -> internal/eventutil/eventcontent.go
	renamed:    internal/events.go -> internal/eventutil/events.go
	renamed:    internal/types.go -> internal/eventutil/types.go
	renamed:    internal/http/http.go -> internal/httputil/http.go
	renamed:    internal/httpapi.go -> internal/httputil/httpapi.go
	renamed:    internal/httpapi_test.go -> internal/httputil/httpapi_test.go
	renamed:    internal/httpapis/paths.go -> internal/httputil/paths.go
	renamed:    internal/routing.go -> internal/httputil/routing.go
	renamed:    internal/basecomponent/base.go -> internal/setup/base.go
	renamed:    internal/basecomponent/flags.go -> internal/setup/flags.go
	renamed:    internal/partition_offset_table.go -> internal/sqlutil/partition_offset_table.go
	renamed:    internal/postgres.go -> internal/sqlutil/postgres.go
	renamed:    internal/postgres_wasm.go -> internal/sqlutil/postgres_wasm.go
	renamed:    internal/sql.go -> internal/sqlutil/sql.go
2020-06-12 14:55:57 +01:00

59 lines
1.5 KiB
Go

package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/types"
)
const streamIDTableSchema = `
-- Global stream ID counter, used by other tables.
CREATE TABLE IF NOT EXISTS syncapi_stream_id (
stream_name TEXT NOT NULL PRIMARY KEY,
stream_id INT DEFAULT 0,
UNIQUE(stream_name)
);
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
ON CONFLICT DO NOTHING;
`
const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"
const selectStreamIDStmt = "" +
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
type streamIDStatements struct {
increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(streamIDTableSchema)
if err != nil {
return
}
if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
return
}
if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil {
return
}
return
}
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
return
}
if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
return
}
return
}