audit-go/audit/messaging/amqp_connection_pool.go
2025-01-15 07:59:15 +01:00

194 lines
5.1 KiB
Go

package messaging
import (
"context"
"errors"
"fmt"
"log/slog"
"sync"
"time"
)
type AmqpConnectionPool struct {
config AmqpConnectionPoolConfig
connectionName string
connections []*AmqpConnection
handleOffset int
lock sync.RWMutex
}
type ConnectionPoolHandle struct {
connectionOffset int
}
func NewAmqpConnectionPool(ctx context.Context, config AmqpConnectionPoolConfig, connectionName string) (*AmqpConnectionPool, error) {
pool := &AmqpConnectionPool{
config: config,
connectionName: connectionName,
connections: make([]*AmqpConnection, 0),
handleOffset: 0,
lock: sync.RWMutex{},
}
if err := pool.initializeConnections(); err != nil {
if closeErr := pool.Close(ctx); closeErr != nil {
return nil, errors.Join(err, fmt.Errorf("initialize amqp connection: pool closed: %w", closeErr))
}
return nil, fmt.Errorf("initialize connections: %w", err)
}
return pool, nil
}
func (p *AmqpConnectionPool) initializeConnections() error {
if len(p.connections) < p.config.PoolSize {
p.lock.Lock()
defer p.lock.Unlock()
numMissingConnections := p.config.PoolSize - len(p.connections)
for i := 0; i < numMissingConnections; i++ {
if err := p.internalAddConnection(); err != nil {
return err
}
}
}
return nil
}
func (p *AmqpConnectionPool) internalAddConnection() error {
newConnection, err := p.internalNewConnection(context.Background())
if err != nil {
return fmt.Errorf("new connection: %w", err)
}
p.connections = append(p.connections, newConnection)
return nil
}
func (p *AmqpConnectionPool) internalNewConnection(ctx context.Context) (*AmqpConnection, error) {
conn := NewAmqpConnection(&p.config.Parameters, p.connectionName)
if err := conn.Connect(ctx); err != nil {
slog.Warn("amqp connection: failed to connect to amqp broker", slog.Any("err", err))
// retry
if err = conn.Connect(ctx); err != nil {
connectErr := fmt.Errorf("failed to connect to amqp broker: %w", err)
if closeErr := conn.Close(ctx); closeErr != nil {
return nil, errors.Join(connectErr, fmt.Errorf("close connection: %w", closeErr))
}
return nil, connectErr
}
}
return conn, nil
}
func (p *AmqpConnectionPool) Close(ctx context.Context) error {
p.lock.Lock()
defer p.lock.Unlock()
closeErrors := make([]error, 0)
for _, conn := range p.connections {
if err := conn.Close(ctx); err != nil {
closeErrors = append(closeErrors, fmt.Errorf("connection: close: %w", err))
}
}
if len(closeErrors) > 0 {
return errors.Join(closeErrors...)
}
p.connections = make([]*AmqpConnection, 0)
return nil
}
func (p *AmqpConnectionPool) NewHandle() *ConnectionPoolHandle {
p.lock.Lock()
defer p.lock.Unlock()
offset := p.handleOffset
p.handleOffset += 1
offset = offset % p.config.PoolSize
return &ConnectionPoolHandle{
connectionOffset: offset,
}
}
func (p *AmqpConnectionPool) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) {
// get the requested connection or another one
conn, addConnection := p.nextConnectionFromQueue(handle)
// renew the requested connection if the request connection is closed
if conn == nil || addConnection {
p.lock.Lock()
// check that accessing the pool only with a valid index (out of bounds should only occur on shutdown)
if handle.connectionOffset < len(p.connections) && p.connections[handle.connectionOffset] == nil {
ctx, cancelFn := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second)
connection, err := p.internalNewConnection(ctx)
cancelFn()
if err != nil {
p.lock.Unlock()
return nil, fmt.Errorf("renew connection: %w", err)
}
p.connections[handle.connectionOffset] = connection
}
p.lock.Unlock()
}
// return the previously returned connection if it is not nil
if conn != nil {
return conn, nil
}
// try to return the renewed connection or another one
conn, _ = p.nextConnectionFromQueue(handle)
if conn == nil {
return nil, errors.New("pool is empty")
}
return conn, nil
}
func (p *AmqpConnectionPool) nextConnectionFromQueue(handle *ConnectionPoolHandle) (*AmqpConnection, bool) {
// return the next possible index (including the retry offset)
nextIndex := func(idx int) int {
if idx+handle.connectionOffset >= p.config.PoolSize {
return idx + handle.connectionOffset - p.config.PoolSize
} else {
return idx + handle.connectionOffset
}
}
// retry as long as there are remaining connections in the pool
var conn *AmqpConnection
var addConnection bool
for i := 0; i < p.config.PoolSize; i++ {
// get the next possible connection (considering the retry index)
idx := nextIndex(i)
p.lock.RLock()
if idx < len(p.connections) {
conn = p.connections[idx]
} else {
// handle the edge case that the pool is empty on shutdown
conn = nil
}
p.lock.RUnlock()
// remember that the requested is closed, retry with the next
if conn == nil {
addConnection = true
continue
}
// if the connection is closed, mark it by setting it to nil
if conn.IsClosed() {
p.lock.Lock()
p.connections[idx] = nil
p.lock.Unlock()
continue
}
return conn, addConnection
}
return nil, true
}