Merged PR 716929: feat: Replace AMQP connection management

So far the SDK provided a messaging API that was not thread-safe (i.e. goroutine-safe). Additionally the SDK provided a MutexAPI which made it thread-safe at the cost of removed concurrency possibilities. The changes implemented in this commit replace both implementations with a thread-safe connection pool based solution.

The api gateway is a SDK user that requires reliable high performance send capabilities with a limit amount of amqp connections. These changes in the PR try address their requirements by moving the responsibility of connection management into the SDK. From this change other SDK users will benefit as well.

Security-concept-update-needed: false.

JIRA Work Item: STACKITALO-62
This commit is contained in:
Christian Schaible 2025-01-27 13:23:54 +00:00
parent c90ce29c51
commit 5742604629
13 changed files with 2056 additions and 291 deletions

View file

@ -30,7 +30,11 @@ func TestDynamicLegacyAuditApi(t *testing.T) {
defer solaceContainer.Stop()
// Instantiate the messaging api
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString})
messagingApi, err := messaging.NewAmqpApi(
messaging.AmqpConnectionPoolConfig{
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err)
// Validator

View file

@ -32,7 +32,10 @@ func TestLegacyAuditApi(t *testing.T) {
defer solaceContainer.Stop()
// Instantiate the messaging api
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString})
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err)
// Validator
@ -579,7 +582,10 @@ func TestLegacyAuditApi_NewLegacyAuditApi(t *testing.T) {
defer solaceContainer.Stop()
// Instantiate the messaging api
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString})
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err)
// Validator

View file

@ -33,7 +33,10 @@ func TestRoutableAuditApi(t *testing.T) {
defer solaceContainer.Stop()
// Instantiate the messaging api
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString})
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err)
// Validator

View file

@ -0,0 +1,15 @@
package messaging
const AmqpTopicPrefix = "topic://"
const connectionTimeoutSeconds = 10
type AmqpConnectionConfig struct {
BrokerUrl string `json:"brokerUrl"`
Username string `json:"username"`
Password string `json:"password"`
}
type AmqpConnectionPoolConfig struct {
Parameters AmqpConnectionConfig `json:"parameters"`
PoolSize int `json:"poolSize"`
}

View file

@ -0,0 +1,232 @@
package messaging
import (
"context"
"errors"
"fmt"
"github.com/Azure/go-amqp"
"log/slog"
"sync"
"time"
)
var ConnectionClosedError = errors.New("amqp connection is closed")
type AmqpConnection struct {
connectionName string
lock sync.RWMutex
brokerUrl string
username string
password string
conn amqpConn
dialer amqpDial
}
// amqpConn is an abstraction of amqp.Conn
type amqpConn interface {
NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error)
Close() error
Done() <-chan struct{}
}
type defaultAmqpConn struct {
conn *amqp.Conn
}
func newDefaultAmqpConn(conn *amqp.Conn) *defaultAmqpConn {
return &defaultAmqpConn{
conn: conn,
}
}
func (d defaultAmqpConn) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) {
session, err := d.conn.NewSession(ctx, opts)
if err != nil {
return nil, err
}
return newDefaultAmqpSession(session), nil
}
func (d defaultAmqpConn) Close() error {
return d.conn.Close()
}
func (d defaultAmqpConn) Done() <-chan struct{} {
return d.conn.Done()
}
var _ amqpConn = (*defaultAmqpConn)(nil)
type amqpDial interface {
Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error)
}
type amqpSession interface {
NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error)
Close(ctx context.Context) error
}
type defaultAmqpSession struct {
session *amqp.Session
}
func newDefaultAmqpSession(session *amqp.Session) *defaultAmqpSession {
return &defaultAmqpSession{
session: session,
}
}
func (s *defaultAmqpSession) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) {
return s.session.NewSender(ctx, target, opts)
}
func (s *defaultAmqpSession) Close(ctx context.Context) error {
return s.session.Close(ctx)
}
var _ amqpSession = (*defaultAmqpSession)(nil)
type defaultAmqpDialer struct{}
func (d *defaultAmqpDialer) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) {
dial, err := amqp.Dial(ctx, addr, opts)
if err != nil {
return nil, err
}
return newDefaultAmqpConn(dial), nil
}
var _ amqpDial = (*defaultAmqpDialer)(nil)
func NewAmqpConnection(config AmqpConnectionConfig, connectionName string) *AmqpConnection {
return &AmqpConnection{
connectionName: connectionName,
lock: sync.RWMutex{},
brokerUrl: config.BrokerUrl,
username: config.Username,
password: config.Password,
dialer: &defaultAmqpDialer{},
}
}
func (c *AmqpConnection) NewSender(ctx context.Context, topic string) (*AmqpSenderSession, error) {
if c.conn == nil {
return nil, errors.New("connection is not initialized")
}
if c.internalIsClosed() {
return nil, ConnectionClosedError
}
c.lock.RLock()
defer c.lock.RUnlock()
// new session
newSession, err := c.conn.NewSession(ctx, nil)
if err != nil {
return nil, fmt.Errorf("new session: %w", err)
}
// new sender
newSender, err := newSession.NewSender(ctx, topic, nil)
if err != nil {
err = fmt.Errorf("new internal sender: %w", err)
closeErr := newSession.Close(ctx)
if closeErr != nil {
return nil, errors.Join(err, fmt.Errorf("close session: %w", closeErr))
}
return nil, err
}
return &AmqpSenderSession{newSession, newSender}, nil
}
func As[T any](value any, err error) (*T, error) {
if err != nil {
return nil, err
}
if value == nil {
return nil, nil
}
castedValue, isType := value.(*T)
if !isType {
return nil, fmt.Errorf("could not cast value: %T", value)
}
return castedValue, nil
}
func (c *AmqpConnection) Connect() error {
c.lock.Lock()
defer c.lock.Unlock()
subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second)
defer cancel()
if err := c.internalConnect(subCtx); err != nil {
return fmt.Errorf("internal connect: %w", err)
}
return nil
}
func (c *AmqpConnection) internalConnect(ctx context.Context) error {
if c.conn == nil {
// Set credentials if specified
auth := amqp.SASLTypeAnonymous()
if c.username != "" && c.password != "" {
auth = amqp.SASLTypePlain(c.username, c.password)
} else {
slog.Debug("amqp connection: connect: using anonymous messaging")
}
options := &amqp.ConnOptions{
SASLType: auth,
}
// Initialize connection
conn, err := c.dialer.Dial(ctx, c.brokerUrl, options)
if err != nil {
return fmt.Errorf("dial: %w", err)
}
c.conn = conn
}
return nil
}
func (c *AmqpConnection) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.internalClose(); err != nil {
return fmt.Errorf("internal close: %w", err)
}
return nil
}
func (c *AmqpConnection) internalClose() error {
if c.conn != nil {
if err := c.conn.Close(); err != nil {
return fmt.Errorf("connection close: %w", err)
}
c.conn = nil
}
return nil
}
func (c *AmqpConnection) IsClosed() bool {
c.lock.RLock()
defer c.lock.RUnlock()
return c.internalIsClosed()
}
func (c *AmqpConnection) internalIsClosed() bool {
if c.conn == nil {
return true
}
select {
case <-c.conn.Done():
return true
default:
return false
}
}

View file

@ -0,0 +1,219 @@
package messaging
import (
"errors"
"fmt"
"log/slog"
"sync"
)
type connectionProvider interface {
NewAmqpConnection(config AmqpConnectionConfig, connectionName string) *AmqpConnection
}
type defaultAmqpConnectionProvider struct{}
func (p defaultAmqpConnectionProvider) NewAmqpConnection(config AmqpConnectionConfig, connectionName string) *AmqpConnection {
return NewAmqpConnection(config, connectionName)
}
var _ connectionProvider = (*defaultAmqpConnectionProvider)(nil)
type ConnectionPool interface {
Close() error
NewHandle() *ConnectionPoolHandle
GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error)
}
type AmqpConnectionPool struct {
config AmqpConnectionPoolConfig
connectionName string
connections []*AmqpConnection
connectionProvider connectionProvider
handleOffset int
lock sync.RWMutex
}
type ConnectionPoolHandle struct {
connectionOffset int
}
func NewAmqpConnectionPool(config AmqpConnectionPoolConfig, connectionName string) (ConnectionPool, error) {
pool := &AmqpConnectionPool{
config: config,
connectionName: connectionName,
connections: make([]*AmqpConnection, 0),
connectionProvider: defaultAmqpConnectionProvider{},
handleOffset: 0,
lock: sync.RWMutex{},
}
if err := pool.initializeConnections(); err != nil {
if closeErr := pool.Close(); closeErr != nil {
return nil, errors.Join(err, fmt.Errorf("initialize amqp connection: pool closed: %w", closeErr))
}
return nil, fmt.Errorf("initialize connections: %w", err)
}
return pool, nil
}
func (p *AmqpConnectionPool) initializeConnections() error {
if len(p.connections) < p.config.PoolSize {
p.lock.Lock()
defer p.lock.Unlock()
numMissingConnections := p.config.PoolSize - len(p.connections)
for i := 0; i < numMissingConnections; i++ {
if err := p.internalAddConnection(); err != nil {
return err
}
}
}
return nil
}
func (p *AmqpConnectionPool) internalAddConnection() error {
newConnection, err := p.internalNewConnection()
if err != nil {
return fmt.Errorf("new connection: %w", err)
}
p.connections = append(p.connections, newConnection)
return nil
}
func (p *AmqpConnectionPool) internalNewConnection() (*AmqpConnection, error) {
conn := p.connectionProvider.NewAmqpConnection(p.config.Parameters, p.connectionName)
if err := conn.Connect(); err != nil {
slog.Warn("amqp connection: failed to connect to amqp broker", slog.Any("err", err))
// retry
if err = conn.Connect(); err != nil {
connectErr := fmt.Errorf("new internal connection: %w", err)
if closeErr := conn.Close(); closeErr != nil {
// this case should never happen as the inner connection should always be null, therefore
// it should not have to be closed, i.e. be able to return errors.
return nil, errors.Join(connectErr, fmt.Errorf("close connection: %w", closeErr))
}
return nil, connectErr
}
}
return conn, nil
}
func (p *AmqpConnectionPool) Close() error {
p.lock.Lock()
defer p.lock.Unlock()
closeErrors := make([]error, 0)
for _, conn := range p.connections {
if conn != nil {
if err := conn.Close(); err != nil {
closeErrors = append(closeErrors, fmt.Errorf("pooled connection: %w", err))
}
}
}
p.connections = make([]*AmqpConnection, p.config.PoolSize)
if len(closeErrors) > 0 {
return errors.Join(closeErrors...)
}
return nil
}
func (p *AmqpConnectionPool) NewHandle() *ConnectionPoolHandle {
p.lock.Lock()
defer p.lock.Unlock()
offset := p.handleOffset
p.handleOffset += 1
offset = offset % p.config.PoolSize
return &ConnectionPoolHandle{
connectionOffset: offset,
}
}
func (p *AmqpConnectionPool) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) {
// get the requested connection or another one
conn, addConnection := p.nextConnectionForHandle(handle)
// renew the requested connection if the request connection is closed
if conn == nil || addConnection {
p.lock.Lock()
// check that accessing the pool only with a valid index (out of bounds should only occur on shutdown)
connectionIndex := p.connectionIndex(handle, 0)
if p.connections[connectionIndex] == nil {
connection, err := p.internalNewConnection()
if err != nil {
if conn == nil {
// case: connection could not be renewed and no connection to return has been found
p.lock.Unlock()
return nil, fmt.Errorf("renew connection: %w", err)
}
// case: connection could not be renewed but another connection will be returned
slog.Warn("amqp connection pool: get connection: renew connection: ", slog.Any("err", err))
} else {
// case: connection could be renewed and will be added to pool
p.connections[connectionIndex] = connection
conn = connection
}
}
p.lock.Unlock()
}
if conn == nil {
return nil, fmt.Errorf("amqp connection pool: get connection: failed to obtain connection")
}
return conn, nil
}
func (p *AmqpConnectionPool) nextConnectionForHandle(handle *ConnectionPoolHandle) (*AmqpConnection, bool) {
// retry as long as there are remaining connections in the pool
var conn *AmqpConnection
var addConnection bool
for i := 0; i < p.config.PoolSize; i++ {
// get the next possible connection (considering the retry index)
idx := p.connectionIndex(handle, i)
p.lock.RLock()
if idx < len(p.connections) {
conn = p.connections[idx]
} else {
// handle the edge case that the pool is empty on shutdown
conn = nil
}
p.lock.RUnlock()
// remember that the requested is closed, retry with the next
if conn == nil {
addConnection = true
continue
}
// if the connection is closed, mark it by setting it to nil
if conn.IsClosed() {
p.lock.Lock()
p.connections[idx] = nil
p.lock.Unlock()
addConnection = true
continue
}
return conn, addConnection
}
return nil, true
}
func (p *AmqpConnectionPool) connectionIndex(handle *ConnectionPoolHandle, iteration int) int {
if iteration+handle.connectionOffset >= p.config.PoolSize {
return (iteration + handle.connectionOffset) % p.config.PoolSize
} else {
return iteration + handle.connectionOffset
}
}

View file

@ -0,0 +1,578 @@
package messaging
import (
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"sync"
"testing"
)
type connectionProviderMock struct {
mock.Mock
}
func (p *connectionProviderMock) NewAmqpConnection(config AmqpConnectionConfig, connectionName string) *AmqpConnection {
args := p.Called(config, connectionName)
return args.Get(0).(*AmqpConnection)
}
var _ connectionProvider = (*connectionProviderMock)(nil)
func Test_AmqpConnectionPool_GetHandle(t *testing.T) {
t.Run("next handle", func(t *testing.T) {
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
}
handle := pool.NewHandle()
assert.NotNil(t, handle)
assert.Equal(t, 0, handle.connectionOffset)
assert.Equal(t, 1, pool.handleOffset)
})
t.Run("next handle high offset", func(t *testing.T) {
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 13,
lock: sync.RWMutex{},
}
handle := pool.NewHandle()
assert.NotNil(t, handle)
assert.Equal(t, 3, handle.connectionOffset)
assert.Equal(t, 14, pool.handleOffset)
})
}
func Test_AmqpConnectionPool_internalAddConnection(t *testing.T) {
t.Run("internal add connection", func(t *testing.T) {
conn := &amqpConnMock{}
dialer := &amqpDialMock{}
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.internalAddConnection()
assert.NoError(t, err)
assert.Equal(t, 1, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
dialer.AssertNumberOfCalls(t, "Dial", 1)
})
t.Run("dialer error", func(t *testing.T) {
conn := &amqpConnMock{}
dialer := &amqpDialMock{}
var c *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error")).Once()
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.internalAddConnection()
assert.NoError(t, err)
assert.Equal(t, 1, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
dialer.AssertNumberOfCalls(t, "Dial", 2)
})
t.Run("repetitive dialer error", func(t *testing.T) {
dialer := &amqpDialMock{}
var c *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.internalAddConnection()
assert.EqualError(t, err, "new connection: new internal connection: internal connect: dial: test error")
assert.Equal(t, 0, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
dialer.AssertNumberOfCalls(t, "Dial", 2)
})
}
func Test_AmqpConnectionPool_initializeConnections(t *testing.T) {
t.Run("initialize connections successfully", func(t *testing.T) {
conn := &amqpConnMock{}
dialer := &amqpDialMock{}
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.initializeConnections()
assert.NoError(t, err)
assert.Equal(t, 5, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 5)
})
t.Run("fail initialization of connections", func(t *testing.T) {
var c *amqpConnMock = nil
failingDialer := &amqpDialMock{}
failingDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
failingConnection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: failingDialer,
}
conn := &amqpConnMock{}
successfulDialer := &amqpDialMock{}
successfulDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
successfulConnection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: successfulDialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(4)
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(failingConnection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.initializeConnections()
assert.EqualError(t, err, "new connection: new internal connection: internal connect: dial: test error")
assert.Equal(t, 4, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 5)
})
}
func Test_AmqpConnectionPool_Close(t *testing.T) {
t.Run("close connection successfully", func(t *testing.T) {
// add 5 connections to the pool
conn := &amqpConnMock{}
conn.On("Close").Return(nil)
dialer := &amqpDialMock{}
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.initializeConnections()
assert.NoError(t, err)
assert.Equal(t, 5, len(pool.connections))
// close the pool
err = pool.Close()
assert.NoError(t, err)
assert.Equal(t, 5, len(pool.connections))
for _, c := range pool.connections {
assert.Nil(t, c)
}
})
t.Run("close connection fail", func(t *testing.T) {
// add 5 connections to the pool
failingConn := &amqpConnMock{}
failingConn.On("Close").Return(errors.New("test error"))
failingDialer := &amqpDialMock{}
failingDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(failingConn, nil)
failingConnection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: failingDialer,
}
successfulConn := &amqpConnMock{}
successfulConn.On("Close").Return(nil)
successfulDialer := &amqpDialMock{}
successfulDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(successfulConn, nil)
successfulConnection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: successfulDialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(2)
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(failingConnection).Times(2)
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(1)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.initializeConnections()
assert.NoError(t, err)
assert.Equal(t, 5, len(pool.connections))
// close the pool
err = pool.Close()
assert.EqualError(t, err, "pooled connection: internal close: connection close: test error\npooled connection: internal close: connection close: test error")
assert.Equal(t, 5, len(pool.connections))
for _, c := range pool.connections {
assert.Nil(t, c)
}
})
}
func Test_AmqpConnectionPool_nextConnectionForHandle(t *testing.T) {
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
newActiveConnection := func() *AmqpConnection {
channel := make(chan struct{})
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
newClosedConnection := func() *AmqpConnection {
channel := make(chan struct{})
close(channel)
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
t.Run("next connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
for i := 0; i < 5; i++ {
connections = append(connections, newActiveConnection())
}
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
assert.NotNil(t, connection)
assert.False(t, addConnection)
})
t.Run("nil connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, newActiveConnection())
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, newActiveConnection())
connections = append(connections, newActiveConnection())
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
assert.NotNil(t, connection)
assert.True(t, addConnection)
})
t.Run("closed connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, newActiveConnection())
connections = append(connections, newClosedConnection())
connections = append(connections, newClosedConnection())
connections = append(connections, newActiveConnection())
connections = append(connections, newActiveConnection())
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
assert.NotNil(t, connection)
assert.True(t, addConnection)
})
t.Run("no connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
assert.Nil(t, connection)
assert.True(t, addConnection)
})
t.Run("connection for requested handle with large index", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, newActiveConnection())
connections = append(connections, nil)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 23})
assert.NotNil(t, connection)
assert.False(t, addConnection)
})
t.Run("connection for requested handle nil with large index", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, newActiveConnection())
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 23})
assert.NotNil(t, connection)
assert.True(t, addConnection)
})
}
func Test_AmqpConnectionPool_GetConnection(t *testing.T) {
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
newActiveConnection := func() *AmqpConnection {
channel := make(chan struct{})
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
t.Run("get connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
for i := 0; i < 5; i++ {
connections = append(connections, newActiveConnection())
}
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
assert.NoError(t, err)
assert.NotNil(t, connection)
assert.Equal(t, connections[1], connection)
assert.Equal(t, 5, len(connections))
})
t.Run("add connection if missing", func(t *testing.T) {
connections := make([]*AmqpConnection, 5)
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(newActiveConnection())
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
connectionProvider: connectionProvider,
}
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
assert.NoError(t, err)
assert.NotNil(t, connection)
assert.Equal(t, connections[1], connection)
assert.Equal(t, 5, len(connections))
})
t.Run("add connection fails returns alternative connection", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, newActiveConnection())
connections = append(connections, nil)
connections = append(connections, newActiveConnection())
connections = append(connections, newActiveConnection())
connections = append(connections, newActiveConnection())
connectionProvider := &connectionProviderMock{}
dialer := &amqpDialMock{}
var c *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, fmt.Errorf("dial error"))
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
connectionProvider: connectionProvider,
}
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
assert.NoError(t, err)
assert.NotNil(t, connection)
assert.Nil(t, connections[1])
assert.Equal(t, connections[2], connection)
assert.Equal(t, 5, len(connections))
})
t.Run("add connection fails", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connectionProvider := &connectionProviderMock{}
dialer := &amqpDialMock{}
var c *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, fmt.Errorf("dial error"))
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
connectionProvider: connectionProvider,
}
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
assert.EqualError(t, err, "renew connection: new internal connection: internal connect: dial: dial error")
assert.Nil(t, connection)
assert.Equal(t, 5, len(connections))
})
}

View file

@ -0,0 +1,306 @@
package messaging
import (
"context"
"errors"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"sync"
"testing"
)
type amqpConnMock struct {
mock.Mock
}
func (m *amqpConnMock) Done() <-chan struct{} {
args := m.Called()
return args.Get(0).(<-chan struct{})
}
func (m *amqpConnMock) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) {
args := m.Called(ctx, opts)
return args.Get(0).(amqpSession), args.Error(1)
}
func (m *amqpConnMock) Close() error {
args := m.Called()
return args.Error(0)
}
var _ amqpConn = (*amqpConnMock)(nil)
type amqpDialMock struct {
mock.Mock
}
func (m *amqpDialMock) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) {
args := m.Called(ctx, addr, opts)
return args.Get(0).(amqpConn), args.Error(1)
}
var _ amqpDial = (*amqpDialMock)(nil)
type amqpSessionMock struct {
mock.Mock
}
func (m *amqpSessionMock) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) {
args := m.Called(ctx, target, opts)
return args.Get(0).(amqpSender), args.Error(1)
}
func (m *amqpSessionMock) Close(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
var _ amqpSession = (*amqpSessionMock)(nil)
func Test_AmqpConnection_IsClosed(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
t.Run("is closed - connection nil", func(t *testing.T) {
assert.True(t, connection.IsClosed())
})
t.Run("is closed", func(t *testing.T) {
channel := make(chan struct{})
close(channel)
amqpConnMock := &amqpConnMock{}
amqpConnMock.On("Done").Return(channelReceiver(channel))
connection.conn = amqpConnMock
assert.True(t, connection.IsClosed())
})
t.Run("is not closed", func(t *testing.T) {
channel := make(chan struct{})
amqpConnMock := &amqpConnMock{}
amqpConnMock.On("Done").Return(channelReceiver(channel))
connection.conn = amqpConnMock
assert.False(t, connection.IsClosed())
})
}
func Test_AmqpConnection_Close(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
t.Run("already closed", func(t *testing.T) {
assert.NoError(t, connection.Close())
})
t.Run("close error", func(t *testing.T) {
err := errors.New("test error")
amqpConnMock := &amqpConnMock{}
amqpConnMock.On("Close").Return(err)
connection.conn = amqpConnMock
assert.EqualError(t, connection.Close(), "internal close: connection close: test error")
assert.NotNil(t, connection.conn)
amqpConnMock.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close without error", func(t *testing.T) {
amqpConnMock := &amqpConnMock{}
amqpConnMock.On("Close").Return(nil)
connection.conn = amqpConnMock
assert.Nil(t, connection.Close())
assert.Nil(t, connection.conn)
amqpConnMock.AssertNumberOfCalls(t, "Close", 1)
})
}
func Test_AmqpConnection_Connect(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
t.Run("already connected", func(t *testing.T) {
connection.conn = &amqpConnMock{}
assert.NoError(t, connection.Connect())
})
t.Run("dial error", func(t *testing.T) {
connection.conn = nil
connection.username = "user"
connection.password = "pass"
amqpDialMock := &amqpDialMock{}
var c *amqpConnMock = nil
amqpDialMock.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
connection.dialer = amqpDialMock
assert.EqualError(t, connection.Connect(), "internal connect: dial: test error")
assert.Nil(t, connection.conn)
})
t.Run("connect without error", func(t *testing.T) {
connection.conn = nil
amqpDialMock := &amqpDialMock{}
amqpConn := &amqpConnMock{}
amqpDialMock.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(amqpConn, nil)
connection.dialer = amqpDialMock
assert.NoError(t, connection.Connect())
assert.Equal(t, amqpConn, connection.conn)
})
}
func Test_AmqpConnection_NewSender(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
t.Run("connection not initialized", func(t *testing.T) {
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "connection is not initialized")
assert.Nil(t, sender)
})
t.Run("connection is closed", func(t *testing.T) {
channel := make(chan struct{})
close(channel)
conn := &amqpConnMock{}
conn.On("Done").Return(channelReceiver(channel))
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "amqp connection is closed")
assert.Nil(t, sender)
})
t.Run("session error", func(t *testing.T) {
channel := make(chan struct{})
var session *amqpSessionMock = nil
conn := &amqpConnMock{}
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, errors.New("test error"))
conn.On("Done").Return(channelReceiver(channel))
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "new session: test error")
assert.Nil(t, sender)
})
t.Run("sender error", func(t *testing.T) {
channel := make(chan struct{})
sessionMock := &amqpSessionMock{}
var amqpSender *amqp.Sender = nil
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, errors.New("test error"))
sessionMock.On("Close", mock.Anything).Return(nil)
conn := &amqpConnMock{}
conn.On("Done").Return(channelReceiver(channel))
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "new internal sender: test error")
assert.Nil(t, sender)
})
t.Run("session close error", func(t *testing.T) {
channel := make(chan struct{})
sessionMock := &amqpSessionMock{}
var amqpSender *amqp.Sender = nil
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, errors.New("test error"))
sessionMock.On("Close", mock.Anything).Return(errors.New("close error"))
conn := &amqpConnMock{}
conn.On("Done").Return(channelReceiver(channel))
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "new internal sender: test error\nclose session: close error")
assert.Nil(t, sender)
})
t.Run("get sender", func(t *testing.T) {
channel := make(chan struct{})
amqpSender := &amqp.Sender{}
sessionMock := &amqpSessionMock{}
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, nil)
conn := &amqpConnMock{}
conn.On("Done").Return(channelReceiver(channel))
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.NoError(t, err)
assert.NotNil(t, sender)
assert.Equal(t, amqpSender, sender.sender)
assert.Equal(t, sessionMock, sender.session)
})
}
func Test_AmqpConnection_NewAmqpConnection(t *testing.T) {
config := AmqpConnectionConfig{
BrokerUrl: "brokerUrl",
Username: "username",
Password: "password",
}
connection := NewAmqpConnection(config, "connectionName")
assert.NotNil(t, connection)
assert.Equal(t, connection.connectionName, "connectionName")
assert.Equal(t, connection.brokerUrl, "brokerUrl")
assert.Equal(t, connection.username, "username")
assert.Equal(t, connection.password, "password")
assert.NotNil(t, connection.dialer)
}
func Test_As(t *testing.T) {
t.Run("error", func(t *testing.T) {
value, err := As[amqp.Message](nil, errors.New("test error"))
assert.EqualError(t, err, "test error")
assert.Nil(t, value)
})
t.Run("value nil", func(t *testing.T) {
value, err := As[amqp.Message](nil, nil)
assert.NoError(t, err)
assert.Nil(t, value)
})
t.Run("value not not type", func(t *testing.T) {
value, err := As[amqp.Message](struct{}{}, nil)
assert.EqualError(t, err, "could not cast value: struct {}")
assert.Nil(t, value)
})
t.Run("cast", func(t *testing.T) {
var sessionAny any = &amqpSessionMock{}
value, err := As[amqpSessionMock](sessionAny, nil)
assert.NoError(t, err)
assert.NotNil(t, value)
})
}

View file

@ -0,0 +1,78 @@
package messaging
import (
"context"
"errors"
"fmt"
"github.com/Azure/go-amqp"
"strings"
"time"
)
type amqpSender interface {
Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error
Close(ctx context.Context) error
}
type AmqpSenderSession struct {
session amqpSession
sender amqpSender
}
func (s *AmqpSenderSession) Send(
topic string,
data [][]byte,
contentType string,
applicationProperties map[string]any,
) error {
// check topic name
if !strings.HasPrefix(topic, AmqpTopicPrefix) {
return fmt.Errorf(
"topic %q name lacks mandatory prefix %q",
topic,
AmqpTopicPrefix,
)
}
if contentType == "" {
return errors.New("content-type is required")
}
// prepare the amqp message
message := amqp.Message{
Header: &amqp.MessageHeader{
Durable: true,
},
Properties: &amqp.MessageProperties{
To: &topic,
ContentType: &contentType,
},
ApplicationProperties: applicationProperties,
Data: data,
}
// send
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFn()
return s.sender.Send(ctx, &message, nil)
}
func (s *AmqpSenderSession) Close() error {
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFn()
var closeErrors []error
senderErr := s.sender.Close(ctx)
if senderErr != nil {
closeErrors = append(closeErrors, senderErr)
}
sessionErr := s.session.Close(ctx)
if sessionErr != nil {
closeErrors = append(closeErrors, sessionErr)
}
if len(closeErrors) > 0 {
return errors.Join(closeErrors...)
}
return nil
}

View file

@ -0,0 +1,186 @@
package messaging
import (
"context"
"errors"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
)
type amqpSenderMock struct {
mock.Mock
}
func (m *amqpSenderMock) Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error {
return m.Called(ctx, msg, opts).Error(0)
}
func (m *amqpSenderMock) Close(ctx context.Context) error {
return m.Called(ctx).Error(0)
}
var _ amqpSender = (*amqpSenderMock)(nil)
func Test_AmqpSenderSession_Close(t *testing.T) {
t.Run("close without errors", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(nil)
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(nil)
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
err := senderSession.Close()
assert.NoError(t, err)
sender.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close with sender error", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(errors.New("sender error"))
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(nil)
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
err := senderSession.Close()
assert.EqualError(t, err, "sender error")
sender.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close with session error", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(nil)
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(errors.New("session error"))
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
err := senderSession.Close()
assert.EqualError(t, err, "session error")
sender.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close with sender and session error", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(errors.New("sender error"))
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(errors.New("session error"))
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
err := senderSession.Close()
assert.EqualError(t, err, "sender error\nsession error")
sender.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
}
func Test_AmqpSenderSession_Send(t *testing.T) {
t.Run("invalid topic name", func(t *testing.T) {
sender := &amqpSenderMock{}
session := &amqpSessionMock{}
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
data := [][]byte{[]byte("data")}
err := senderSession.Send("invalid", data, "application/json", map[string]interface{}{})
assert.EqualError(t, err, "topic \"invalid\" name lacks mandatory prefix \"topic://\"")
})
t.Run("content type missing", func(t *testing.T) {
sender := &amqpSenderMock{}
session := &amqpSessionMock{}
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
data := [][]byte{[]byte("data")}
err := senderSession.Send("topic://some/name", data, "", map[string]interface{}{})
assert.EqualError(t, err, "content-type is required")
})
t.Run("send", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
session := &amqpSessionMock{}
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
data := [][]byte{[]byte("data")}
applicationProperties := map[string]interface{}{}
applicationProperties["key"] = "value"
err := senderSession.Send("topic://some/name", data, "application/json", applicationProperties)
assert.NoError(t, err)
sender.AssertNumberOfCalls(t, "Send", 1)
calls := sender.Calls
assert.Equal(t, 1, len(calls))
ctx, isCtx := calls[0].Arguments[0].(context.Context)
assert.True(t, isCtx)
assert.NotNil(t, ctx)
message, isMsg := calls[0].Arguments[1].(*amqp.Message)
assert.True(t, isMsg)
assert.True(t, message.Header.Durable)
assert.Equal(t, "topic://some/name", *message.Properties.To)
assert.Equal(t, "application/json", *message.Properties.ContentType)
assert.Equal(t, applicationProperties, message.ApplicationProperties)
assert.Equal(t, data, message.Data)
senderOptions, isSenderOptions := calls[0].Arguments[2].(*amqp.SendOptions)
assert.True(t, isSenderOptions)
assert.Nil(t, senderOptions)
})
t.Run("send fails", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("send fail"))
session := &amqpSessionMock{}
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
data := [][]byte{[]byte("data")}
applicationProperties := map[string]interface{}{}
applicationProperties["key"] = "value"
err := senderSession.Send("topic://some/name", data, "application/json", applicationProperties)
assert.EqualError(t, err, "send fail")
sender.AssertNumberOfCalls(t, "Send", 1)
})
}

View file

@ -2,19 +2,13 @@ package messaging
import (
"context"
"dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/log"
"errors"
"fmt"
"strings"
"sync"
"time"
"dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/log"
"github.com/Azure/go-amqp"
)
// Default connection timeout for the AMQP connection
const connectionTimeoutSeconds = 10
// Api is an abstraction for a messaging system that can be used to send
// audit logs to the audit log system.
type Api interface {
@ -38,203 +32,122 @@ type Api interface {
Close(ctx context.Context) error
}
// MutexApi is wrapper around an API implementation that controls mutual exclusive access to the api.
type MutexApi struct {
mutex sync.Mutex
api Api
}
var _ Api = &MutexApi{}
func NewMutexApi(api Api) (Api, error) {
if api == nil {
return nil, errors.New("api is nil")
}
mutexApi := MutexApi{
mutex: sync.Mutex{},
api: api,
}
var genericApi Api = &mutexApi
return genericApi, nil
}
// Send implements Api.Send
func (m *MutexApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.api.Send(ctx, topic, data, contentType, applicationProperties)
}
func (m *MutexApi) Close(ctx context.Context) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.api.Close(ctx)
}
// AmqpConfig provides AMQP connection related parameters.
type AmqpConfig struct {
URL string
User string
Password string
}
// AmqpSession is an abstraction providing a subset of the methods of amqp.Session
type AmqpSession interface {
NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error)
Close(ctx context.Context) error
}
type AmqpSessionWrapper struct {
session *amqp.Session
}
func (w AmqpSessionWrapper) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) {
return w.session.NewSender(ctx, target, opts)
}
func (w AmqpSessionWrapper) Close(ctx context.Context) error {
return w.session.Close(ctx)
}
// AmqpSender is an abstraction providing a subset of the methods of amqp.Sender
type AmqpSender interface {
Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error
Close(ctx context.Context) error
}
// AmqpApi implements Api.
type AmqpApi struct {
config AmqpConfig
connection *amqp.Conn
session AmqpSession
config AmqpConnectionPoolConfig
connection *AmqpConnection
connectionPool ConnectionPool
connectionPoolHandle *ConnectionPoolHandle
senderCache map[string]*AmqpSenderSession
lock sync.RWMutex
}
var _ Api = &AmqpApi{}
func NewAmqpApi(amqpConfig AmqpConfig) (Api, error) {
amqpApi := &AmqpApi{config: amqpConfig}
if err := amqpApi.connect(); err != nil {
return nil, fmt.Errorf("connect to broker: %w", err)
}
return amqpApi, nil
}
// connect opens a new connection and session to the AMQP messaging system.
// The connection attempt will be cancelled after connectionTimeoutSeconds.
func (a *AmqpApi) connect() error {
log.AuditLogger.Info("connecting to audit messaging system")
// Set credentials if specified
auth := amqp.SASLTypeAnonymous()
if a.config.User != "" && a.config.Password != "" {
auth = amqp.SASLTypePlain(a.config.User, a.config.Password)
log.AuditLogger.Info("using username and password for messaging")
} else {
log.AuditLogger.Warn("using anonymous messaging!")
}
options := &amqp.ConnOptions{
SASLType: auth,
}
// Create new context with timeout for the connection initialization
subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second)
defer cancel()
// Initialize connection
conn, err := amqp.Dial(subCtx, a.config.URL, options)
func NewAmqpApi(amqpConfig AmqpConnectionPoolConfig) (Api, error) {
connectionPool, err := NewAmqpConnectionPool(amqpConfig, "sdk")
if err != nil {
return fmt.Errorf("dial connection to broker: %w", err)
}
a.connection = conn
// Initialize session
session, err := conn.NewSession(context.Background(), nil)
if err != nil {
return fmt.Errorf("create session: %w", err)
return nil, fmt.Errorf("new amqp connection pool: %w", err)
}
var amqpSession AmqpSession = &AmqpSessionWrapper{session: session}
a.session = amqpSession
amqpApi := &AmqpApi{config: amqpConfig,
connectionPool: connectionPool,
connectionPoolHandle: connectionPool.NewHandle(),
senderCache: make(map[string]*AmqpSenderSession),
}
return nil
var messagingApi Api = amqpApi
return messagingApi, nil
}
// Send implements Api.Send.
// If errors occur the connection to the messaging system will be closed and re-established.
func (a *AmqpApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
err := a.trySend(ctx, topic, data, contentType, applicationProperties)
if err == nil {
func (a *AmqpApi) Send(_ context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
// create or get sender from cache
var sender = a.senderFromCache(topic)
if sender == nil {
if err := a.newSender(topic); err != nil {
return err
}
sender = a.senderFromCache(topic)
}
// first attempt to send
var sendErr error
wrappedData := [][]byte{data}
if err := sender.Send(topic, wrappedData, contentType, applicationProperties); err != nil {
sendErr = fmt.Errorf("send: %w", err)
} else {
return nil
}
// Drop the current sender, as it cannot connect to the broker anymore
log.AuditLogger.Error("message sender error, recreating", err)
// renew sender
if err := a.newSender(topic); err != nil {
return errors.Join(sendErr, err)
}
sender = a.senderFromCache(topic)
err = a.resetConnection(ctx)
// retry send
if err := sender.Send(topic, wrappedData, contentType, applicationProperties); err != nil {
return errors.Join(sendErr, fmt.Errorf("retry send: %w", err))
}
return nil
}
func (a *AmqpApi) senderFromCache(topic string) *AmqpSenderSession {
a.lock.RLock()
defer a.lock.RUnlock()
return a.senderCache[topic]
}
func (a *AmqpApi) newSender(topic string) error {
a.lock.Lock()
defer a.lock.Unlock()
connectionIsClosed := a.connection == nil || a.connection.IsClosed()
if connectionIsClosed {
connection, err := a.connectionPool.GetConnection(a.connectionPoolHandle)
if err != nil {
return fmt.Errorf("reset connection: %w", err)
return fmt.Errorf("get connection: %w", err)
}
a.connection = connection
}
return a.trySend(ctx, topic, data, contentType, applicationProperties)
}
// trySend actually sends the given data as amqp.Message to the messaging system.
func (a *AmqpApi) trySend(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
if !strings.HasPrefix(topic, AmqpTopicPrefix) {
return fmt.Errorf(
"topic %q name lacks mandatory prefix %q",
topic,
AmqpTopicPrefix,
)
}
sender, err := a.session.NewSender(ctx, topic, nil)
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
sender, err := a.connection.NewSender(ctx, topic)
cancelFn()
if err != nil {
return fmt.Errorf("new sender: %w", err)
}
defer func() {
if err := sender.Close(ctx); err != nil {
log.AuditLogger.Error("failed to close session sender", err)
}
}()
bytes := [][]byte{data}
message := amqp.Message{
Header: &amqp.MessageHeader{
Durable: true,
},
Properties: &amqp.MessageProperties{
To: &topic,
ContentType: &contentType,
},
ApplicationProperties: applicationProperties,
Data: bytes,
}
err = sender.Send(ctx, &message, nil)
if err != nil {
return fmt.Errorf("send message: %w", err)
}
a.senderCache[topic] = sender
return nil
}
// resetConnection closes the current session and connection and reconnects to the messaging system.
func (a *AmqpApi) resetConnection(ctx context.Context) error {
if err := a.Close(ctx); err != nil {
log.AuditLogger.Error("failed to close audit messaging connection", err)
}
return a.connect()
}
// Close implements Api.Close
func (a *AmqpApi) Close(ctx context.Context) error {
log.AuditLogger.Info("close audit messaging connection")
return errors.Join(a.session.Close(ctx), a.connection.Close())
func (a *AmqpApi) Close(_ context.Context) error {
log.AuditLogger.Info("close audit amqp connection pool")
a.lock.Lock()
defer a.lock.Unlock()
// cached senders
var closeErrors []error
for _, session := range a.senderCache {
if err := session.Close(); err != nil {
closeErrors = append(closeErrors, fmt.Errorf("close session: %w", err))
}
}
clear(a.senderCache)
// pool
if err := a.connectionPool.Close(); err != nil {
closeErrors = append(closeErrors, fmt.Errorf("close pool: %w", err))
}
if len(closeErrors) > 0 {
return fmt.Errorf("close: %w", errors.Join(closeErrors...))
}
return nil
}

View file

@ -4,50 +4,38 @@ import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"sync"
"testing"
"time"
)
type AmqpSessionMock struct {
type connectionPoolMock struct {
mock.Mock
}
func (m *AmqpSessionMock) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) {
args := m.Called(ctx, target, opts)
var sender AmqpSender = nil
if args.Get(0) != nil {
sender = args.Get(0).(AmqpSender)
}
err := args.Error(1)
return sender, err
func (m *connectionPoolMock) Close() error {
return m.Called().Error(0)
}
func (m *AmqpSessionMock) Close(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
func (m *connectionPoolMock) NewHandle() *ConnectionPoolHandle {
return m.Called().Get(0).(*ConnectionPoolHandle)
}
type AmqpSenderMock struct {
mock.Mock
func (m *connectionPoolMock) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) {
return m.Called(handle).Get(0).(*AmqpConnection), m.Called(handle).Error(1)
}
func (m *AmqpSenderMock) Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error {
args := m.Called(ctx, msg, opts)
return args.Error(0)
}
func (m *AmqpSenderMock) Close(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
var _ ConnectionPool = (*connectionPoolMock)(nil)
func Test_NewAmqpMessagingApi(t *testing.T) {
_, err := NewAmqpApi(AmqpConfig{URL: "not-handled-protocol://localhost:5672"})
assert.EqualError(t, err, "connect to broker: dial connection to broker: unsupported scheme \"not-handled-protocol\"")
_, err := NewAmqpApi(
AmqpConnectionPoolConfig{
Parameters: AmqpConnectionConfig{BrokerUrl: "not-handled-protocol://localhost:5672"},
PoolSize: 1,
})
assert.EqualError(t, err, "new amqp connection pool: initialize connections: new connection: new internal connection: internal connect: dial: unsupported scheme \"not-handled-protocol\"")
}
func Test_AmqpMessagingApi_Send(t *testing.T) {
@ -63,121 +51,359 @@ func Test_AmqpMessagingApi_Send(t *testing.T) {
t.Run("Missing topic prefix", func(t *testing.T) {
defer solaceContainer.StopOnError()
api, err := NewAmqpApi(AmqpConfig{URL: solaceContainer.AmqpConnectionString})
api, err := NewAmqpApi(AmqpConnectionPoolConfig{
Parameters: AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err)
err = api.Send(ctx, "topic-name", []byte{}, "application/json", make(map[string]any))
assert.EqualError(t, err, "topic \"topic-name\" name lacks mandatory prefix \"topic://\"")
assert.EqualError(t, err, "send: topic \"topic-name\" name lacks mandatory prefix \"topic://\"\nretry send: topic \"topic-name\" name lacks mandatory prefix \"topic://\"")
})
t.Run("Close connection without errors", func(t *testing.T) {
t.Run("send successfully", func(t *testing.T) {
defer solaceContainer.StopOnError()
// Initialize the solace queue
topicSubscriptionTopicPattern := "auditlog/>"
queueName := "close-connection-without-error"
queueName := "send-successfully"
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-close-connection")
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-send-successfully")
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}}
err := api.connect()
api, err := NewAmqpApi(AmqpConnectionPoolConfig{
Parameters: AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err)
data := []byte("data")
applicationProperties := make(map[string]interface{})
applicationProperties["key"] = "value"
err = api.Send(ctx, topicName, data, "application/json", applicationProperties)
assert.NoError(t, err)
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true)
assert.NoError(t, err)
assert.Equal(t, "data", string(message.Data[0]))
assert.Equal(t, topicName, *message.Properties.To)
assert.Equal(t, "application/json", *message.Properties.ContentType)
assert.Equal(t, applicationProperties, message.ApplicationProperties)
err = api.Close(ctx)
assert.NoError(t, err)
})
}
t.Run("New sender call returns error", func(t *testing.T) {
defer solaceContainer.StopOnError()
func Test_AmqpMessagingApi_Send_Special_Cases(t *testing.T) {
// Initialize the solace queue
topicSubscriptionTopicPattern := "auditlog/>"
queueName := "messaging-new-sender"
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-no-new-sender")
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}}
err := api.connect()
newActiveConnection := func() *AmqpConnection {
channel := make(chan struct{})
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
newClosedConnection := func() *AmqpConnection {
channel := make(chan struct{})
close(channel)
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
t.Run("connection nil sender nil", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
session := &amqpSessionMock{}
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
connection := newActiveConnection()
conn := connection.conn.(*amqpConnMock)
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
pool := &connectionPoolMock{}
pool.On("GetConnection", mock.Anything).Return(connection, nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.NoError(t, err)
expectedError := errors.New("expected error")
// Set mock session
sessionMock := AmqpSessionMock{}
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(nil, expectedError)
sessionMock.On("Close", mock.Anything).Return(nil)
var amqpSession AmqpSession = &sessionMock
api.session = amqpSession
// It's expected that the test succeeds.
// First the session is closed as it returns the expected error
// Then the retry mechanism restarts the connection and successfully sends the data
value := "test"
err = api.Send(ctx, topicName, []byte(value), "application/json", make(map[string]any))
assert.NoError(t, err)
// Check that the mock was called
assert.True(t, sessionMock.AssertNumberOfCalls(t, "NewSender", 1))
assert.True(t, sessionMock.AssertNumberOfCalls(t, "Close", 1))
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true)
assert.NoError(t, err)
assert.Equal(t, value, string(message.Data[0]))
assert.Equal(t, topicName, *message.Properties.To)
sender.AssertNumberOfCalls(t, "Send", 1)
session.AssertNumberOfCalls(t, "NewSender", 1)
pool.AssertNumberOfCalls(t, "GetConnection", 2)
})
t.Run("Send call on sender returns error", func(t *testing.T) {
defer solaceContainer.StopOnError()
t.Run("connection closed sender nil", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
// Initialize the solace queue
topicSubscriptionTopicPattern := "auditlog/>"
queueName := "messaging-sender-error"
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-sender-error")
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
session := &amqpSessionMock{}
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}}
err := api.connect()
connection := newActiveConnection()
conn := connection.conn.(*amqpConnMock)
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
pool := &connectionPoolMock{}
pool.On("GetConnection", mock.Anything).Return(connection, nil)
closedConnection := newClosedConnection()
closedConnMock := closedConnection.conn.(*amqpConnMock)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: closedConnection,
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.NoError(t, err)
expectedError := errors.New("expected error")
sender.AssertNumberOfCalls(t, "Send", 1)
session.AssertNumberOfCalls(t, "NewSender", 1)
pool.AssertNumberOfCalls(t, "GetConnection", 2)
closedConnMock.AssertNumberOfCalls(t, "Done", 1)
})
// Instantiate mock sender
senderMock := AmqpSenderMock{}
senderMock.On("Send", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError)
senderMock.On("Close", mock.Anything).Return(nil)
var amqpSender AmqpSender = &senderMock
t.Run("connection nil get connection fail", func(t *testing.T) {
var connection *AmqpConnection = nil
// Set mock session
sessionMock := AmqpSessionMock{}
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(&amqpSender, nil)
sessionMock.On("Close", mock.Anything).Return(nil)
pool := &connectionPoolMock{}
pool.On("GetConnection", mock.Anything).Return(connection, errors.New("connection error"))
var amqpSession AmqpSession = &sessionMock
api.session = amqpSession
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
// It's expected that the test succeeds.
// First the sender and session are closed as the sender returns the expected error
// Then the retry mechanism restarts the connection and successfully sends the data
value := "test"
err = api.Send(ctx, topicName, []byte(value), "application/json", make(map[string]any))
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.EqualError(t, err, "get connection: connection error")
pool.AssertNumberOfCalls(t, "GetConnection", 2)
})
t.Run("connection active sender nil", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
session := &amqpSessionMock{}
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
connection := newActiveConnection()
conn := connection.conn.(*amqpConnMock)
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: connection,
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.NoError(t, err)
// Check that the mocks were called
assert.True(t, sessionMock.AssertNumberOfCalls(t, "NewSender", 1))
assert.True(t, sessionMock.AssertNumberOfCalls(t, "Close", 1))
assert.True(t, senderMock.AssertNumberOfCalls(t, "Send", 1))
assert.True(t, senderMock.AssertNumberOfCalls(t, "Close", 1))
sender.AssertNumberOfCalls(t, "Send", 1)
session.AssertNumberOfCalls(t, "NewSender", 1)
})
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true)
t.Run("connection active new sender fail", func(t *testing.T) {
var sender *amqpSenderMock = nil
session := &amqpSessionMock{}
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, errors.New("new sender error"))
session.On("Close", mock.Anything).Return(nil)
connection := newActiveConnection()
conn := connection.conn.(*amqpConnMock)
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: connection,
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.EqualError(t, err, "new sender: new internal sender: new sender error")
session.AssertNumberOfCalls(t, "NewSender", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("connection active sender set", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
topic := "topic://some-topic"
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: newActiveConnection(),
senderCache: map[string]*AmqpSenderSession{topic: {sender: sender}},
}
err := amqpApi.Send(context.Background(), topic, []byte("data"), "application/json", make(map[string]any))
assert.NoError(t, err)
assert.Equal(t, value, string(message.Data[0]))
assert.Equal(t, topicName, *message.Properties.To)
sender.AssertNumberOfCalls(t, "Send", 1)
})
t.Run("send fail", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("send error"))
session := &amqpSessionMock{}
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
topic := "topic://some-topic"
connection := newActiveConnection()
connection.conn.(*amqpConnMock).On("NewSession", mock.Anything, mock.Anything, mock.Anything).Return(session, nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: connection,
senderCache: map[string]*AmqpSenderSession{topic: {sender: sender}},
}
err := amqpApi.Send(context.Background(), topic, []byte("data"), "application/json", make(map[string]any))
assert.EqualError(t, err, "send: send error\nretry send: send error")
sender.AssertNumberOfCalls(t, "Send", 2)
})
}
func Test_AmqpMessagingApi_Close(t *testing.T) {
t.Run("close without cached senders", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Close(context.Background())
assert.NoError(t, err)
pool.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close fail without cached senders", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(errors.New("close error"))
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Close(context.Background())
assert.EqualError(t, err, "close: close pool: close error")
pool.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close with cached senders", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(nil)
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(nil)
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(nil)
senderSession := &AmqpSenderSession{
session: session,
sender: sender,
}
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
}
err := amqpApi.Close(context.Background())
assert.NoError(t, err)
assert.Equal(t, 0, len(amqpApi.senderCache))
pool.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
sender.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close fail with cached senders", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(nil)
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(nil)
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(errors.New("close sender error"))
senderSession := &AmqpSenderSession{
session: session,
sender: sender,
}
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
}
err := amqpApi.Close(context.Background())
assert.EqualError(t, err, "close: close session: close sender error")
assert.Equal(t, 0, len(amqpApi.senderCache))
pool.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
sender.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close fail", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(errors.New("close pool error"))
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(errors.New("close session error"))
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(errors.New("close sender error"))
senderSession := &AmqpSenderSession{
session: session,
sender: sender,
}
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
}
err := amqpApi.Close(context.Background())
assert.EqualError(t, err, "close: close session: close sender error\nclose session error\nclose pool: close pool error")
assert.Equal(t, 0, len(amqpApi.senderCache))
pool.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
sender.AssertNumberOfCalls(t, "Close", 1)
})
}

View file

@ -18,7 +18,6 @@ import (
)
const (
AmqpTopicPrefix = "topic://"
AmqpQueuePrefix = "queue://"
)