Use a Postgres database rather than Memory for Naffka (#337)

* Update naffka dep

* User Postgres database rather than Memory for Naffka
This commit is contained in:
Erik Johnston 2017-11-16 17:35:28 +00:00 committed by GitHub
parent bdc44c4bde
commit 8599a36fa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 567 additions and 94 deletions

View File

@ -16,6 +16,7 @@ package main
import ( import (
"context" "context"
"database/sql"
"flag" "flag"
"net/http" "net/http"
"os" "os"
@ -199,7 +200,21 @@ func (m *monolith) setupFederation() {
func (m *monolith) setupKafka() { func (m *monolith) setupKafka() {
if m.cfg.Kafka.UseNaffka { if m.cfg.Kafka.UseNaffka {
naff, err := naffka.New(&naffka.MemoryDatabase{}) db, err := sql.Open("postgres", string(m.cfg.Database.Naffka))
if err != nil {
log.WithFields(log.Fields{
log.ErrorKey: err,
}).Panic("Failed to open naffka database")
}
naffkaDB, err := naffka.NewPostgresqlDatabase(db)
if err != nil {
log.WithFields(log.Fields{
log.ErrorKey: err,
}).Panic("Failed to setup naffka database")
}
naff, err := naffka.New(naffkaDB)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
log.ErrorKey: err, log.ErrorKey: err,

View File

@ -148,6 +148,8 @@ type Dendrite struct {
// The PublicRoomsAPI database stores information used to compute the public // The PublicRoomsAPI database stores information used to compute the public
// room directory. It is only accessed by the PublicRoomsAPI server. // room directory. It is only accessed by the PublicRoomsAPI server.
PublicRoomsAPI DataSource `yaml:"public_rooms_api"` PublicRoomsAPI DataSource `yaml:"public_rooms_api"`
// The Naffka database is used internally by the naffka library, if used.
Naffka DataSource `yaml:"naffka,omitempty"`
} `yaml:"database"` } `yaml:"database"`
// TURN Server Config // TURN Server Config
@ -386,6 +388,8 @@ func (config *Dendrite) check(monolithic bool) error {
if !monolithic { if !monolithic {
problems = append(problems, fmt.Sprintf("naffka can only be used in a monolithic server")) problems = append(problems, fmt.Sprintf("naffka can only be used in a monolithic server"))
} }
checkNotEmpty("database.naffka", string(config.Database.Naffka))
} else { } else {
// If we aren't using naffka then we need to have at least one kafka // If we aren't using naffka then we need to have at least one kafka
// server to talk to. // server to talk to.

2
vendor/manifest vendored
View File

@ -141,7 +141,7 @@
{ {
"importpath": "github.com/matrix-org/naffka", "importpath": "github.com/matrix-org/naffka",
"repository": "https://github.com/matrix-org/naffka", "repository": "https://github.com/matrix-org/naffka",
"revision": "d28656e34f96a8eeaab53e3b7678c9ce14af5786", "revision": "662bfd0841d0194bfe0a700d54226bb96eac574d",
"branch": "master" "branch": "master"
}, },
{ {

View File

@ -8,7 +8,8 @@ import (
// A MemoryDatabase stores the message history as arrays in memory. // A MemoryDatabase stores the message history as arrays in memory.
// It can be used to run unit tests. // It can be used to run unit tests.
// If the process is stopped then any messages that haven't been // If the process is stopped then any messages that haven't been
// processed by a consumer are lost forever. // processed by a consumer are lost forever and all offsets become
// invalid.
type MemoryDatabase struct { type MemoryDatabase struct {
topicsMutex sync.Mutex topicsMutex sync.Mutex
topics map[string]*memoryDatabaseTopic topics map[string]*memoryDatabaseTopic
@ -58,10 +59,7 @@ func (m *MemoryDatabase) getTopic(topicName string) *memoryDatabaseTopic {
// StoreMessages implements Database // StoreMessages implements Database
func (m *MemoryDatabase) StoreMessages(topic string, messages []Message) error { func (m *MemoryDatabase) StoreMessages(topic string, messages []Message) error {
if err := m.getTopic(topic).addMessages(messages); err != nil { return m.getTopic(topic).addMessages(messages)
return err
}
return nil
} }
// FetchMessages implements Database // FetchMessages implements Database
@ -73,10 +71,10 @@ func (m *MemoryDatabase) FetchMessages(topic string, startOffset, endOffset int6
if startOffset >= endOffset { if startOffset >= endOffset {
return nil, fmt.Errorf("start offset %d greater than or equal to end offset %d", startOffset, endOffset) return nil, fmt.Errorf("start offset %d greater than or equal to end offset %d", startOffset, endOffset)
} }
if startOffset < -1 { if startOffset < 0 {
return nil, fmt.Errorf("start offset %d less than -1", startOffset) return nil, fmt.Errorf("start offset %d less than 0", startOffset)
} }
return messages[startOffset+1 : endOffset], nil return messages[startOffset:endOffset], nil
} }
// MaxOffsets implements Database // MaxOffsets implements Database

View File

@ -13,6 +13,7 @@ import (
// single go process. It implements both the sarama.SyncProducer and the // single go process. It implements both the sarama.SyncProducer and the
// sarama.Consumer interfaces. This means it can act as a drop in replacement // sarama.Consumer interfaces. This means it can act as a drop in replacement
// for kafka for testing or single instance deployment. // for kafka for testing or single instance deployment.
// Does not support multiple partitions.
type Naffka struct { type Naffka struct {
db Database db Database
topicsMutex sync.Mutex topicsMutex sync.Mutex
@ -28,6 +29,7 @@ func New(db Database) (*Naffka, error) {
} }
for topicName, offset := range maxOffsets { for topicName, offset := range maxOffsets {
n.topics[topicName] = &topic{ n.topics[topicName] = &topic{
db: db,
topicName: topicName, topicName: topicName,
nextOffset: offset + 1, nextOffset: offset + 1,
} }
@ -64,7 +66,7 @@ type Database interface {
// So for a given topic the message with offset n+1 is stored after the // So for a given topic the message with offset n+1 is stored after the
// the message with offset n. // the message with offset n.
StoreMessages(topic string, messages []Message) error StoreMessages(topic string, messages []Message) error
// FetchMessages fetches all messages with an offset greater than but not // FetchMessages fetches all messages with an offset greater than and
// including startOffset and less than but not including endOffset. // including startOffset and less than but not including endOffset.
// The range of offsets requested must not overlap with those stored by a // The range of offsets requested must not overlap with those stored by a
// concurrent StoreMessages. The message offsets within the requested range // concurrent StoreMessages. The message offsets within the requested range
@ -138,6 +140,7 @@ func (n *Naffka) Partitions(topic string) ([]int32, error) {
} }
// ConsumePartition implements sarama.Consumer // ConsumePartition implements sarama.Consumer
// Note: offset is *inclusive*, i.e. it will include the message with that offset.
func (n *Naffka) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) { func (n *Naffka) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) {
if partition != 0 { if partition != 0 {
return nil, fmt.Errorf("Unknown partition ID %d", partition) return nil, fmt.Errorf("Unknown partition ID %d", partition)
@ -166,13 +169,16 @@ func (n *Naffka) Close() error {
const channelSize = 1024 const channelSize = 1024
// partitionConsumer ensures that all messages written to a particular
// topic, from an offset, get sent in order to a channel.
// Implements sarama.PartitionConsumer
type partitionConsumer struct { type partitionConsumer struct {
topic *topic topic *topic
messages chan *sarama.ConsumerMessage messages chan *sarama.ConsumerMessage
// Whether the consumer is ready for new messages or whether it // Whether the consumer is in "catchup" mode or not.
// is catching up on historic messages. // See "catchup" function for details.
// Reads and writes to this field are proctected by the topic mutex. // Reads and writes to this field are proctected by the topic mutex.
ready bool catchingUp bool
} }
// AsyncClose implements sarama.PartitionConsumer // AsyncClose implements sarama.PartitionConsumer
@ -201,66 +207,101 @@ func (c *partitionConsumer) HighWaterMarkOffset() int64 {
return c.topic.highwaterMark() return c.topic.highwaterMark()
} }
// block writes the message to the consumer blocking until the consumer is ready // catchup makes the consumer go into "catchup" mode, where messages are read
// to add the message to the channel. Once the message is successfully added to // from the database instead of directly from producers.
// the channel it will catch up by pulling historic messsages from the database. // Once the consumer is up to date, i.e. no new messages in the database, then
func (c *partitionConsumer) block(cmsg *sarama.ConsumerMessage) { // the consumer will go back into normal mode where new messages are written
c.messages <- cmsg // directly to the channel.
c.catchup(cmsg.Offset) // Must be called with the c.topic.mutex lock
func (c *partitionConsumer) catchup(fromOffset int64) {
// If we're already in catchup mode or up to date, noop
if c.catchingUp || fromOffset == c.topic.nextOffset {
return
}
c.catchingUp = true
// Due to the checks above there can only be one of these goroutines
// running at a time
go func() {
for {
// Check if we're up to date yet. If we are we exit catchup mode.
c.topic.mutex.Lock()
nextOffset := c.topic.nextOffset
if fromOffset == nextOffset {
c.catchingUp = false
c.topic.mutex.Unlock()
return
}
c.topic.mutex.Unlock()
// Limit the number of messages we request from the database to be the
// capacity of the channel.
if nextOffset > fromOffset+int64(cap(c.messages)) {
nextOffset = fromOffset + int64(cap(c.messages))
}
// Fetch the messages from the database.
msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset)
if err != nil {
// TODO: Add option to write consumer errors to an errors channel
// as an alternative to logging the errors.
log.Print("Error reading messages: ", err)
// Wait before retrying.
// TODO: Maybe use an exponentional backoff scheme here.
// TODO: This timeout should take account of all the other goroutines
// that might be doing the same thing. (If there are a 10000 consumers
// then we don't want to end up retrying every millisecond)
time.Sleep(10 * time.Second)
continue
}
if len(msgs) == 0 {
// This should only happen if the database is corrupted and has lost the
// messages between the requested offsets.
log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset)
}
// Pass the messages into the consumer channel.
// Blocking each write until the channel has enough space for the message.
for i := range msgs {
c.messages <- msgs[i].consumerMessage(c.topic.topicName)
}
// Update our the offset for the next loop iteration.
fromOffset = msgs[len(msgs)-1].Offset + 1
}
}()
} }
// catchup reads historic messages from the database until the consumer has caught // notifyNewMessage tells the consumer about a new message
// up on all the historic messages. // Must be called with the c.topic.mutex lock
func (c *partitionConsumer) catchup(fromOffset int64) { func (c *partitionConsumer) notifyNewMessage(cmsg *sarama.ConsumerMessage) {
for { // If we're in "catchup" mode then the catchup routine will send the
// First check if we have caught up. // message later, since cmsg has already been written to the database
caughtUp, nextOffset := c.topic.hasCaughtUp(c, fromOffset) if c.catchingUp {
if caughtUp { return
return }
}
// Limit the number of messages we request from the database to be the
// capacity of the channel.
if nextOffset > fromOffset+int64(cap(c.messages)) {
nextOffset = fromOffset + int64(cap(c.messages))
}
// Fetch the messages from the database.
msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset)
if err != nil {
// TODO: Add option to write consumer errors to an errors channel
// as an alternative to logging the errors.
log.Print("Error reading messages: ", err)
// Wait before retrying.
// TODO: Maybe use an exponentional backoff scheme here.
// TODO: This timeout should take account of all the other goroutines
// that might be doing the same thing. (If there are a 10000 consumers
// then we don't want to end up retrying every millisecond)
time.Sleep(10 * time.Second)
continue
}
if len(msgs) == 0 {
// This should only happen if the database is corrupted and has lost the
// messages between the requested offsets.
log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset)
}
// Pass the messages into the consumer channel. // Otherwise, lets try writing the message directly to the channel
// Blocking each write until the channel has enough space for the message. select {
for i := range msgs { case c.messages <- cmsg:
c.messages <- msgs[i].consumerMessage(c.topic.topicName) default:
} // The messages channel has filled up, so lets go into catchup
// Update our the offset for the next loop iteration. // mode. Once the channel starts being read from again messages
fromOffset = msgs[len(msgs)-1].Offset // will be read from the database
c.catchup(cmsg.Offset)
} }
} }
type topic struct { type topic struct {
db Database db Database
topicName string topicName string
mutex sync.Mutex mutex sync.Mutex
consumers []*partitionConsumer consumers []*partitionConsumer
// nextOffset is the offset that will be assigned to the next message in
// this topic, i.e. one greater than the last message offset.
nextOffset int64 nextOffset int64
} }
// send writes messages to a topic.
func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error { func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
var err error var err error
// Encode the message keys and values. // Encode the message keys and values.
@ -298,21 +339,10 @@ func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
t.nextOffset = offset t.nextOffset = offset
// Now notify the consumers about the messages. // Now notify the consumers about the messages.
for i := range msgs { for _, msg := range msgs {
cmsg := msgs[i].consumerMessage(t.topicName) cmsg := msg.consumerMessage(t.topicName)
for _, c := range t.consumers { for _, c := range t.consumers {
if c.ready { c.notifyNewMessage(cmsg)
select {
case c.messages <- cmsg:
default:
// The consumer wasn't ready to receive a message because
// the channel buffer was full.
// Fork a goroutine to send the message so that we don't
// block sending messages to the other consumers.
c.ready = false
go c.block(cmsg)
}
}
} }
} }
@ -330,27 +360,17 @@ func (t *topic) consume(offset int64) *partitionConsumer {
offset = t.nextOffset offset = t.nextOffset
} }
if offset == sarama.OffsetOldest { if offset == sarama.OffsetOldest {
offset = -1 offset = 0
} }
c.messages = make(chan *sarama.ConsumerMessage, channelSize) c.messages = make(chan *sarama.ConsumerMessage, channelSize)
t.consumers = append(t.consumers, c) t.consumers = append(t.consumers, c)
// Start catching up on historic messages in the background.
go c.catchup(offset)
return c
}
func (t *topic) hasCaughtUp(c *partitionConsumer, offset int64) (bool, int64) { // If we're not streaming from the latest offset we need to go into
t.mutex.Lock() // "catchup" mode
defer t.mutex.Unlock() if offset != t.nextOffset {
// Check if we have caught up while holding a lock on the topic so there c.catchup(offset)
// isn't a way for our check to race with a new message being sent on the topic.
if offset+1 == t.nextOffset {
// We've caught up, the consumer can now receive messages as they are
// sent rather than fetching them from the database.
c.ready = true
return true, t.nextOffset
} }
return false, t.nextOffset return c
} }
func (t *topic) highwaterMark() int64 { func (t *topic) highwaterMark() int64 {

View File

@ -1,6 +1,7 @@
package naffka package naffka
import ( import (
"strconv"
"testing" "testing"
"time" "time"
@ -84,3 +85,142 @@ func TestDelayedReceive(t *testing.T) {
t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value)) t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value))
} }
} }
func TestCatchup(t *testing.T) {
naffka, err := New(&MemoryDatabase{})
if err != nil {
t.Fatal(err)
}
producer := sarama.SyncProducer(naffka)
consumer := sarama.Consumer(naffka)
const topic = "testTopic"
const value = "Hello, World"
message := sarama.ProducerMessage{
Value: sarama.StringEncoder(value),
Topic: topic,
}
if _, _, err = producer.SendMessage(&message); err != nil {
t.Fatal(err)
}
c, err := consumer.ConsumePartition(topic, 0, sarama.OffsetOldest)
if err != nil {
t.Fatal(err)
}
var result *sarama.ConsumerMessage
select {
case result = <-c.Messages():
case _ = <-time.NewTimer(10 * time.Second).C:
t.Fatal("expected to receive a message")
}
if string(result.Value) != value {
t.Fatalf("wrong value: wanted %q got %q", value, string(result.Value))
}
currOffset := result.Offset
const value2 = "Hello, World2"
const value3 = "Hello, World3"
_, _, err = producer.SendMessage(&sarama.ProducerMessage{
Value: sarama.StringEncoder(value2),
Topic: topic,
})
if err != nil {
t.Fatal(err)
}
_, _, err = producer.SendMessage(&sarama.ProducerMessage{
Value: sarama.StringEncoder(value3),
Topic: topic,
})
if err != nil {
t.Fatal(err)
}
t.Logf("Streaming from %q", currOffset+1)
c2, err := consumer.ConsumePartition(topic, 0, currOffset+1)
if err != nil {
t.Fatal(err)
}
var result2 *sarama.ConsumerMessage
select {
case result2 = <-c2.Messages():
case _ = <-time.NewTimer(10 * time.Second).C:
t.Fatal("expected to receive a message")
}
if string(result2.Value) != value2 {
t.Fatalf("wrong value: wanted %q got %q", value2, string(result2.Value))
}
}
func TestChannelSaturation(t *testing.T) {
// The channel returned by c.Messages() has a fixed capacity
naffka, err := New(&MemoryDatabase{})
if err != nil {
t.Fatal(err)
}
producer := sarama.SyncProducer(naffka)
consumer := sarama.Consumer(naffka)
const topic = "testTopic"
const baseValue = "testValue: "
c, err := consumer.ConsumePartition(topic, 0, sarama.OffsetOldest)
if err != nil {
t.Fatal(err)
}
channelSize := cap(c.Messages())
// We want to send enough messages to fill up the channel, so lets double
// the size of the channel. And add three in case its a zero sized channel
numberMessagesToSend := 2*channelSize + 3
var sentMessages []string
for i := 0; i < numberMessagesToSend; i++ {
value := baseValue + strconv.Itoa(i)
message := sarama.ProducerMessage{
Topic: topic,
Value: sarama.StringEncoder(value),
}
sentMessages = append(sentMessages, value)
if _, _, err = producer.SendMessage(&message); err != nil {
t.Fatal(err)
}
}
var result *sarama.ConsumerMessage
j := 0
for ; j < numberMessagesToSend; j++ {
select {
case result = <-c.Messages():
case _ = <-time.NewTimer(10 * time.Second).C:
t.Fatalf("failed to receive message %d out of %d", j+1, numberMessagesToSend)
}
expectedValue := sentMessages[j]
if string(result.Value) != expectedValue {
t.Fatalf("wrong value: wanted %q got %q", expectedValue, string(result.Value))
}
}
select {
case result = <-c.Messages():
t.Fatalf("expected to only receive %d messages", numberMessagesToSend)
default:
}
}

View File

@ -0,0 +1,296 @@
package naffka
import (
"database/sql"
"sync"
"time"
)
const postgresqlSchema = `
-- The topic table assigns each topic a unique numeric ID.
CREATE SEQUENCE IF NOT EXISTS naffka_topic_nid_seq;
CREATE TABLE IF NOT EXISTS naffka_topics (
topic_name TEXT PRIMARY KEY,
topic_nid BIGINT NOT NULL DEFAULT nextval('naffka_topic_nid_seq')
);
-- The messages table contains the actual messages.
CREATE TABLE IF NOT EXISTS naffka_messages (
topic_nid BIGINT NOT NULL,
message_offset BIGINT NOT NULL,
message_key BYTEA NOT NULL,
message_value BYTEA NOT NULL,
message_timestamp_ns BIGINT NOT NULL,
UNIQUE (topic_nid, message_offset)
);
`
const insertTopicSQL = "" +
"INSERT INTO naffka_topics (topic_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING" +
" RETURNING (topic_nid)"
const selectTopicSQL = "" +
"SELECT topic_nid FROM naffka_topics WHERE topic_name = $1"
const selectTopicsSQL = "" +
"SELECT topic_name, topic_nid FROM naffka_topics"
const insertMessageSQL = "" +
"INSERT INTO naffka_messages (topic_nid, message_offset, message_key, message_value, message_timestamp_ns)" +
" VALUES ($1, $2, $3, $4, $5)"
const selectMessagesSQL = "" +
"SELECT message_offset, message_key, message_value, message_timestamp_ns" +
" FROM naffka_messages WHERE topic_nid = $1 AND $2 <= message_offset AND message_offset < $3" +
" ORDER BY message_offset ASC"
const selectMaxOffsetSQL = "" +
"SELECT message_offset FROM naffka_messages WHERE topic_nid = $1" +
" ORDER BY message_offset DESC LIMIT 1"
type postgresqlDatabase struct {
db *sql.DB
topicsMutex sync.Mutex
topicNIDs map[string]int64
insertTopicStmt *sql.Stmt
selectTopicStmt *sql.Stmt
selectTopicsStmt *sql.Stmt
insertMessageStmt *sql.Stmt
selectMessagesStmt *sql.Stmt
selectMaxOffsetStmt *sql.Stmt
}
// NewPostgresqlDatabase creates a new naffka database using a postgresql database.
// Returns an error if there was a problem setting up the database.
func NewPostgresqlDatabase(db *sql.DB) (Database, error) {
var err error
p := &postgresqlDatabase{
db: db,
topicNIDs: map[string]int64{},
}
if _, err = db.Exec(postgresqlSchema); err != nil {
return nil, err
}
for _, s := range []struct {
sql string
stmt **sql.Stmt
}{
{insertTopicSQL, &p.insertTopicStmt},
{selectTopicSQL, &p.selectTopicStmt},
{selectTopicsSQL, &p.selectTopicsStmt},
{insertMessageSQL, &p.insertMessageStmt},
{selectMessagesSQL, &p.selectMessagesStmt},
{selectMaxOffsetSQL, &p.selectMaxOffsetStmt},
} {
*s.stmt, err = db.Prepare(s.sql)
if err != nil {
return nil, err
}
}
return p, nil
}
// StoreMessages implements Database.
func (p *postgresqlDatabase) StoreMessages(topic string, messages []Message) error {
// Store the messages inside a single database transaction.
return withTransaction(p.db, func(txn *sql.Tx) error {
s := txn.Stmt(p.insertMessageStmt)
topicNID, err := p.assignTopicNID(txn, topic)
if err != nil {
return err
}
for _, m := range messages {
_, err = s.Exec(topicNID, m.Offset, m.Key, m.Value, m.Timestamp.UnixNano())
if err != nil {
return err
}
}
return nil
})
}
// FetchMessages implements Database.
func (p *postgresqlDatabase) FetchMessages(topic string, startOffset, endOffset int64) (messages []Message, err error) {
topicNID, err := p.getTopicNID(nil, topic)
if err != nil {
return
}
rows, err := p.selectMessagesStmt.Query(topicNID, startOffset, endOffset)
if err != nil {
return
}
defer rows.Close()
for rows.Next() {
var (
offset int64
key []byte
value []byte
timestampNano int64
)
if err = rows.Scan(&offset, &key, &value, &timestampNano); err != nil {
return
}
messages = append(messages, Message{
Offset: offset,
Key: key,
Value: value,
Timestamp: time.Unix(0, timestampNano),
})
}
return
}
// MaxOffsets implements Database.
func (p *postgresqlDatabase) MaxOffsets() (map[string]int64, error) {
topicNames, err := p.selectTopics()
if err != nil {
return nil, err
}
result := map[string]int64{}
for topicName, topicNID := range topicNames {
// Lookup the maximum offset.
maxOffset, err := p.selectMaxOffset(topicNID)
if err != nil {
return nil, err
}
if maxOffset > -1 {
// Don't include the topic if we haven't sent any messages on it.
result[topicName] = maxOffset
}
// Prefill the numeric ID cache.
p.addTopicNIDToCache(topicName, topicNID)
}
return result, nil
}
// selectTopics fetches the names and numeric IDs for all the topics the
// database is aware of.
func (p *postgresqlDatabase) selectTopics() (map[string]int64, error) {
rows, err := p.selectTopicsStmt.Query()
if err != nil {
return nil, err
}
defer rows.Close()
result := map[string]int64{}
for rows.Next() {
var (
topicName string
topicNID int64
)
if err = rows.Scan(&topicName, &topicNID); err != nil {
return nil, err
}
result[topicName] = topicNID
}
return result, nil
}
// selectMaxOffset selects the maximum offset for a topic.
// Returns -1 if there aren't any messages for that topic.
// Returns an error if there was a problem talking to the database.
func (p *postgresqlDatabase) selectMaxOffset(topicNID int64) (maxOffset int64, err error) {
err = p.selectMaxOffsetStmt.QueryRow(topicNID).Scan(&maxOffset)
if err == sql.ErrNoRows {
return -1, nil
}
return maxOffset, err
}
// getTopicNID finds the numeric ID for a topic.
// The txn argument is optional, this can be used outside a transaction
// by setting the txn argument to nil.
func (p *postgresqlDatabase) getTopicNID(txn *sql.Tx, topicName string) (topicNID int64, err error) {
// Get from the cache.
topicNID = p.getTopicNIDFromCache(topicName)
if topicNID != 0 {
return topicNID, nil
}
// Get from the database
s := p.selectTopicStmt
if txn != nil {
s = txn.Stmt(s)
}
err = s.QueryRow(topicName).Scan(&topicNID)
if err == sql.ErrNoRows {
return 0, nil
}
if err != nil {
return 0, err
}
// Update the shared cache.
p.addTopicNIDToCache(topicName, topicNID)
return topicNID, nil
}
// assignTopicNID assigns a new numeric ID to a topic.
// The txn argument is mandatory, this is always called inside a transaction.
func (p *postgresqlDatabase) assignTopicNID(txn *sql.Tx, topicName string) (topicNID int64, err error) {
// Check if we already have a numeric ID for the topic name.
topicNID, err = p.getTopicNID(txn, topicName)
if err != nil {
return 0, err
}
if topicNID != 0 {
return topicNID, err
}
// We don't have a numeric ID for the topic name so we add an entry to the
// topics table. If the insert stmt succeeds then it will return the ID.
err = txn.Stmt(p.insertTopicStmt).QueryRow(topicName).Scan(&topicNID)
if err == sql.ErrNoRows {
// If the insert stmt succeeded, but didn't return any rows then it
// means that someone has added a row for the topic name between us
// selecting it the first time and us inserting our own row.
// (N.B. postgres only returns modified rows when using "RETURNING")
// So we can now just select the row that someone else added.
// TODO: This is probably unnecessary since naffka writes to a topic
// from a single thread.
return p.getTopicNID(txn, topicName)
}
if err != nil {
return 0, err
}
// Update the cache.
p.addTopicNIDToCache(topicName, topicNID)
return topicNID, nil
}
// getTopicNIDFromCache returns the topicNID from the cache or returns 0 if the
// topic is not in the cache.
func (p *postgresqlDatabase) getTopicNIDFromCache(topicName string) (topicNID int64) {
p.topicsMutex.Lock()
defer p.topicsMutex.Unlock()
return p.topicNIDs[topicName]
}
// addTopicNIDToCache adds the numeric ID for the topic to the cache.
func (p *postgresqlDatabase) addTopicNIDToCache(topicName string, topicNID int64) {
p.topicsMutex.Lock()
defer p.topicsMutex.Unlock()
p.topicNIDs[topicName] = topicNID
}
// withTransaction runs a block of code passing in an SQL transaction
// If the code returns an error or panics then the transactions is rolledback
// Otherwise the transaction is committed.
func withTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
txn.Rollback()
panic(r)
} else if err != nil {
txn.Rollback()
} else {
err = txn.Commit()
}
}()
err = fn(txn)
return
}