audit-go/internal/messaging/amqp_connection_pool.go
Christian Schaible (EXT) 85aae1c2e7 Merged PR 779949: feat: Refactor module structure to reflect best practices
Security-concept-update-needed: false.

JIRA Work Item: STACKITALO-259
2025-05-19 11:54:00 +00:00

231 lines
6.5 KiB
Go

package messaging
import (
"errors"
"fmt"
"log/slog"
"sync"
pkgCommon "dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/pkg/messaging/common"
)
type connectionProvider interface {
NewAmqpConnection(config pkgCommon.AmqpConnectionConfig, connectionName string) *AmqpConnection
}
type defaultAmqpConnectionProvider struct{}
func (p defaultAmqpConnectionProvider) NewAmqpConnection(config pkgCommon.AmqpConnectionConfig, connectionName string) *AmqpConnection {
return NewAmqpConnection(config, connectionName)
}
var _ connectionProvider = (*defaultAmqpConnectionProvider)(nil)
type ConnectionPool interface {
Close() error
NewHandle() *ConnectionPoolHandle
GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error)
}
type AmqpConnectionPool struct {
Config pkgCommon.AmqpConnectionPoolConfig
ConnectionName string
Connections []*AmqpConnection
ConnectionProvider connectionProvider
HandleOffset int
Lock sync.RWMutex
}
type ConnectionPoolHandle struct {
ConnectionOffset int
}
func NewDefaultAmqpConnectionPool(config pkgCommon.AmqpConnectionConfig, connectionName string) (ConnectionPool, error) {
poolConfig := pkgCommon.AmqpConnectionPoolConfig{
Parameters: config,
PoolSize: 2,
}
return NewAmqpConnectionPool(poolConfig, connectionName)
}
func NewAmqpConnectionPool(config pkgCommon.AmqpConnectionPoolConfig, connectionName string) (ConnectionPool, error) {
if config.PoolSize == 0 {
config.PoolSize = 2
}
pool := &AmqpConnectionPool{
Config: config,
ConnectionName: connectionName,
Connections: make([]*AmqpConnection, 0),
ConnectionProvider: defaultAmqpConnectionProvider{},
HandleOffset: 0,
Lock: sync.RWMutex{},
}
if err := pool.initializeConnections(); err != nil {
if closeErr := pool.Close(); closeErr != nil {
return nil, errors.Join(err, fmt.Errorf("initialize amqp connection: pool closed: %w", closeErr))
}
return nil, fmt.Errorf("initialize connections: %w", err)
}
return pool, nil
}
func (p *AmqpConnectionPool) initializeConnections() error {
if len(p.Connections) < p.Config.PoolSize {
p.Lock.Lock()
defer p.Lock.Unlock()
numMissingConnections := p.Config.PoolSize - len(p.Connections)
for i := 0; i < numMissingConnections; i++ {
if err := p.internalAddConnection(); err != nil {
return err
}
}
}
return nil
}
func (p *AmqpConnectionPool) internalAddConnection() error {
newConnection, err := p.internalNewConnection()
if err != nil {
return fmt.Errorf("new connection: %w", err)
}
p.Connections = append(p.Connections, newConnection)
return nil
}
func (p *AmqpConnectionPool) internalNewConnection() (*AmqpConnection, error) {
conn := p.ConnectionProvider.NewAmqpConnection(p.Config.Parameters, p.ConnectionName)
if err := conn.Connect(); err != nil {
slog.Warn("amqp connection: failed to connect to amqp broker", slog.Any("err", err))
// retry
if err = conn.Connect(); err != nil {
connectErr := fmt.Errorf("new internal connection: %w", err)
if closeErr := conn.Close(); closeErr != nil {
// this case should never happen as the inner connection should always be null, therefore
// it should not have to be closed, i.e. be able to return errors.
return nil, errors.Join(connectErr, fmt.Errorf("close connection: %w", closeErr))
}
return nil, connectErr
}
}
return conn, nil
}
func (p *AmqpConnectionPool) Close() error {
p.Lock.Lock()
defer p.Lock.Unlock()
closeErrors := make([]error, 0)
for _, conn := range p.Connections {
if conn != nil {
if err := conn.Close(); err != nil {
closeErrors = append(closeErrors, fmt.Errorf("pooled connection: %w", err))
}
}
}
p.Connections = make([]*AmqpConnection, p.Config.PoolSize)
if len(closeErrors) > 0 {
return errors.Join(closeErrors...)
}
return nil
}
func (p *AmqpConnectionPool) NewHandle() *ConnectionPoolHandle {
p.Lock.Lock()
defer p.Lock.Unlock()
offset := p.HandleOffset
p.HandleOffset++
offset %= p.Config.PoolSize
return &ConnectionPoolHandle{
ConnectionOffset: offset,
}
}
func (p *AmqpConnectionPool) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) {
// get the requested connection or another one
conn, addConnection := p.nextConnectionForHandle(handle)
// renew the requested connection if the request connection is closed
if conn == nil || addConnection {
p.Lock.Lock()
// check that accessing the pool only with a valid index (out of bounds should only occur on shutdown)
connectionIndex := p.connectionIndex(handle, 0)
if p.Connections[connectionIndex] == nil {
connection, err := p.internalNewConnection()
if err != nil {
if conn == nil {
// case: connection could not be renewed and no connection to return has been found
p.Lock.Unlock()
return nil, fmt.Errorf("renew connection: %w", err)
}
// case: connection could not be renewed but another connection will be returned
slog.Warn("amqp connection pool: get connection: renew connection: ", slog.Any("err", err))
} else {
// case: connection could be renewed and will be added to pool
p.Connections[connectionIndex] = connection
conn = connection
}
}
p.Lock.Unlock()
}
if conn == nil {
return nil, fmt.Errorf("amqp connection pool: get connection: failed to obtain connection")
}
return conn, nil
}
func (p *AmqpConnectionPool) nextConnectionForHandle(handle *ConnectionPoolHandle) (*AmqpConnection, bool) {
// retry as long as there are remaining connections in the pool
var conn *AmqpConnection
var addConnection bool
for i := 0; i < p.Config.PoolSize; i++ {
// get the next possible connection (considering the retry index)
idx := p.connectionIndex(handle, i)
p.Lock.RLock()
if idx < len(p.Connections) {
conn = p.Connections[idx]
} else {
// handle the edge case that the pool is empty on shutdown
conn = nil
}
p.Lock.RUnlock()
// remember that the requested is closed, retry with the next
if conn == nil {
addConnection = true
continue
}
// if the connection is closed, mark it by setting it to nil
if conn.IsClosed() {
p.Lock.Lock()
p.Connections[idx] = nil
p.Lock.Unlock()
addConnection = true
continue
}
return conn, addConnection
}
return nil, true
}
func (p *AmqpConnectionPool) connectionIndex(handle *ConnectionPoolHandle, iteration int) int {
if iteration+handle.ConnectionOffset >= p.Config.PoolSize {
return (iteration + handle.ConnectionOffset) % p.Config.PoolSize
}
return iteration + handle.ConnectionOffset
}