Compare commits

...

1 commit

Author SHA1 Message Date
Christian Schaible
49ebd76a9d feat: replace amqp connection 2025-01-16 10:54:08 +01:00
13 changed files with 2180 additions and 293 deletions

View file

@ -30,7 +30,11 @@ func TestDynamicLegacyAuditApi(t *testing.T) {
defer solaceContainer.Stop() defer solaceContainer.Stop()
// Instantiate the messaging api // Instantiate the messaging api
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString}) messagingApi, err := messaging.NewAmqpApi(
messaging.AmqpConnectionPoolConfig{
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err) assert.NoError(t, err)
// Validator // Validator

View file

@ -32,7 +32,10 @@ func TestLegacyAuditApi(t *testing.T) {
defer solaceContainer.Stop() defer solaceContainer.Stop()
// Instantiate the messaging api // Instantiate the messaging api
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString}) messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err) assert.NoError(t, err)
// Validator // Validator
@ -579,7 +582,10 @@ func TestLegacyAuditApi_NewLegacyAuditApi(t *testing.T) {
defer solaceContainer.Stop() defer solaceContainer.Stop()
// Instantiate the messaging api // Instantiate the messaging api
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString}) messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err) assert.NoError(t, err)
// Validator // Validator

View file

@ -33,7 +33,10 @@ func TestRoutableAuditApi(t *testing.T) {
defer solaceContainer.Stop() defer solaceContainer.Stop()
// Instantiate the messaging api // Instantiate the messaging api
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString}) messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err) assert.NoError(t, err)
// Validator // Validator

View file

@ -0,0 +1,15 @@
package messaging
const AmqpTopicPrefix = "topic://"
const connectionTimeoutSeconds = 10
type AmqpConnectionConfig struct {
BrokerUrl string `json:"brokerUrl"`
Username string `json:"username"`
Password string `json:"password"`
}
type AmqpConnectionPoolConfig struct {
Parameters AmqpConnectionConfig `json:"parameters"`
PoolSize int `json:"poolSize"`
}

View file

@ -0,0 +1,277 @@
package messaging
import (
"context"
"errors"
"fmt"
"github.com/Azure/go-amqp"
"log/slog"
"sync"
"time"
)
var ConnectionClosedError = errors.New("amqp connection is closed")
type AmqpConnection struct {
connectionName string
lock sync.RWMutex
brokerUrl string
username string
password string
conn amqpConn
dialer amqpDial
}
// amqpConn is an abstraction of amqp.Conn
type amqpConn interface {
NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error)
Close() error
Done() <-chan struct{}
}
type defaultAmqpConn struct {
conn *amqp.Conn
}
func newDefaultAmqpConn(conn *amqp.Conn) *defaultAmqpConn {
return &defaultAmqpConn{
conn: conn,
}
}
func (d defaultAmqpConn) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) {
session, err := d.conn.NewSession(ctx, opts)
if err != nil {
return nil, err
}
return newDefaultAmqpSession(session), nil
}
func (d defaultAmqpConn) Close() error {
return d.conn.Close()
}
func (d defaultAmqpConn) Done() <-chan struct{} {
return d.conn.Done()
}
var _ amqpConn = (*defaultAmqpConn)(nil)
type amqpDial interface {
Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error)
}
type amqpSession interface {
NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error)
Close(ctx context.Context) error
}
type defaultAmqpSession struct {
session *amqp.Session
}
func newDefaultAmqpSession(session *amqp.Session) *defaultAmqpSession {
return &defaultAmqpSession{
session: session,
}
}
func (s *defaultAmqpSession) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) {
return s.session.NewSender(ctx, target, opts)
}
func (s *defaultAmqpSession) Close(ctx context.Context) error {
return s.session.Close(ctx)
}
var _ amqpSession = (*defaultAmqpSession)(nil)
type defaultAmqpDialer struct{}
func (d *defaultAmqpDialer) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) {
dial, err := amqp.Dial(ctx, addr, opts)
if err != nil {
return nil, err
}
return newDefaultAmqpConn(dial), nil
}
var _ amqpDial = (*defaultAmqpDialer)(nil)
func NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection {
return &AmqpConnection{
connectionName: connectionName,
lock: sync.RWMutex{},
brokerUrl: config.BrokerUrl,
username: config.Username,
password: config.Password,
dialer: &defaultAmqpDialer{},
}
}
func (c *AmqpConnection) NewSender(ctx context.Context, topic string) (*AmqpSenderSession, error) {
if c.conn == nil {
return nil, errors.New("connection is not initialized")
}
if c.internalIsClosed() {
return nil, ConnectionClosedError
}
c.lock.RLock()
defer c.lock.RUnlock()
// new session
newSession, err := c.conn.NewSession(ctx, nil)
if err != nil {
return nil, fmt.Errorf("new session: %w", err)
}
// new sender
newSender, err := newSession.NewSender(ctx, topic, nil)
if err != nil {
err = fmt.Errorf("new internal sender: %w", err)
closeErr := newSession.Close(ctx)
if closeErr != nil {
return nil, errors.Join(err, fmt.Errorf("close session: %w", closeErr))
}
return nil, err
}
return &AmqpSenderSession{newSession, newSender}, nil
}
func (c *AmqpConnection) ResetConnectionAndRetryIfErrorWithReturnValue(opName string, fn func(ctx context.Context) (any, error)) (any, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
result, err := fn(ctx)
cancelFn()
if err != nil {
slog.Info(fmt.Sprintf("amqp connection: %s", opName), slog.Any("connection", c.connectionName), slog.Any("err", err))
err := c.ResetConnection(context.Background())
if err != nil {
return nil, fmt.Errorf("reset connection: %w", err)
}
newCtx, closeFn := context.WithTimeout(context.Background(), 10*time.Second)
defer closeFn()
// Retry
return fn(newCtx)
}
return result, nil
}
func (c *AmqpConnection) ResetConnection(ctx context.Context) error {
c.lock.Lock()
defer c.lock.Unlock()
err := c.internalClose()
if err != nil {
slog.Warn("amqp connection: reset: failed to close amqp connection", slog.Any("err", err))
}
subCtx, cancel := context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second)
err = c.internalConnect(subCtx)
cancel()
if err != nil {
slog.Warn("amqp connection: reset: failed to connect to amqp server, retry..", slog.Any("err", err))
subCtx, cancel = context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second)
err = c.internalConnect(subCtx)
cancel()
if err != nil {
return fmt.Errorf("connect: %w", err)
}
}
return nil
}
func As[T any](value any, err error) (*T, error) {
if err != nil {
return nil, err
}
if value == nil {
return nil, nil
}
castedValue, isType := value.(*T)
if !isType {
return nil, fmt.Errorf("could not cast value: %T", value)
}
return castedValue, nil
}
func (c *AmqpConnection) Connect() error {
c.lock.Lock()
defer c.lock.Unlock()
subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second)
defer cancel()
if err := c.internalConnect(subCtx); err != nil {
return fmt.Errorf("internal connection connect: %w", err)
}
return nil
}
func (c *AmqpConnection) internalConnect(ctx context.Context) error {
if c.conn == nil {
// Set credentials if specified
auth := amqp.SASLTypeAnonymous()
if c.username != "" && c.password != "" {
auth = amqp.SASLTypePlain(c.username, c.password)
} else {
slog.Debug("amqp connection: connect: using anonymous messaging")
}
options := &amqp.ConnOptions{
SASLType: auth,
}
// Initialize connection
conn, err := c.dialer.Dial(ctx, c.brokerUrl, options)
if err != nil {
return fmt.Errorf("dial: %w", err)
}
c.conn = conn
}
return nil
}
func (c *AmqpConnection) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.internalClose(); err != nil {
return fmt.Errorf("close: %w", err)
}
return nil
}
func (c *AmqpConnection) internalClose() error {
if c.conn != nil {
if err := c.conn.Close(); err != nil {
return fmt.Errorf("internal connection close: %w", err)
}
c.conn = nil
}
return nil
}
func (c *AmqpConnection) IsClosed() bool {
c.lock.RLock()
defer c.lock.RUnlock()
return c.internalIsClosed()
}
func (c *AmqpConnection) internalIsClosed() bool {
if c.conn == nil {
return true
}
select {
case <-c.conn.Done():
return true
default:
return false
}
}

View file

@ -0,0 +1,213 @@
package messaging
import (
"errors"
"fmt"
"log/slog"
"sync"
)
type connectionProvider interface {
NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection
}
type defaultAmqpConnectionProvider struct{}
func (p defaultAmqpConnectionProvider) NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection {
return NewAmqpConnection(config, connectionName)
}
var _ connectionProvider = (*defaultAmqpConnectionProvider)(nil)
type ConnectionPool interface {
Close() error
NewHandle() *ConnectionPoolHandle
GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error)
}
type AmqpConnectionPool struct {
config AmqpConnectionPoolConfig
connectionName string
connections []*AmqpConnection
connectionProvider connectionProvider
handleOffset int
lock sync.RWMutex
}
type ConnectionPoolHandle struct {
connectionOffset int
}
func NewAmqpConnectionPool(config AmqpConnectionPoolConfig, connectionName string) (ConnectionPool, error) {
pool := &AmqpConnectionPool{
config: config,
connectionName: connectionName,
connections: make([]*AmqpConnection, 0),
connectionProvider: defaultAmqpConnectionProvider{},
handleOffset: 0,
lock: sync.RWMutex{},
}
if err := pool.initializeConnections(); err != nil {
if closeErr := pool.Close(); closeErr != nil {
return nil, errors.Join(err, fmt.Errorf("initialize amqp connection: pool closed: %w", closeErr))
}
return nil, fmt.Errorf("initialize connections: %w", err)
}
return pool, nil
}
func (p *AmqpConnectionPool) initializeConnections() error {
if len(p.connections) < p.config.PoolSize {
p.lock.Lock()
defer p.lock.Unlock()
numMissingConnections := p.config.PoolSize - len(p.connections)
for i := 0; i < numMissingConnections; i++ {
if err := p.internalAddConnection(); err != nil {
return err
}
}
}
return nil
}
func (p *AmqpConnectionPool) internalAddConnection() error {
newConnection, err := p.internalNewConnection()
if err != nil {
return fmt.Errorf("new connection: %w", err)
}
p.connections = append(p.connections, newConnection)
return nil
}
func (p *AmqpConnectionPool) internalNewConnection() (*AmqpConnection, error) {
conn := p.connectionProvider.NewAmqpConnection(&p.config.Parameters, p.connectionName)
if err := conn.Connect(); err != nil {
slog.Warn("amqp connection: failed to connect to amqp broker", slog.Any("err", err))
// retry
if err = conn.Connect(); err != nil {
connectErr := fmt.Errorf("failed to connect to amqp broker: %w", err)
if closeErr := conn.Close(); closeErr != nil {
// this case should never happen as the inner connection should always be null, therefore
// it should not have to be closed, i.e. be able to return errors.
return nil, errors.Join(connectErr, fmt.Errorf("close connection: %w", closeErr))
}
return nil, connectErr
}
}
return conn, nil
}
func (p *AmqpConnectionPool) Close() error {
p.lock.Lock()
defer p.lock.Unlock()
closeErrors := make([]error, 0)
for _, conn := range p.connections {
if err := conn.Close(); err != nil {
closeErrors = append(closeErrors, fmt.Errorf("connection: close: %w", err))
}
}
p.connections = make([]*AmqpConnection, 0)
if len(closeErrors) > 0 {
return errors.Join(closeErrors...)
}
return nil
}
func (p *AmqpConnectionPool) NewHandle() *ConnectionPoolHandle {
p.lock.Lock()
defer p.lock.Unlock()
offset := p.handleOffset
p.handleOffset += 1
offset = offset % p.config.PoolSize
return &ConnectionPoolHandle{
connectionOffset: offset,
}
}
func (p *AmqpConnectionPool) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) {
// get the requested connection or another one
conn, addConnection := p.nextConnectionForHandle(handle)
// renew the requested connection if the request connection is closed
if conn == nil || addConnection {
p.lock.Lock()
// check that accessing the pool only with a valid index (out of bounds should only occur on shutdown)
if handle.connectionOffset < len(p.connections) && p.connections[handle.connectionOffset] == nil {
connection, err := p.internalNewConnection()
if err != nil {
if conn == nil {
// case: connection could not be renewed and no connection to return has been found
p.lock.Unlock()
return nil, fmt.Errorf("renew connection: %w", err)
}
// case: connection could not be renewed but another connection will be returned
slog.Warn("amqp connection pool: get connection: renew connection: ", slog.Any("err", err))
} else {
// case: connection could be renewed and will be added to pool
p.connections[handle.connectionOffset] = connection
conn = connection
}
}
p.lock.Unlock()
}
return conn, nil
}
func (p *AmqpConnectionPool) nextConnectionForHandle(handle *ConnectionPoolHandle) (*AmqpConnection, bool) {
// return the next possible index (including the retry offset)
nextIndex := func(idx int) int {
if idx+handle.connectionOffset >= p.config.PoolSize {
return (idx + handle.connectionOffset) % p.config.PoolSize
} else {
return idx + handle.connectionOffset
}
}
// retry as long as there are remaining connections in the pool
var conn *AmqpConnection
var addConnection bool
for i := 0; i < p.config.PoolSize; i++ {
// get the next possible connection (considering the retry index)
idx := nextIndex(i)
p.lock.RLock()
if idx < len(p.connections) {
conn = p.connections[idx]
} else {
// handle the edge case that the pool is empty on shutdown
conn = nil
}
p.lock.RUnlock()
// remember that the requested is closed, retry with the next
if conn == nil {
addConnection = true
continue
}
// if the connection is closed, mark it by setting it to nil
if conn.IsClosed() {
p.lock.Lock()
p.connections[idx] = nil
p.lock.Unlock()
addConnection = true
continue
}
return conn, addConnection
}
return nil, true
}

View file

@ -0,0 +1,572 @@
package messaging
import (
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"sync"
"testing"
)
type connectionProviderMock struct {
mock.Mock
}
func (p *connectionProviderMock) NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection {
args := p.Called(config, connectionName)
return args.Get(0).(*AmqpConnection)
}
var _ connectionProvider = (*connectionProviderMock)(nil)
func Test_AmqpConnectionPool_GetHandle(t *testing.T) {
t.Run("next handle", func(t *testing.T) {
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
}
handle := pool.NewHandle()
assert.NotNil(t, handle)
assert.Equal(t, 0, handle.connectionOffset)
assert.Equal(t, 1, pool.handleOffset)
})
t.Run("next handle high offset", func(t *testing.T) {
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 13,
lock: sync.RWMutex{},
}
handle := pool.NewHandle()
assert.NotNil(t, handle)
assert.Equal(t, 3, handle.connectionOffset)
assert.Equal(t, 14, pool.handleOffset)
})
}
func Test_AmqpConnectionPool_internalAddConnection(t *testing.T) {
t.Run("internal add connection", func(t *testing.T) {
conn := &amqpConnMock{}
dialer := &amqpDialMock{}
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.internalAddConnection()
assert.NoError(t, err)
assert.Equal(t, 1, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
dialer.AssertNumberOfCalls(t, "Dial", 1)
})
t.Run("dialer error", func(t *testing.T) {
conn := &amqpConnMock{}
dialer := &amqpDialMock{}
var c *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error")).Once()
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.internalAddConnection()
assert.NoError(t, err)
assert.Equal(t, 1, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
dialer.AssertNumberOfCalls(t, "Dial", 2)
})
t.Run("repetitive dialer error", func(t *testing.T) {
dialer := &amqpDialMock{}
var c *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.internalAddConnection()
assert.EqualError(t, err, "new connection: failed to connect to amqp broker: internal connection connect: dial: test error")
assert.Equal(t, 0, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
dialer.AssertNumberOfCalls(t, "Dial", 2)
})
}
func Test_AmqpConnectionPool_initializeConnections(t *testing.T) {
t.Run("initialize connections successfully", func(t *testing.T) {
conn := &amqpConnMock{}
dialer := &amqpDialMock{}
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.initializeConnections()
assert.NoError(t, err)
assert.Equal(t, 5, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 5)
})
t.Run("fail initialization of connections", func(t *testing.T) {
var c *amqpConnMock = nil
failingDialer := &amqpDialMock{}
failingDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
failingConnection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: failingDialer,
}
conn := &amqpConnMock{}
successfulDialer := &amqpDialMock{}
successfulDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
successfulConnection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: successfulDialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(4)
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(failingConnection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.initializeConnections()
assert.EqualError(t, err, "new connection: failed to connect to amqp broker: internal connection connect: dial: test error")
assert.Equal(t, 4, len(pool.connections))
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 5)
})
}
func Test_AmqpConnectionPool_Close(t *testing.T) {
t.Run("close connection successfully", func(t *testing.T) {
// add 5 connections to the pool
conn := &amqpConnMock{}
conn.On("Close").Return(nil)
dialer := &amqpDialMock{}
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.initializeConnections()
assert.NoError(t, err)
assert.Equal(t, 5, len(pool.connections))
// close the pool
err = pool.Close()
assert.NoError(t, err)
assert.Equal(t, 0, len(pool.connections))
})
t.Run("close connection fail", func(t *testing.T) {
// add 5 connections to the pool
failingConn := &amqpConnMock{}
failingConn.On("Close").Return(errors.New("test error"))
failingDialer := &amqpDialMock{}
failingDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(failingConn, nil)
failingConnection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: failingDialer,
}
successfulConn := &amqpConnMock{}
successfulConn.On("Close").Return(nil)
successfulDialer := &amqpDialMock{}
successfulDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(successfulConn, nil)
successfulConnection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: successfulDialer,
}
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(2)
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(failingConnection).Times(2)
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(1)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connectionProvider: connectionProvider,
}
err := pool.initializeConnections()
assert.NoError(t, err)
assert.Equal(t, 5, len(pool.connections))
// close the pool
err = pool.Close()
assert.EqualError(t, err, "connection: close: close: internal connection close: test error\nconnection: close: close: internal connection close: test error")
assert.Equal(t, 0, len(pool.connections))
})
}
func Test_AmqpConnectionPool_nextConnectionForHandle(t *testing.T) {
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
newActiveConnection := func() *AmqpConnection {
channel := make(chan struct{})
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
newClosedConnection := func() *AmqpConnection {
channel := make(chan struct{})
close(channel)
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
t.Run("next connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
for i := 0; i < 5; i++ {
connections = append(connections, newActiveConnection())
}
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
assert.NotNil(t, connection)
assert.False(t, addConnection)
})
t.Run("nil connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, newActiveConnection())
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, newActiveConnection())
connections = append(connections, newActiveConnection())
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
assert.NotNil(t, connection)
assert.True(t, addConnection)
})
t.Run("closed connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, newActiveConnection())
connections = append(connections, newClosedConnection())
connections = append(connections, newClosedConnection())
connections = append(connections, newActiveConnection())
connections = append(connections, newActiveConnection())
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
assert.NotNil(t, connection)
assert.True(t, addConnection)
})
t.Run("no connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
assert.Nil(t, connection)
assert.True(t, addConnection)
})
t.Run("connection for requested handle with large index", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, newActiveConnection())
connections = append(connections, nil)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 23})
assert.NotNil(t, connection)
assert.False(t, addConnection)
})
t.Run("connection for requested handle nil with large index", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, newActiveConnection())
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 23})
assert.NotNil(t, connection)
assert.True(t, addConnection)
})
}
func Test_AmqpConnectionPool_GetConnection(t *testing.T) {
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
newActiveConnection := func() *AmqpConnection {
channel := make(chan struct{})
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
t.Run("get connection for requested handle", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
for i := 0; i < 5; i++ {
connections = append(connections, newActiveConnection())
}
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
}
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
assert.NoError(t, err)
assert.NotNil(t, connection)
assert.Equal(t, connections[1], connection)
assert.Equal(t, 5, len(connections))
})
t.Run("add connection if missing", func(t *testing.T) {
connections := make([]*AmqpConnection, 5)
connectionProvider := &connectionProviderMock{}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(newActiveConnection())
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
connectionProvider: connectionProvider,
}
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
assert.NoError(t, err)
assert.NotNil(t, connection)
assert.Equal(t, connections[1], connection)
assert.Equal(t, 5, len(connections))
})
t.Run("add connection fails returns alternative connection", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, newActiveConnection())
connections = append(connections, nil)
connections = append(connections, newActiveConnection())
connections = append(connections, newActiveConnection())
connections = append(connections, newActiveConnection())
connectionProvider := &connectionProviderMock{}
dialer := &amqpDialMock{}
var c *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, fmt.Errorf("dial error"))
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
connectionProvider: connectionProvider,
}
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
assert.NoError(t, err)
assert.NotNil(t, connection)
assert.Nil(t, connections[1])
assert.Equal(t, connections[2], connection)
assert.Equal(t, 5, len(connections))
})
t.Run("add connection fails", func(t *testing.T) {
connections := make([]*AmqpConnection, 0)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connections = append(connections, nil)
connectionProvider := &connectionProviderMock{}
dialer := &amqpDialMock{}
var c *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, fmt.Errorf("dial error"))
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
dialer: dialer,
}
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
pool := AmqpConnectionPool{
config: AmqpConnectionPoolConfig{PoolSize: 5},
handleOffset: 0,
lock: sync.RWMutex{},
connections: connections,
connectionProvider: connectionProvider,
}
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
assert.EqualError(t, err, "renew connection: failed to connect to amqp broker: internal connection connect: dial: dial error")
assert.Nil(t, connection)
assert.Equal(t, 5, len(connections))
})
}

View file

@ -0,0 +1,415 @@
package messaging
import (
"context"
"errors"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"sync"
"testing"
)
type amqpConnMock struct {
mock.Mock
}
func (m *amqpConnMock) Done() <-chan struct{} {
args := m.Called()
return args.Get(0).(<-chan struct{})
}
func (m *amqpConnMock) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) {
args := m.Called(ctx, opts)
return args.Get(0).(amqpSession), args.Error(1)
}
func (m *amqpConnMock) Close() error {
args := m.Called()
return args.Error(0)
}
var _ amqpConn = (*amqpConnMock)(nil)
type amqpDialMock struct {
mock.Mock
}
func (m *amqpDialMock) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) {
args := m.Called(ctx, addr, opts)
return args.Get(0).(amqpConn), args.Error(1)
}
var _ amqpDial = (*amqpDialMock)(nil)
type amqpSessionMock struct {
mock.Mock
}
func (m *amqpSessionMock) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) {
args := m.Called(ctx, target, opts)
return args.Get(0).(amqpSender), args.Error(1)
}
func (m *amqpSessionMock) Close(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
var _ amqpSession = (*amqpSessionMock)(nil)
func Test_AmqpConnection_IsClosed(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
t.Run("is closed - connection nil", func(t *testing.T) {
assert.True(t, connection.IsClosed())
})
t.Run("is closed", func(t *testing.T) {
channel := make(chan struct{})
close(channel)
amqpConnMock := &amqpConnMock{}
amqpConnMock.On("Done").Return(channelReceiver(channel))
connection.conn = amqpConnMock
assert.True(t, connection.IsClosed())
})
t.Run("is not closed", func(t *testing.T) {
channel := make(chan struct{})
amqpConnMock := &amqpConnMock{}
amqpConnMock.On("Done").Return(channelReceiver(channel))
connection.conn = amqpConnMock
assert.False(t, connection.IsClosed())
})
}
func Test_AmqpConnection_Close(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
t.Run("already closed", func(t *testing.T) {
assert.NoError(t, connection.Close())
})
t.Run("close error", func(t *testing.T) {
err := errors.New("test error")
amqpConnMock := &amqpConnMock{}
amqpConnMock.On("Close").Return(err)
connection.conn = amqpConnMock
assert.EqualError(t, connection.Close(), "close: internal connection close: test error")
assert.NotNil(t, connection.conn)
amqpConnMock.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close without error", func(t *testing.T) {
amqpConnMock := &amqpConnMock{}
amqpConnMock.On("Close").Return(nil)
connection.conn = amqpConnMock
assert.Nil(t, connection.Close())
assert.Nil(t, connection.conn)
amqpConnMock.AssertNumberOfCalls(t, "Close", 1)
})
}
func Test_AmqpConnection_Connect(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
t.Run("already connected", func(t *testing.T) {
connection.conn = &amqpConnMock{}
assert.NoError(t, connection.Connect())
})
t.Run("dial error", func(t *testing.T) {
connection.conn = nil
connection.username = "user"
connection.password = "pass"
amqpDialMock := &amqpDialMock{}
var c *amqpConnMock = nil
amqpDialMock.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
connection.dialer = amqpDialMock
assert.EqualError(t, connection.Connect(), "internal connection connect: dial: test error")
assert.Nil(t, connection.conn)
})
t.Run("connect without error", func(t *testing.T) {
connection.conn = nil
amqpDialMock := &amqpDialMock{}
amqpConn := &amqpConnMock{}
amqpDialMock.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(amqpConn, nil)
connection.dialer = amqpDialMock
assert.NoError(t, connection.Connect())
assert.Equal(t, amqpConn, connection.conn)
})
}
func Test_AmqpConnection_ResetConnection(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
t.Run("reset connection - connect on first attempt", func(t *testing.T) {
connection.conn = nil
dialer := &amqpDialMock{}
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(&amqpConnMock{}, nil)
connection.dialer = dialer
assert.NoError(t, connection.ResetConnection(context.Background()))
assert.NotNil(t, connection.conn)
dialer.AssertNumberOfCalls(t, "Dial", 1)
})
t.Run("reset connection - connect on second attempt", func(t *testing.T) {
connection.conn = nil
dialer := &amqpDialMock{}
var conn *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, errors.New("test error")).Once()
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(&amqpConnMock{}, nil)
connection.dialer = dialer
assert.NoError(t, connection.ResetConnection(context.Background()))
assert.NotNil(t, connection.conn)
dialer.AssertNumberOfCalls(t, "Dial", 2)
})
t.Run("reset connection - fail continuously", func(t *testing.T) {
connection.conn = nil
dialer := &amqpDialMock{}
var conn *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, errors.New("test error"))
connection.dialer = dialer
assert.EqualError(t, connection.ResetConnection(context.Background()), "connect: dial: test error")
assert.Nil(t, connection.conn)
dialer.AssertNumberOfCalls(t, "Dial", 2)
})
}
type ResettableFunction struct {
mock.Mock
}
func (f *ResettableFunction) Run(ctx context.Context) (any, error) {
args := f.Called(ctx)
return args.Get(0), args.Error(1)
}
func Test_AmqpConnection_ResetConnectionAndRetryIfErrorWithReturnValue(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
t.Run("no error", func(t *testing.T) {
value, err := connection.ResetConnectionAndRetryIfErrorWithReturnValue("test", func(ctx context.Context) (any, error) {
return struct{}{}, nil
})
assert.NoError(t, err)
assert.Equal(t, struct{}{}, value)
})
t.Run("reset - success", func(t *testing.T) {
connection.conn = nil
dialer := &amqpDialMock{}
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(&amqpConnMock{}, nil)
connection.dialer = dialer
resettableFunction := &ResettableFunction{}
resettableFunction.On("Run", mock.Anything).Return(struct{}{}, errors.New("test error")).Once()
resettableFunction.On("Run", mock.Anything).Return(struct{}{}, nil)
value, err := connection.ResetConnectionAndRetryIfErrorWithReturnValue("test", resettableFunction.Run)
assert.NoError(t, err)
assert.Equal(t, struct{}{}, value)
resettableFunction.AssertNumberOfCalls(t, "Run", 2)
})
t.Run("reset - fail", func(t *testing.T) {
connection.conn = nil
dialer := &amqpDialMock{}
var conn *amqpConnMock = nil
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, errors.New("test error"))
connection.dialer = dialer
resettableFunction := &ResettableFunction{}
resettableFunction.On("Run", mock.Anything).Return(struct{}{}, errors.New("test error"))
value, err := connection.ResetConnectionAndRetryIfErrorWithReturnValue("test", resettableFunction.Run)
assert.EqualError(t, err, "reset connection: connect: dial: test error")
assert.Nil(t, value)
resettableFunction.AssertNumberOfCalls(t, "Run", 1)
})
}
func Test_AmqpConnection_NewSender(t *testing.T) {
connection := &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
}
channelReceiver := func(channel chan struct{}) <-chan struct{} {
return channel
}
t.Run("connection not initialized", func(t *testing.T) {
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "connection is not initialized")
assert.Nil(t, sender)
})
t.Run("connection is closed", func(t *testing.T) {
channel := make(chan struct{})
close(channel)
conn := &amqpConnMock{}
conn.On("Done").Return(channelReceiver(channel))
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "amqp connection is closed")
assert.Nil(t, sender)
})
t.Run("session error", func(t *testing.T) {
channel := make(chan struct{})
var session *amqpSessionMock = nil
conn := &amqpConnMock{}
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, errors.New("test error"))
conn.On("Done").Return(channelReceiver(channel))
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "new session: test error")
assert.Nil(t, sender)
})
t.Run("sender error", func(t *testing.T) {
channel := make(chan struct{})
sessionMock := &amqpSessionMock{}
var amqpSender *amqp.Sender = nil
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, errors.New("test error"))
sessionMock.On("Close", mock.Anything).Return(nil)
conn := &amqpConnMock{}
conn.On("Done").Return(channelReceiver(channel))
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "new internal sender: test error")
assert.Nil(t, sender)
})
t.Run("session close error", func(t *testing.T) {
channel := make(chan struct{})
sessionMock := &amqpSessionMock{}
var amqpSender *amqp.Sender = nil
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, errors.New("test error"))
sessionMock.On("Close", mock.Anything).Return(errors.New("close error"))
conn := &amqpConnMock{}
conn.On("Done").Return(channelReceiver(channel))
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.EqualError(t, err, "new internal sender: test error\nclose session: close error")
assert.Nil(t, sender)
})
t.Run("get sender", func(t *testing.T) {
channel := make(chan struct{})
amqpSender := &amqp.Sender{}
sessionMock := &amqpSessionMock{}
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, nil)
conn := &amqpConnMock{}
conn.On("Done").Return(channelReceiver(channel))
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
connection.conn = conn
sender, err := connection.NewSender(context.Background(), "topic")
assert.NoError(t, err)
assert.NotNil(t, sender)
assert.Equal(t, amqpSender, sender.sender)
assert.Equal(t, sessionMock, sender.session)
})
}
func Test_AmqpConnection_NewAmqpConnection(t *testing.T) {
config := AmqpConnectionConfig{
BrokerUrl: "brokerUrl",
Username: "username",
Password: "password",
}
connection := NewAmqpConnection(&config, "connectionName")
assert.NotNil(t, connection)
assert.Equal(t, connection.connectionName, "connectionName")
assert.Equal(t, connection.brokerUrl, "brokerUrl")
assert.Equal(t, connection.username, "username")
assert.Equal(t, connection.password, "password")
assert.NotNil(t, connection.dialer)
}
func Test_As(t *testing.T) {
t.Run("error", func(t *testing.T) {
value, err := As[amqp.Message](nil, errors.New("test error"))
assert.EqualError(t, err, "test error")
assert.Nil(t, value)
})
t.Run("value nil", func(t *testing.T) {
value, err := As[amqp.Message](nil, nil)
assert.NoError(t, err)
assert.Nil(t, value)
})
t.Run("value not not type", func(t *testing.T) {
value, err := As[amqp.Message](struct{}{}, nil)
assert.EqualError(t, err, "could not cast value: struct {}")
assert.Nil(t, value)
})
t.Run("cast", func(t *testing.T) {
var sessionAny any = &amqpSessionMock{}
value, err := As[amqpSessionMock](sessionAny, nil)
assert.NoError(t, err)
assert.NotNil(t, value)
})
}

View file

@ -0,0 +1,86 @@
package messaging
import (
"context"
"errors"
"fmt"
"github.com/Azure/go-amqp"
"log/slog"
"strings"
"time"
)
type amqpSender interface {
Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error
Close(ctx context.Context) error
}
type AmqpSenderSession struct {
session amqpSession
sender amqpSender
}
func (s *AmqpSenderSession) Send(
topic string,
data [][]byte,
contentType string,
applicationProperties map[string]any,
) error {
// check topic name
if !strings.HasPrefix(topic, AmqpTopicPrefix) {
return fmt.Errorf(
"topic %q name lacks mandatory prefix %q",
topic,
AmqpTopicPrefix,
)
}
if contentType == "" {
return errors.New("content-type is required")
}
// prepare the amqp message
message := amqp.Message{
Header: &amqp.MessageHeader{
Durable: true,
},
Properties: &amqp.MessageProperties{
To: &topic,
ContentType: &contentType,
},
ApplicationProperties: applicationProperties,
Data: data,
}
// send
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFn()
return s.sender.Send(ctx, &message, nil)
}
func (s *AmqpSenderSession) Close() error {
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFn()
var closeErrors []error
senderErr := s.sender.Close(ctx)
if senderErr != nil {
closeErrors = append(closeErrors, senderErr)
}
sessionErr := s.session.Close(ctx)
if sessionErr != nil {
closeErrors = append(closeErrors, sessionErr)
}
if len(closeErrors) > 0 {
return errors.Join(closeErrors...)
}
return nil
}
func (s *AmqpSenderSession) CloseSilently() {
err := s.Close()
if err != nil {
slog.Error("error closing sender session", slog.Any("err", err))
}
}

View file

@ -0,0 +1,186 @@
package messaging
import (
"context"
"errors"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
)
type amqpSenderMock struct {
mock.Mock
}
func (m *amqpSenderMock) Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error {
return m.Called(ctx, msg, opts).Error(0)
}
func (m *amqpSenderMock) Close(ctx context.Context) error {
return m.Called(ctx).Error(0)
}
var _ amqpSender = (*amqpSenderMock)(nil)
func Test_AmqpSenderSession_Close(t *testing.T) {
t.Run("close without errors", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(nil)
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(nil)
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
err := senderSession.Close()
assert.NoError(t, err)
sender.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close with sender error", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(errors.New("sender error"))
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(nil)
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
err := senderSession.Close()
assert.EqualError(t, err, "sender error")
sender.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close with session error", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(nil)
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(errors.New("session error"))
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
err := senderSession.Close()
assert.EqualError(t, err, "session error")
sender.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close with sender and session error", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(errors.New("sender error"))
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(errors.New("session error"))
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
err := senderSession.Close()
assert.EqualError(t, err, "sender error\nsession error")
sender.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
}
func Test_AmqpSenderSession_Send(t *testing.T) {
t.Run("invalid topic name", func(t *testing.T) {
sender := &amqpSenderMock{}
session := &amqpSessionMock{}
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
data := [][]byte{[]byte("data")}
err := senderSession.Send("invalid", data, "application/json", map[string]interface{}{})
assert.EqualError(t, err, "topic \"invalid\" name lacks mandatory prefix \"topic://\"")
})
t.Run("content type missing", func(t *testing.T) {
sender := &amqpSenderMock{}
session := &amqpSessionMock{}
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
data := [][]byte{[]byte("data")}
err := senderSession.Send("topic://some/name", data, "", map[string]interface{}{})
assert.EqualError(t, err, "content-type is required")
})
t.Run("send", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
session := &amqpSessionMock{}
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
data := [][]byte{[]byte("data")}
applicationProperties := map[string]interface{}{}
applicationProperties["key"] = "value"
err := senderSession.Send("topic://some/name", data, "application/json", applicationProperties)
assert.NoError(t, err)
sender.AssertNumberOfCalls(t, "Send", 1)
calls := sender.Calls
assert.Equal(t, 1, len(calls))
ctx, isCtx := calls[0].Arguments[0].(context.Context)
assert.True(t, isCtx)
assert.NotNil(t, ctx)
message, isMsg := calls[0].Arguments[1].(*amqp.Message)
assert.True(t, isMsg)
assert.True(t, message.Header.Durable)
assert.Equal(t, "topic://some/name", *message.Properties.To)
assert.Equal(t, "application/json", *message.Properties.ContentType)
assert.Equal(t, applicationProperties, message.ApplicationProperties)
assert.Equal(t, data, message.Data)
senderOptions, isSenderOptions := calls[0].Arguments[2].(*amqp.SendOptions)
assert.True(t, isSenderOptions)
assert.Nil(t, senderOptions)
})
t.Run("send fails", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("send fail"))
session := &amqpSessionMock{}
senderSession := &AmqpSenderSession{
sender: sender,
session: session,
}
data := [][]byte{[]byte("data")}
applicationProperties := map[string]interface{}{}
applicationProperties["key"] = "value"
err := senderSession.Send("topic://some/name", data, "application/json", applicationProperties)
assert.EqualError(t, err, "send fail")
sender.AssertNumberOfCalls(t, "Send", 1)
})
}

View file

@ -2,19 +2,13 @@ package messaging
import ( import (
"context" "context"
"dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/log"
"errors" "errors"
"fmt" "fmt"
"strings"
"sync" "sync"
"time" "time"
"dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/log"
"github.com/Azure/go-amqp"
) )
// Default connection timeout for the AMQP connection
const connectionTimeoutSeconds = 10
// Api is an abstraction for a messaging system that can be used to send // Api is an abstraction for a messaging system that can be used to send
// audit logs to the audit log system. // audit logs to the audit log system.
type Api interface { type Api interface {
@ -38,203 +32,99 @@ type Api interface {
Close(ctx context.Context) error Close(ctx context.Context) error
} }
// MutexApi is wrapper around an API implementation that controls mutual exclusive access to the api.
type MutexApi struct {
mutex sync.Mutex
api Api
}
var _ Api = &MutexApi{}
func NewMutexApi(api Api) (Api, error) {
if api == nil {
return nil, errors.New("api is nil")
}
mutexApi := MutexApi{
mutex: sync.Mutex{},
api: api,
}
var genericApi Api = &mutexApi
return genericApi, nil
}
// Send implements Api.Send
func (m *MutexApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.api.Send(ctx, topic, data, contentType, applicationProperties)
}
func (m *MutexApi) Close(ctx context.Context) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.api.Close(ctx)
}
// AmqpConfig provides AMQP connection related parameters.
type AmqpConfig struct {
URL string
User string
Password string
}
// AmqpSession is an abstraction providing a subset of the methods of amqp.Session
type AmqpSession interface {
NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error)
Close(ctx context.Context) error
}
type AmqpSessionWrapper struct {
session *amqp.Session
}
func (w AmqpSessionWrapper) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) {
return w.session.NewSender(ctx, target, opts)
}
func (w AmqpSessionWrapper) Close(ctx context.Context) error {
return w.session.Close(ctx)
}
// AmqpSender is an abstraction providing a subset of the methods of amqp.Sender
type AmqpSender interface {
Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error
Close(ctx context.Context) error
}
// AmqpApi implements Api. // AmqpApi implements Api.
type AmqpApi struct { type AmqpApi struct {
config AmqpConfig config AmqpConnectionPoolConfig
connection *amqp.Conn connection *AmqpConnection
session AmqpSession connectionPool ConnectionPool
connectionPoolHandle *ConnectionPoolHandle
senderCache map[string]*AmqpSenderSession
lock sync.RWMutex
} }
var _ Api = &AmqpApi{} var _ Api = &AmqpApi{}
func NewAmqpApi(amqpConfig AmqpConfig) (Api, error) { func NewAmqpApi(amqpConfig AmqpConnectionPoolConfig) (Api, error) {
amqpApi := &AmqpApi{config: amqpConfig} connectionPool, err := NewAmqpConnectionPool(amqpConfig, "sdk")
if err := amqpApi.connect(); err != nil {
return nil, fmt.Errorf("connect to broker: %w", err)
}
return amqpApi, nil
}
// connect opens a new connection and session to the AMQP messaging system.
// The connection attempt will be cancelled after connectionTimeoutSeconds.
func (a *AmqpApi) connect() error {
log.AuditLogger.Info("connecting to audit messaging system")
// Set credentials if specified
auth := amqp.SASLTypeAnonymous()
if a.config.User != "" && a.config.Password != "" {
auth = amqp.SASLTypePlain(a.config.User, a.config.Password)
log.AuditLogger.Info("using username and password for messaging")
} else {
log.AuditLogger.Warn("using anonymous messaging!")
}
options := &amqp.ConnOptions{
SASLType: auth,
}
// Create new context with timeout for the connection initialization
subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second)
defer cancel()
// Initialize connection
conn, err := amqp.Dial(subCtx, a.config.URL, options)
if err != nil { if err != nil {
return fmt.Errorf("dial connection to broker: %w", err) return nil, fmt.Errorf("new amqp connection pool: %w", err)
}
a.connection = conn
// Initialize session
session, err := conn.NewSession(context.Background(), nil)
if err != nil {
return fmt.Errorf("create session: %w", err)
} }
var amqpSession AmqpSession = &AmqpSessionWrapper{session: session} amqpApi := &AmqpApi{config: amqpConfig,
a.session = amqpSession connectionPool: connectionPool,
connectionPoolHandle: connectionPool.NewHandle(),
senderCache: make(map[string]*AmqpSenderSession),
}
return nil var messagingApi Api = amqpApi
return messagingApi, nil
} }
// Send implements Api.Send. // Send implements Api.Send.
// If errors occur the connection to the messaging system will be closed and re-established. // If errors occur the connection to the messaging system will be closed and re-established.
func (a *AmqpApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error { func (a *AmqpApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
err := a.trySend(ctx, topic, data, contentType, applicationProperties)
if err == nil {
return nil
}
// Drop the current sender, as it cannot connect to the broker anymore a.lock.RLock()
log.AuditLogger.Error("message sender error, recreating", err) connectionIsClosed := a.connection == nil || a.connection.IsClosed()
a.lock.RUnlock()
err = a.resetConnection(ctx) if connectionIsClosed {
if err != nil { connection, err := a.connectionPool.GetConnection(a.connectionPoolHandle)
return fmt.Errorf("reset connection: %w", err) if err != nil {
} return fmt.Errorf("get connection: %w", err)
return a.trySend(ctx, topic, data, contentType, applicationProperties)
}
// trySend actually sends the given data as amqp.Message to the messaging system.
func (a *AmqpApi) trySend(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
if !strings.HasPrefix(topic, AmqpTopicPrefix) {
return fmt.Errorf(
"topic %q name lacks mandatory prefix %q",
topic,
AmqpTopicPrefix,
)
}
sender, err := a.session.NewSender(ctx, topic, nil)
if err != nil {
return fmt.Errorf("new sender: %w", err)
}
defer func() {
if err := sender.Close(ctx); err != nil {
log.AuditLogger.Error("failed to close session sender", err)
} }
}() a.lock.Lock()
a.connection = connection
bytes := [][]byte{data} a.lock.Unlock()
message := amqp.Message{
Header: &amqp.MessageHeader{
Durable: true,
},
Properties: &amqp.MessageProperties{
To: &topic,
ContentType: &contentType,
},
ApplicationProperties: applicationProperties,
Data: bytes,
} }
err = sender.Send(ctx, &message, nil) a.lock.RLock()
if err != nil { var sender = a.senderCache[topic]
return fmt.Errorf("send message: %w", err) a.lock.RUnlock()
if sender == nil {
a.lock.RLock()
ctx, cancelFn := context.WithTimeout(ctx, 10*time.Second)
senderSession, err := a.connection.NewSender(ctx, topic)
cancelFn()
a.lock.RUnlock()
if err != nil {
return fmt.Errorf("new sender: %w", err)
}
a.lock.Lock()
a.senderCache[topic] = senderSession
a.lock.Unlock()
sender = senderSession
} }
wrappedData := [][]byte{data}
if err := sender.Send(topic, wrappedData, contentType, applicationProperties); err != nil {
return fmt.Errorf("send: %w", err)
}
return nil return nil
} }
// resetConnection closes the current session and connection and reconnects to the messaging system. // Close implements Api.Close
func (a *AmqpApi) resetConnection(ctx context.Context) error { func (a *AmqpApi) Close(_ context.Context) error {
if err := a.Close(ctx); err != nil { log.AuditLogger.Info("close audit amqp connection pool")
log.AuditLogger.Error("failed to close audit messaging connection", err)
a.lock.Lock()
defer a.lock.Unlock()
// cached senders
var closeErrors []error
for _, session := range a.senderCache {
if err := session.Close(); err != nil {
closeErrors = append(closeErrors, fmt.Errorf("close session: %w", err))
}
}
clear(a.senderCache)
// pool
if err := a.connectionPool.Close(); err != nil {
closeErrors = append(closeErrors, fmt.Errorf("close pool: %w", err))
} }
return a.connect() if len(closeErrors) > 0 {
} return fmt.Errorf("close: %w", errors.Join(closeErrors...))
}
// Close implements Api.Close return nil
func (a *AmqpApi) Close(ctx context.Context) error {
log.AuditLogger.Info("close audit messaging connection")
return errors.Join(a.session.Close(ctx), a.connection.Close())
} }

View file

@ -4,50 +4,38 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"testing"
"time"
"github.com/Azure/go-amqp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"sync"
"testing"
"time"
) )
type AmqpSessionMock struct { type connectionPoolMock struct {
mock.Mock mock.Mock
} }
func (m *AmqpSessionMock) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) { func (m *connectionPoolMock) Close() error {
args := m.Called(ctx, target, opts) return m.Called().Error(0)
var sender AmqpSender = nil
if args.Get(0) != nil {
sender = args.Get(0).(AmqpSender)
}
err := args.Error(1)
return sender, err
} }
func (m *AmqpSessionMock) Close(ctx context.Context) error { func (m *connectionPoolMock) NewHandle() *ConnectionPoolHandle {
args := m.Called(ctx) return m.Called().Get(0).(*ConnectionPoolHandle)
return args.Error(0)
} }
type AmqpSenderMock struct { func (m *connectionPoolMock) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) {
mock.Mock return m.Called(handle).Get(0).(*AmqpConnection), m.Called(handle).Error(1)
} }
func (m *AmqpSenderMock) Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error { var _ ConnectionPool = (*connectionPoolMock)(nil)
args := m.Called(ctx, msg, opts)
return args.Error(0)
}
func (m *AmqpSenderMock) Close(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func Test_NewAmqpMessagingApi(t *testing.T) { func Test_NewAmqpMessagingApi(t *testing.T) {
_, err := NewAmqpApi(AmqpConfig{URL: "not-handled-protocol://localhost:5672"}) _, err := NewAmqpApi(
assert.EqualError(t, err, "connect to broker: dial connection to broker: unsupported scheme \"not-handled-protocol\"") AmqpConnectionPoolConfig{
Parameters: AmqpConnectionConfig{BrokerUrl: "not-handled-protocol://localhost:5672"},
PoolSize: 1,
})
assert.EqualError(t, err, "new amqp connection pool: initialize connections: new connection: failed to connect to amqp broker: internal connection connect: dial: unsupported scheme \"not-handled-protocol\"")
} }
func Test_AmqpMessagingApi_Send(t *testing.T) { func Test_AmqpMessagingApi_Send(t *testing.T) {
@ -63,121 +51,354 @@ func Test_AmqpMessagingApi_Send(t *testing.T) {
t.Run("Missing topic prefix", func(t *testing.T) { t.Run("Missing topic prefix", func(t *testing.T) {
defer solaceContainer.StopOnError() defer solaceContainer.StopOnError()
api, err := NewAmqpApi(AmqpConfig{URL: solaceContainer.AmqpConnectionString}) api, err := NewAmqpApi(AmqpConnectionPoolConfig{
Parameters: AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err) assert.NoError(t, err)
err = api.Send(ctx, "topic-name", []byte{}, "application/json", make(map[string]any)) err = api.Send(ctx, "topic-name", []byte{}, "application/json", make(map[string]any))
assert.EqualError(t, err, "topic \"topic-name\" name lacks mandatory prefix \"topic://\"") assert.EqualError(t, err, "send: topic \"topic-name\" name lacks mandatory prefix \"topic://\"")
}) })
t.Run("Close connection without errors", func(t *testing.T) { t.Run("send successfully", func(t *testing.T) {
defer solaceContainer.StopOnError() defer solaceContainer.StopOnError()
// Initialize the solace queue // Initialize the solace queue
topicSubscriptionTopicPattern := "auditlog/>" topicSubscriptionTopicPattern := "auditlog/>"
queueName := "close-connection-without-error" queueName := "send-successfully"
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName)) assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern)) assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-close-connection") topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-send-successfully")
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName)) assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}} api, err := NewAmqpApi(AmqpConnectionPoolConfig{
err := api.connect() Parameters: AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
PoolSize: 1,
})
assert.NoError(t, err) assert.NoError(t, err)
data := []byte("data")
applicationProperties := make(map[string]interface{})
applicationProperties["key"] = "value"
err = api.Send(ctx, topicName, data, "application/json", applicationProperties)
assert.NoError(t, err)
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true)
assert.NoError(t, err)
assert.Equal(t, "data", string(message.Data[0]))
assert.Equal(t, topicName, *message.Properties.To)
assert.Equal(t, "application/json", *message.Properties.ContentType)
assert.Equal(t, applicationProperties, message.ApplicationProperties)
err = api.Close(ctx) err = api.Close(ctx)
assert.NoError(t, err) assert.NoError(t, err)
}) })
}
t.Run("New sender call returns error", func(t *testing.T) { func Test_AmqpMessagingApi_Send_Special_Cases(t *testing.T) {
defer solaceContainer.StopOnError()
// Initialize the solace queue channelReceiver := func(channel chan struct{}) <-chan struct{} {
topicSubscriptionTopicPattern := "auditlog/>" return channel
queueName := "messaging-new-sender" }
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-no-new-sender")
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}} newActiveConnection := func() *AmqpConnection {
err := api.connect() channel := make(chan struct{})
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
newClosedConnection := func() *AmqpConnection {
channel := make(chan struct{})
close(channel)
conn := &amqpConnMock{}
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
return &AmqpConnection{
connectionName: "test",
lock: sync.RWMutex{},
conn: conn,
}
}
t.Run("connection nil sender nil", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
session := &amqpSessionMock{}
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
connection := newActiveConnection()
conn := connection.conn.(*amqpConnMock)
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
pool := &connectionPoolMock{}
pool.On("GetConnection", mock.Anything).Return(connection, nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.NoError(t, err) assert.NoError(t, err)
expectedError := errors.New("expected error") sender.AssertNumberOfCalls(t, "Send", 1)
session.AssertNumberOfCalls(t, "NewSender", 1)
// Set mock session pool.AssertNumberOfCalls(t, "GetConnection", 2)
sessionMock := AmqpSessionMock{}
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(nil, expectedError)
sessionMock.On("Close", mock.Anything).Return(nil)
var amqpSession AmqpSession = &sessionMock
api.session = amqpSession
// It's expected that the test succeeds.
// First the session is closed as it returns the expected error
// Then the retry mechanism restarts the connection and successfully sends the data
value := "test"
err = api.Send(ctx, topicName, []byte(value), "application/json", make(map[string]any))
assert.NoError(t, err)
// Check that the mock was called
assert.True(t, sessionMock.AssertNumberOfCalls(t, "NewSender", 1))
assert.True(t, sessionMock.AssertNumberOfCalls(t, "Close", 1))
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true)
assert.NoError(t, err)
assert.Equal(t, value, string(message.Data[0]))
assert.Equal(t, topicName, *message.Properties.To)
}) })
t.Run("Send call on sender returns error", func(t *testing.T) { t.Run("connection closed sender nil", func(t *testing.T) {
defer solaceContainer.StopOnError() sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
// Initialize the solace queue session := &amqpSessionMock{}
topicSubscriptionTopicPattern := "auditlog/>" session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
queueName := "messaging-sender-error"
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-sender-error")
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}} connection := newActiveConnection()
err := api.connect() conn := connection.conn.(*amqpConnMock)
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
pool := &connectionPoolMock{}
pool.On("GetConnection", mock.Anything).Return(connection, nil)
closedConnection := newClosedConnection()
closedConnMock := closedConnection.conn.(*amqpConnMock)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: closedConnection,
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.NoError(t, err) assert.NoError(t, err)
expectedError := errors.New("expected error") sender.AssertNumberOfCalls(t, "Send", 1)
session.AssertNumberOfCalls(t, "NewSender", 1)
pool.AssertNumberOfCalls(t, "GetConnection", 2)
closedConnMock.AssertNumberOfCalls(t, "Done", 1)
})
// Instantiate mock sender t.Run("connection nil get connection fail", func(t *testing.T) {
senderMock := AmqpSenderMock{} var connection *AmqpConnection = nil
senderMock.On("Send", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError)
senderMock.On("Close", mock.Anything).Return(nil)
var amqpSender AmqpSender = &senderMock
// Set mock session pool := &connectionPoolMock{}
sessionMock := AmqpSessionMock{} pool.On("GetConnection", mock.Anything).Return(connection, errors.New("connection error"))
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(&amqpSender, nil)
sessionMock.On("Close", mock.Anything).Return(nil)
var amqpSession AmqpSession = &sessionMock amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
api.session = amqpSession connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
// It's expected that the test succeeds. err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
// First the sender and session are closed as the sender returns the expected error assert.EqualError(t, err, "get connection: connection error")
// Then the retry mechanism restarts the connection and successfully sends the data
value := "test" pool.AssertNumberOfCalls(t, "GetConnection", 2)
err = api.Send(ctx, topicName, []byte(value), "application/json", make(map[string]any)) })
t.Run("connection active sender nil", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
session := &amqpSessionMock{}
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
connection := newActiveConnection()
conn := connection.conn.(*amqpConnMock)
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: connection,
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.NoError(t, err) assert.NoError(t, err)
// Check that the mocks were called sender.AssertNumberOfCalls(t, "Send", 1)
assert.True(t, sessionMock.AssertNumberOfCalls(t, "NewSender", 1)) session.AssertNumberOfCalls(t, "NewSender", 1)
assert.True(t, sessionMock.AssertNumberOfCalls(t, "Close", 1)) })
assert.True(t, senderMock.AssertNumberOfCalls(t, "Send", 1))
assert.True(t, senderMock.AssertNumberOfCalls(t, "Close", 1))
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true) t.Run("connection active new sender fail", func(t *testing.T) {
var sender *amqpSenderMock = nil
session := &amqpSessionMock{}
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, errors.New("new sender error"))
session.On("Close", mock.Anything).Return(nil)
connection := newActiveConnection()
conn := connection.conn.(*amqpConnMock)
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: connection,
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
assert.EqualError(t, err, "new sender: new internal sender: new sender error")
session.AssertNumberOfCalls(t, "NewSender", 1)
session.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("connection active sender set", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
topic := "topic://some-topic"
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: newActiveConnection(),
senderCache: map[string]*AmqpSenderSession{topic: {sender: sender}},
}
err := amqpApi.Send(context.Background(), topic, []byte("data"), "application/json", make(map[string]any))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, value, string(message.Data[0]))
assert.Equal(t, topicName, *message.Properties.To) sender.AssertNumberOfCalls(t, "Send", 1)
})
t.Run("send fail", func(t *testing.T) {
sender := &amqpSenderMock{}
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("send error"))
topic := "topic://some-topic"
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connection: newActiveConnection(),
senderCache: map[string]*AmqpSenderSession{topic: {sender: sender}},
}
err := amqpApi.Send(context.Background(), topic, []byte("data"), "application/json", make(map[string]any))
assert.EqualError(t, err, "send: send error")
sender.AssertNumberOfCalls(t, "Send", 1)
})
}
func Test_AmqpMessagingApi_Close(t *testing.T) {
t.Run("close without cached senders", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(nil)
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Close(context.Background())
assert.NoError(t, err)
pool.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close fail without cached senders", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(errors.New("close error"))
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: make(map[string]*AmqpSenderSession),
}
err := amqpApi.Close(context.Background())
assert.EqualError(t, err, "close: close pool: close error")
pool.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close with cached senders", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(nil)
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(nil)
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(nil)
senderSession := &AmqpSenderSession{
session: session,
sender: sender,
}
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
}
err := amqpApi.Close(context.Background())
assert.NoError(t, err)
assert.Equal(t, 0, len(amqpApi.senderCache))
pool.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
sender.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close fail with cached senders", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(nil)
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(nil)
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(errors.New("close sender error"))
senderSession := &AmqpSenderSession{
session: session,
sender: sender,
}
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
}
err := amqpApi.Close(context.Background())
assert.EqualError(t, err, "close: close session: close sender error")
assert.Equal(t, 0, len(amqpApi.senderCache))
pool.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
sender.AssertNumberOfCalls(t, "Close", 1)
})
t.Run("close fail", func(t *testing.T) {
pool := &connectionPoolMock{}
pool.On("Close").Return(errors.New("close pool error"))
session := &amqpSessionMock{}
session.On("Close", mock.Anything).Return(errors.New("close session error"))
sender := &amqpSenderMock{}
sender.On("Close", mock.Anything).Return(errors.New("close sender error"))
senderSession := &AmqpSenderSession{
session: session,
sender: sender,
}
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
connectionPool: pool,
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
}
err := amqpApi.Close(context.Background())
assert.EqualError(t, err, "close: close session: close sender error\nclose session error\nclose pool: close pool error")
assert.Equal(t, 0, len(amqpApi.senderCache))
pool.AssertNumberOfCalls(t, "Close", 1)
session.AssertNumberOfCalls(t, "Close", 1)
sender.AssertNumberOfCalls(t, "Close", 1)
}) })
} }

View file

@ -18,7 +18,6 @@ import (
) )
const ( const (
AmqpTopicPrefix = "topic://"
AmqpQueuePrefix = "queue://" AmqpQueuePrefix = "queue://"
) )