// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package test import ( "context" "database/sql" "encoding/json" "sync" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) type InMemoryRelayDatabase struct { nid int64 nidMutex sync.Mutex transactions map[int64]json.RawMessage associations map[spec.ServerName][]int64 } func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { return &InMemoryRelayDatabase{ nid: 1, nidMutex: sync.Mutex{}, transactions: make(map[int64]json.RawMessage), associations: make(map[spec.ServerName][]int64), } } func (d *InMemoryRelayDatabase) InsertQueueEntry( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName spec.ServerName, nid int64, ) error { if _, ok := d.associations[serverName]; !ok { d.associations[serverName] = []int64{} } d.associations[serverName] = append(d.associations[serverName], nid) return nil } func (d *InMemoryRelayDatabase) DeleteQueueEntries( ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64, ) error { for _, nid := range jsonNIDs { for index, associatedNID := range d.associations[serverName] { if associatedNID == nid { d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) } } } return nil } func (d *InMemoryRelayDatabase) SelectQueueEntries( ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int, ) ([]int64, error) { results := []int64{} resultCount := limit if limit > len(d.associations[serverName]) { resultCount = len(d.associations[serverName]) } if resultCount > 0 { for i := 0; i < resultCount; i++ { results = append(results, d.associations[serverName][i]) } } return results, nil } func (d *InMemoryRelayDatabase) SelectQueueEntryCount( ctx context.Context, txn *sql.Tx, serverName spec.ServerName, ) (int64, error) { return int64(len(d.associations[serverName])), nil } func (d *InMemoryRelayDatabase) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (int64, error) { d.nidMutex.Lock() defer d.nidMutex.Unlock() nid := d.nid d.transactions[nid] = []byte(json) d.nid++ return nid, nil } func (d *InMemoryRelayDatabase) DeleteQueueJSON( ctx context.Context, txn *sql.Tx, nids []int64, ) error { for _, nid := range nids { delete(d.transactions, nid) } return nil } func (d *InMemoryRelayDatabase) SelectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { result := make(map[int64][]byte) for _, nid := range jsonNIDs { if transaction, ok := d.transactions[nid]; ok { result[nid] = transaction } } return result, nil }