diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 7eb4567f..aac5bc36 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -68,12 +68,25 @@ func (d *Database) eventTypeNIDs( ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes) - if err != nil { - return nil, err + // first try the cache + fetchEventTypes := make([]string, 0, len(eventTypes)) + for _, eventType := range eventTypes { + eventTypeNID, ok := d.Cache.GetEventTypeKey(eventType) + if ok { + result[eventType] = eventTypeNID + continue + } + fetchEventTypes = append(fetchEventTypes, eventType) } - for eventType, nid := range nids { - result[eventType] = nid + if len(fetchEventTypes) > 0 { + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, fetchEventTypes) + if err != nil { + return nil, err + } + for eventType, nid := range nids { + result[eventType] = nid + d.Cache.StoreEventTypeKey(nid, eventType) + } } return result, nil } @@ -90,13 +103,15 @@ func (d *Database) EventStateKeys( fetch = append(fetch, nid) } } - fromDB, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, fetch) - if err != nil { - return nil, err - } - for nid, key := range fromDB { - result[nid] = key - d.Cache.StoreEventStateKey(nid, key) + if len(fetch) > 0 { + fromDB, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, fetch) + if err != nil { + return nil, err + } + for nid, key := range fromDB { + result[nid] = key + d.Cache.StoreEventStateKey(nid, key) + } } return result, nil } @@ -130,6 +145,7 @@ func (d *Database) eventStateKeyNIDs( } for eventStateKey, nid := range nids { result[eventStateKey] = nid + d.Cache.StoreEventStateKey(nid, eventStateKey) } }