mirror of
https://dev.azure.com/schwarzit/schwarzit.stackit-public/_git/audit-go
synced 2026-02-07 16:47:24 +00:00
Merged PR 716929: feat: Replace AMQP connection management
So far the SDK provided a messaging API that was not thread-safe (i.e. goroutine-safe). Additionally the SDK provided a MutexAPI which made it thread-safe at the cost of removed concurrency possibilities. The changes implemented in this commit replace both implementations with a thread-safe connection pool based solution. The api gateway is a SDK user that requires reliable high performance send capabilities with a limit amount of amqp connections. These changes in the PR try address their requirements by moving the responsibility of connection management into the SDK. From this change other SDK users will benefit as well. Security-concept-update-needed: false. JIRA Work Item: STACKITALO-62
This commit is contained in:
parent
c90ce29c51
commit
5742604629
13 changed files with 2056 additions and 291 deletions
|
|
@ -30,7 +30,11 @@ func TestDynamicLegacyAuditApi(t *testing.T) {
|
|||
defer solaceContainer.Stop()
|
||||
|
||||
// Instantiate the messaging api
|
||||
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString})
|
||||
messagingApi, err := messaging.NewAmqpApi(
|
||||
messaging.AmqpConnectionPoolConfig{
|
||||
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
|
||||
PoolSize: 1,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Validator
|
||||
|
|
|
|||
|
|
@ -32,7 +32,10 @@ func TestLegacyAuditApi(t *testing.T) {
|
|||
defer solaceContainer.Stop()
|
||||
|
||||
// Instantiate the messaging api
|
||||
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString})
|
||||
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
|
||||
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
|
||||
PoolSize: 1,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Validator
|
||||
|
|
@ -579,7 +582,10 @@ func TestLegacyAuditApi_NewLegacyAuditApi(t *testing.T) {
|
|||
defer solaceContainer.Stop()
|
||||
|
||||
// Instantiate the messaging api
|
||||
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString})
|
||||
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
|
||||
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
|
||||
PoolSize: 1,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Validator
|
||||
|
|
|
|||
|
|
@ -33,7 +33,10 @@ func TestRoutableAuditApi(t *testing.T) {
|
|||
defer solaceContainer.Stop()
|
||||
|
||||
// Instantiate the messaging api
|
||||
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString})
|
||||
messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{
|
||||
Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
|
||||
PoolSize: 1,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Validator
|
||||
|
|
|
|||
15
audit/messaging/amqp_config.go
Normal file
15
audit/messaging/amqp_config.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
package messaging
|
||||
|
||||
const AmqpTopicPrefix = "topic://"
|
||||
const connectionTimeoutSeconds = 10
|
||||
|
||||
type AmqpConnectionConfig struct {
|
||||
BrokerUrl string `json:"brokerUrl"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type AmqpConnectionPoolConfig struct {
|
||||
Parameters AmqpConnectionConfig `json:"parameters"`
|
||||
PoolSize int `json:"poolSize"`
|
||||
}
|
||||
232
audit/messaging/amqp_connection.go
Normal file
232
audit/messaging/amqp_connection.go
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
package messaging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Azure/go-amqp"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ConnectionClosedError = errors.New("amqp connection is closed")
|
||||
|
||||
type AmqpConnection struct {
|
||||
connectionName string
|
||||
lock sync.RWMutex
|
||||
brokerUrl string
|
||||
username string
|
||||
password string
|
||||
conn amqpConn
|
||||
dialer amqpDial
|
||||
}
|
||||
|
||||
// amqpConn is an abstraction of amqp.Conn
|
||||
type amqpConn interface {
|
||||
NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error)
|
||||
Close() error
|
||||
Done() <-chan struct{}
|
||||
}
|
||||
|
||||
type defaultAmqpConn struct {
|
||||
conn *amqp.Conn
|
||||
}
|
||||
|
||||
func newDefaultAmqpConn(conn *amqp.Conn) *defaultAmqpConn {
|
||||
return &defaultAmqpConn{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (d defaultAmqpConn) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) {
|
||||
session, err := d.conn.NewSession(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newDefaultAmqpSession(session), nil
|
||||
}
|
||||
|
||||
func (d defaultAmqpConn) Close() error {
|
||||
return d.conn.Close()
|
||||
}
|
||||
|
||||
func (d defaultAmqpConn) Done() <-chan struct{} {
|
||||
return d.conn.Done()
|
||||
}
|
||||
|
||||
var _ amqpConn = (*defaultAmqpConn)(nil)
|
||||
|
||||
type amqpDial interface {
|
||||
Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error)
|
||||
}
|
||||
|
||||
type amqpSession interface {
|
||||
NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error)
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
type defaultAmqpSession struct {
|
||||
session *amqp.Session
|
||||
}
|
||||
|
||||
func newDefaultAmqpSession(session *amqp.Session) *defaultAmqpSession {
|
||||
return &defaultAmqpSession{
|
||||
session: session,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *defaultAmqpSession) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) {
|
||||
return s.session.NewSender(ctx, target, opts)
|
||||
}
|
||||
|
||||
func (s *defaultAmqpSession) Close(ctx context.Context) error {
|
||||
return s.session.Close(ctx)
|
||||
}
|
||||
|
||||
var _ amqpSession = (*defaultAmqpSession)(nil)
|
||||
|
||||
type defaultAmqpDialer struct{}
|
||||
|
||||
func (d *defaultAmqpDialer) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) {
|
||||
dial, err := amqp.Dial(ctx, addr, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newDefaultAmqpConn(dial), nil
|
||||
}
|
||||
|
||||
var _ amqpDial = (*defaultAmqpDialer)(nil)
|
||||
|
||||
func NewAmqpConnection(config AmqpConnectionConfig, connectionName string) *AmqpConnection {
|
||||
return &AmqpConnection{
|
||||
connectionName: connectionName,
|
||||
lock: sync.RWMutex{},
|
||||
brokerUrl: config.BrokerUrl,
|
||||
username: config.Username,
|
||||
password: config.Password,
|
||||
dialer: &defaultAmqpDialer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AmqpConnection) NewSender(ctx context.Context, topic string) (*AmqpSenderSession, error) {
|
||||
if c.conn == nil {
|
||||
return nil, errors.New("connection is not initialized")
|
||||
}
|
||||
|
||||
if c.internalIsClosed() {
|
||||
return nil, ConnectionClosedError
|
||||
}
|
||||
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
// new session
|
||||
newSession, err := c.conn.NewSession(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
// new sender
|
||||
newSender, err := newSession.NewSender(ctx, topic, nil)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("new internal sender: %w", err)
|
||||
|
||||
closeErr := newSession.Close(ctx)
|
||||
if closeErr != nil {
|
||||
return nil, errors.Join(err, fmt.Errorf("close session: %w", closeErr))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AmqpSenderSession{newSession, newSender}, nil
|
||||
}
|
||||
|
||||
func As[T any](value any, err error) (*T, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
castedValue, isType := value.(*T)
|
||||
if !isType {
|
||||
return nil, fmt.Errorf("could not cast value: %T", value)
|
||||
}
|
||||
return castedValue, nil
|
||||
}
|
||||
|
||||
func (c *AmqpConnection) Connect() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := c.internalConnect(subCtx); err != nil {
|
||||
return fmt.Errorf("internal connect: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AmqpConnection) internalConnect(ctx context.Context) error {
|
||||
if c.conn == nil {
|
||||
// Set credentials if specified
|
||||
auth := amqp.SASLTypeAnonymous()
|
||||
if c.username != "" && c.password != "" {
|
||||
auth = amqp.SASLTypePlain(c.username, c.password)
|
||||
} else {
|
||||
slog.Debug("amqp connection: connect: using anonymous messaging")
|
||||
}
|
||||
options := &amqp.ConnOptions{
|
||||
SASLType: auth,
|
||||
}
|
||||
|
||||
// Initialize connection
|
||||
conn, err := c.dialer.Dial(ctx, c.brokerUrl, options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial: %w", err)
|
||||
}
|
||||
c.conn = conn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AmqpConnection) Close() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if err := c.internalClose(); err != nil {
|
||||
return fmt.Errorf("internal close: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AmqpConnection) internalClose() error {
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
return fmt.Errorf("connection close: %w", err)
|
||||
}
|
||||
c.conn = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AmqpConnection) IsClosed() bool {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
return c.internalIsClosed()
|
||||
}
|
||||
|
||||
func (c *AmqpConnection) internalIsClosed() bool {
|
||||
if c.conn == nil {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-c.conn.Done():
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
219
audit/messaging/amqp_connection_pool.go
Normal file
219
audit/messaging/amqp_connection_pool.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
package messaging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type connectionProvider interface {
|
||||
NewAmqpConnection(config AmqpConnectionConfig, connectionName string) *AmqpConnection
|
||||
}
|
||||
|
||||
type defaultAmqpConnectionProvider struct{}
|
||||
|
||||
func (p defaultAmqpConnectionProvider) NewAmqpConnection(config AmqpConnectionConfig, connectionName string) *AmqpConnection {
|
||||
return NewAmqpConnection(config, connectionName)
|
||||
}
|
||||
|
||||
var _ connectionProvider = (*defaultAmqpConnectionProvider)(nil)
|
||||
|
||||
type ConnectionPool interface {
|
||||
Close() error
|
||||
NewHandle() *ConnectionPoolHandle
|
||||
GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error)
|
||||
}
|
||||
|
||||
type AmqpConnectionPool struct {
|
||||
config AmqpConnectionPoolConfig
|
||||
connectionName string
|
||||
connections []*AmqpConnection
|
||||
connectionProvider connectionProvider
|
||||
handleOffset int
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
type ConnectionPoolHandle struct {
|
||||
connectionOffset int
|
||||
}
|
||||
|
||||
func NewAmqpConnectionPool(config AmqpConnectionPoolConfig, connectionName string) (ConnectionPool, error) {
|
||||
pool := &AmqpConnectionPool{
|
||||
config: config,
|
||||
connectionName: connectionName,
|
||||
connections: make([]*AmqpConnection, 0),
|
||||
connectionProvider: defaultAmqpConnectionProvider{},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
if err := pool.initializeConnections(); err != nil {
|
||||
if closeErr := pool.Close(); closeErr != nil {
|
||||
return nil, errors.Join(err, fmt.Errorf("initialize amqp connection: pool closed: %w", closeErr))
|
||||
}
|
||||
return nil, fmt.Errorf("initialize connections: %w", err)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func (p *AmqpConnectionPool) initializeConnections() error {
|
||||
if len(p.connections) < p.config.PoolSize {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
numMissingConnections := p.config.PoolSize - len(p.connections)
|
||||
|
||||
for i := 0; i < numMissingConnections; i++ {
|
||||
if err := p.internalAddConnection(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *AmqpConnectionPool) internalAddConnection() error {
|
||||
newConnection, err := p.internalNewConnection()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new connection: %w", err)
|
||||
}
|
||||
p.connections = append(p.connections, newConnection)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *AmqpConnectionPool) internalNewConnection() (*AmqpConnection, error) {
|
||||
conn := p.connectionProvider.NewAmqpConnection(p.config.Parameters, p.connectionName)
|
||||
if err := conn.Connect(); err != nil {
|
||||
slog.Warn("amqp connection: failed to connect to amqp broker", slog.Any("err", err))
|
||||
|
||||
// retry
|
||||
if err = conn.Connect(); err != nil {
|
||||
connectErr := fmt.Errorf("new internal connection: %w", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
// this case should never happen as the inner connection should always be null, therefore
|
||||
// it should not have to be closed, i.e. be able to return errors.
|
||||
return nil, errors.Join(connectErr, fmt.Errorf("close connection: %w", closeErr))
|
||||
}
|
||||
return nil, connectErr
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (p *AmqpConnectionPool) Close() error {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
closeErrors := make([]error, 0)
|
||||
for _, conn := range p.connections {
|
||||
if conn != nil {
|
||||
if err := conn.Close(); err != nil {
|
||||
closeErrors = append(closeErrors, fmt.Errorf("pooled connection: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
p.connections = make([]*AmqpConnection, p.config.PoolSize)
|
||||
if len(closeErrors) > 0 {
|
||||
return errors.Join(closeErrors...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *AmqpConnectionPool) NewHandle() *ConnectionPoolHandle {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
offset := p.handleOffset
|
||||
p.handleOffset += 1
|
||||
|
||||
offset = offset % p.config.PoolSize
|
||||
|
||||
return &ConnectionPoolHandle{
|
||||
connectionOffset: offset,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AmqpConnectionPool) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) {
|
||||
// get the requested connection or another one
|
||||
conn, addConnection := p.nextConnectionForHandle(handle)
|
||||
|
||||
// renew the requested connection if the request connection is closed
|
||||
if conn == nil || addConnection {
|
||||
p.lock.Lock()
|
||||
|
||||
// check that accessing the pool only with a valid index (out of bounds should only occur on shutdown)
|
||||
connectionIndex := p.connectionIndex(handle, 0)
|
||||
if p.connections[connectionIndex] == nil {
|
||||
connection, err := p.internalNewConnection()
|
||||
if err != nil {
|
||||
if conn == nil {
|
||||
// case: connection could not be renewed and no connection to return has been found
|
||||
p.lock.Unlock()
|
||||
return nil, fmt.Errorf("renew connection: %w", err)
|
||||
}
|
||||
|
||||
// case: connection could not be renewed but another connection will be returned
|
||||
slog.Warn("amqp connection pool: get connection: renew connection: ", slog.Any("err", err))
|
||||
} else {
|
||||
// case: connection could be renewed and will be added to pool
|
||||
p.connections[connectionIndex] = connection
|
||||
conn = connection
|
||||
}
|
||||
}
|
||||
p.lock.Unlock()
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("amqp connection pool: get connection: failed to obtain connection")
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (p *AmqpConnectionPool) nextConnectionForHandle(handle *ConnectionPoolHandle) (*AmqpConnection, bool) {
|
||||
// retry as long as there are remaining connections in the pool
|
||||
var conn *AmqpConnection
|
||||
var addConnection bool
|
||||
for i := 0; i < p.config.PoolSize; i++ {
|
||||
|
||||
// get the next possible connection (considering the retry index)
|
||||
idx := p.connectionIndex(handle, i)
|
||||
p.lock.RLock()
|
||||
if idx < len(p.connections) {
|
||||
conn = p.connections[idx]
|
||||
} else {
|
||||
// handle the edge case that the pool is empty on shutdown
|
||||
conn = nil
|
||||
}
|
||||
p.lock.RUnlock()
|
||||
|
||||
// remember that the requested is closed, retry with the next
|
||||
if conn == nil {
|
||||
addConnection = true
|
||||
continue
|
||||
}
|
||||
|
||||
// if the connection is closed, mark it by setting it to nil
|
||||
if conn.IsClosed() {
|
||||
p.lock.Lock()
|
||||
p.connections[idx] = nil
|
||||
p.lock.Unlock()
|
||||
|
||||
addConnection = true
|
||||
continue
|
||||
}
|
||||
|
||||
return conn, addConnection
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func (p *AmqpConnectionPool) connectionIndex(handle *ConnectionPoolHandle, iteration int) int {
|
||||
if iteration+handle.connectionOffset >= p.config.PoolSize {
|
||||
return (iteration + handle.connectionOffset) % p.config.PoolSize
|
||||
} else {
|
||||
return iteration + handle.connectionOffset
|
||||
}
|
||||
}
|
||||
578
audit/messaging/amqp_connection_pool_test.go
Normal file
578
audit/messaging/amqp_connection_pool_test.go
Normal file
|
|
@ -0,0 +1,578 @@
|
|||
package messaging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type connectionProviderMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (p *connectionProviderMock) NewAmqpConnection(config AmqpConnectionConfig, connectionName string) *AmqpConnection {
|
||||
args := p.Called(config, connectionName)
|
||||
return args.Get(0).(*AmqpConnection)
|
||||
}
|
||||
|
||||
var _ connectionProvider = (*connectionProviderMock)(nil)
|
||||
|
||||
func Test_AmqpConnectionPool_GetHandle(t *testing.T) {
|
||||
|
||||
t.Run("next handle", func(t *testing.T) {
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
handle := pool.NewHandle()
|
||||
assert.NotNil(t, handle)
|
||||
assert.Equal(t, 0, handle.connectionOffset)
|
||||
assert.Equal(t, 1, pool.handleOffset)
|
||||
})
|
||||
|
||||
t.Run("next handle high offset", func(t *testing.T) {
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 13,
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
handle := pool.NewHandle()
|
||||
assert.NotNil(t, handle)
|
||||
assert.Equal(t, 3, handle.connectionOffset)
|
||||
assert.Equal(t, 14, pool.handleOffset)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnectionPool_internalAddConnection(t *testing.T) {
|
||||
|
||||
t.Run("internal add connection", func(t *testing.T) {
|
||||
conn := &amqpConnMock{}
|
||||
|
||||
dialer := &amqpDialMock{}
|
||||
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
|
||||
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: dialer,
|
||||
}
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
err := pool.internalAddConnection()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, len(pool.connections))
|
||||
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
|
||||
dialer.AssertNumberOfCalls(t, "Dial", 1)
|
||||
})
|
||||
|
||||
t.Run("dialer error", func(t *testing.T) {
|
||||
conn := &amqpConnMock{}
|
||||
|
||||
dialer := &amqpDialMock{}
|
||||
var c *amqpConnMock = nil
|
||||
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error")).Once()
|
||||
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
|
||||
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: dialer,
|
||||
}
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
err := pool.internalAddConnection()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, len(pool.connections))
|
||||
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
|
||||
dialer.AssertNumberOfCalls(t, "Dial", 2)
|
||||
})
|
||||
|
||||
t.Run("repetitive dialer error", func(t *testing.T) {
|
||||
dialer := &amqpDialMock{}
|
||||
var c *amqpConnMock = nil
|
||||
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
|
||||
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: dialer,
|
||||
}
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
err := pool.internalAddConnection()
|
||||
assert.EqualError(t, err, "new connection: new internal connection: internal connect: dial: test error")
|
||||
|
||||
assert.Equal(t, 0, len(pool.connections))
|
||||
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1)
|
||||
dialer.AssertNumberOfCalls(t, "Dial", 2)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnectionPool_initializeConnections(t *testing.T) {
|
||||
|
||||
t.Run("initialize connections successfully", func(t *testing.T) {
|
||||
|
||||
conn := &amqpConnMock{}
|
||||
dialer := &amqpDialMock{}
|
||||
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
|
||||
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: dialer,
|
||||
}
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
err := pool.initializeConnections()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 5, len(pool.connections))
|
||||
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 5)
|
||||
})
|
||||
|
||||
t.Run("fail initialization of connections", func(t *testing.T) {
|
||||
|
||||
var c *amqpConnMock = nil
|
||||
failingDialer := &amqpDialMock{}
|
||||
failingDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
|
||||
|
||||
failingConnection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: failingDialer,
|
||||
}
|
||||
|
||||
conn := &amqpConnMock{}
|
||||
successfulDialer := &amqpDialMock{}
|
||||
successfulDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
|
||||
|
||||
successfulConnection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: successfulDialer,
|
||||
}
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(4)
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(failingConnection)
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
err := pool.initializeConnections()
|
||||
assert.EqualError(t, err, "new connection: new internal connection: internal connect: dial: test error")
|
||||
|
||||
assert.Equal(t, 4, len(pool.connections))
|
||||
connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 5)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnectionPool_Close(t *testing.T) {
|
||||
|
||||
t.Run("close connection successfully", func(t *testing.T) {
|
||||
// add 5 connections to the pool
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Close").Return(nil)
|
||||
|
||||
dialer := &amqpDialMock{}
|
||||
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil)
|
||||
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: dialer,
|
||||
}
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
err := pool.initializeConnections()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 5, len(pool.connections))
|
||||
|
||||
// close the pool
|
||||
err = pool.Close()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, len(pool.connections))
|
||||
for _, c := range pool.connections {
|
||||
assert.Nil(t, c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("close connection fail", func(t *testing.T) {
|
||||
// add 5 connections to the pool
|
||||
failingConn := &amqpConnMock{}
|
||||
failingConn.On("Close").Return(errors.New("test error"))
|
||||
|
||||
failingDialer := &amqpDialMock{}
|
||||
failingDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(failingConn, nil)
|
||||
|
||||
failingConnection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: failingDialer,
|
||||
}
|
||||
|
||||
successfulConn := &amqpConnMock{}
|
||||
successfulConn.On("Close").Return(nil)
|
||||
successfulDialer := &amqpDialMock{}
|
||||
successfulDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(successfulConn, nil)
|
||||
|
||||
successfulConnection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: successfulDialer,
|
||||
}
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(2)
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(failingConnection).Times(2)
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(1)
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
err := pool.initializeConnections()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 5, len(pool.connections))
|
||||
|
||||
// close the pool
|
||||
err = pool.Close()
|
||||
assert.EqualError(t, err, "pooled connection: internal close: connection close: test error\npooled connection: internal close: connection close: test error")
|
||||
assert.Equal(t, 5, len(pool.connections))
|
||||
for _, c := range pool.connections {
|
||||
assert.Nil(t, c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnectionPool_nextConnectionForHandle(t *testing.T) {
|
||||
channelReceiver := func(channel chan struct{}) <-chan struct{} {
|
||||
return channel
|
||||
}
|
||||
|
||||
newActiveConnection := func() *AmqpConnection {
|
||||
channel := make(chan struct{})
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
|
||||
return &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
newClosedConnection := func() *AmqpConnection {
|
||||
channel := make(chan struct{})
|
||||
close(channel)
|
||||
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
|
||||
return &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("next connection for requested handle", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
for i := 0; i < 5; i++ {
|
||||
connections = append(connections, newActiveConnection())
|
||||
}
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
}
|
||||
|
||||
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
|
||||
assert.NotNil(t, connection)
|
||||
assert.False(t, addConnection)
|
||||
})
|
||||
|
||||
t.Run("nil connection for requested handle", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
connections = append(connections, newActiveConnection())
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, newActiveConnection())
|
||||
connections = append(connections, newActiveConnection())
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
}
|
||||
|
||||
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
|
||||
assert.NotNil(t, connection)
|
||||
assert.True(t, addConnection)
|
||||
})
|
||||
|
||||
t.Run("closed connection for requested handle", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
connections = append(connections, newActiveConnection())
|
||||
connections = append(connections, newClosedConnection())
|
||||
connections = append(connections, newClosedConnection())
|
||||
connections = append(connections, newActiveConnection())
|
||||
connections = append(connections, newActiveConnection())
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
}
|
||||
|
||||
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
|
||||
assert.NotNil(t, connection)
|
||||
assert.True(t, addConnection)
|
||||
})
|
||||
|
||||
t.Run("no connection for requested handle", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
}
|
||||
|
||||
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1})
|
||||
assert.Nil(t, connection)
|
||||
assert.True(t, addConnection)
|
||||
})
|
||||
|
||||
t.Run("connection for requested handle with large index", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, newActiveConnection())
|
||||
connections = append(connections, nil)
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
}
|
||||
|
||||
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 23})
|
||||
assert.NotNil(t, connection)
|
||||
assert.False(t, addConnection)
|
||||
})
|
||||
|
||||
t.Run("connection for requested handle nil with large index", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, newActiveConnection())
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
}
|
||||
|
||||
connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 23})
|
||||
assert.NotNil(t, connection)
|
||||
assert.True(t, addConnection)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnectionPool_GetConnection(t *testing.T) {
|
||||
channelReceiver := func(channel chan struct{}) <-chan struct{} {
|
||||
return channel
|
||||
}
|
||||
|
||||
newActiveConnection := func() *AmqpConnection {
|
||||
channel := make(chan struct{})
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
|
||||
return &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("get connection for requested handle", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
for i := 0; i < 5; i++ {
|
||||
connections = append(connections, newActiveConnection())
|
||||
}
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
}
|
||||
|
||||
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, connection)
|
||||
assert.Equal(t, connections[1], connection)
|
||||
assert.Equal(t, 5, len(connections))
|
||||
})
|
||||
|
||||
t.Run("add connection if missing", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 5)
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(newActiveConnection())
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, connection)
|
||||
assert.Equal(t, connections[1], connection)
|
||||
assert.Equal(t, 5, len(connections))
|
||||
})
|
||||
|
||||
t.Run("add connection fails returns alternative connection", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
connections = append(connections, newActiveConnection())
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, newActiveConnection())
|
||||
connections = append(connections, newActiveConnection())
|
||||
connections = append(connections, newActiveConnection())
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
|
||||
dialer := &amqpDialMock{}
|
||||
var c *amqpConnMock = nil
|
||||
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, fmt.Errorf("dial error"))
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: dialer,
|
||||
}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, connection)
|
||||
assert.Nil(t, connections[1])
|
||||
assert.Equal(t, connections[2], connection)
|
||||
assert.Equal(t, 5, len(connections))
|
||||
})
|
||||
|
||||
t.Run("add connection fails", func(t *testing.T) {
|
||||
connections := make([]*AmqpConnection, 0)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
connections = append(connections, nil)
|
||||
|
||||
connectionProvider := &connectionProviderMock{}
|
||||
|
||||
dialer := &amqpDialMock{}
|
||||
var c *amqpConnMock = nil
|
||||
dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, fmt.Errorf("dial error"))
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
dialer: dialer,
|
||||
}
|
||||
connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection)
|
||||
|
||||
pool := AmqpConnectionPool{
|
||||
config: AmqpConnectionPoolConfig{PoolSize: 5},
|
||||
handleOffset: 0,
|
||||
lock: sync.RWMutex{},
|
||||
connections: connections,
|
||||
connectionProvider: connectionProvider,
|
||||
}
|
||||
|
||||
connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1})
|
||||
assert.EqualError(t, err, "renew connection: new internal connection: internal connect: dial: dial error")
|
||||
assert.Nil(t, connection)
|
||||
assert.Equal(t, 5, len(connections))
|
||||
})
|
||||
}
|
||||
306
audit/messaging/amqp_connection_test.go
Normal file
306
audit/messaging/amqp_connection_test.go
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
package messaging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/Azure/go-amqp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type amqpConnMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *amqpConnMock) Done() <-chan struct{} {
|
||||
args := m.Called()
|
||||
return args.Get(0).(<-chan struct{})
|
||||
}
|
||||
|
||||
func (m *amqpConnMock) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) {
|
||||
args := m.Called(ctx, opts)
|
||||
return args.Get(0).(amqpSession), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *amqpConnMock) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
var _ amqpConn = (*amqpConnMock)(nil)
|
||||
|
||||
type amqpDialMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *amqpDialMock) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) {
|
||||
args := m.Called(ctx, addr, opts)
|
||||
return args.Get(0).(amqpConn), args.Error(1)
|
||||
}
|
||||
|
||||
var _ amqpDial = (*amqpDialMock)(nil)
|
||||
|
||||
type amqpSessionMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *amqpSessionMock) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) {
|
||||
args := m.Called(ctx, target, opts)
|
||||
return args.Get(0).(amqpSender), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *amqpSessionMock) Close(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
var _ amqpSession = (*amqpSessionMock)(nil)
|
||||
|
||||
func Test_AmqpConnection_IsClosed(t *testing.T) {
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
channelReceiver := func(channel chan struct{}) <-chan struct{} {
|
||||
return channel
|
||||
}
|
||||
|
||||
t.Run("is closed - connection nil", func(t *testing.T) {
|
||||
assert.True(t, connection.IsClosed())
|
||||
})
|
||||
|
||||
t.Run("is closed", func(t *testing.T) {
|
||||
channel := make(chan struct{})
|
||||
close(channel)
|
||||
amqpConnMock := &amqpConnMock{}
|
||||
amqpConnMock.On("Done").Return(channelReceiver(channel))
|
||||
connection.conn = amqpConnMock
|
||||
|
||||
assert.True(t, connection.IsClosed())
|
||||
})
|
||||
|
||||
t.Run("is not closed", func(t *testing.T) {
|
||||
channel := make(chan struct{})
|
||||
amqpConnMock := &amqpConnMock{}
|
||||
amqpConnMock.On("Done").Return(channelReceiver(channel))
|
||||
connection.conn = amqpConnMock
|
||||
|
||||
assert.False(t, connection.IsClosed())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnection_Close(t *testing.T) {
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
t.Run("already closed", func(t *testing.T) {
|
||||
assert.NoError(t, connection.Close())
|
||||
})
|
||||
|
||||
t.Run("close error", func(t *testing.T) {
|
||||
err := errors.New("test error")
|
||||
|
||||
amqpConnMock := &amqpConnMock{}
|
||||
amqpConnMock.On("Close").Return(err)
|
||||
connection.conn = amqpConnMock
|
||||
|
||||
assert.EqualError(t, connection.Close(), "internal close: connection close: test error")
|
||||
assert.NotNil(t, connection.conn)
|
||||
amqpConnMock.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("close without error", func(t *testing.T) {
|
||||
amqpConnMock := &amqpConnMock{}
|
||||
amqpConnMock.On("Close").Return(nil)
|
||||
connection.conn = amqpConnMock
|
||||
|
||||
assert.Nil(t, connection.Close())
|
||||
assert.Nil(t, connection.conn)
|
||||
amqpConnMock.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnection_Connect(t *testing.T) {
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
t.Run("already connected", func(t *testing.T) {
|
||||
connection.conn = &amqpConnMock{}
|
||||
assert.NoError(t, connection.Connect())
|
||||
})
|
||||
|
||||
t.Run("dial error", func(t *testing.T) {
|
||||
connection.conn = nil
|
||||
connection.username = "user"
|
||||
connection.password = "pass"
|
||||
|
||||
amqpDialMock := &amqpDialMock{}
|
||||
var c *amqpConnMock = nil
|
||||
amqpDialMock.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error"))
|
||||
connection.dialer = amqpDialMock
|
||||
|
||||
assert.EqualError(t, connection.Connect(), "internal connect: dial: test error")
|
||||
assert.Nil(t, connection.conn)
|
||||
})
|
||||
|
||||
t.Run("connect without error", func(t *testing.T) {
|
||||
connection.conn = nil
|
||||
|
||||
amqpDialMock := &amqpDialMock{}
|
||||
amqpConn := &amqpConnMock{}
|
||||
amqpDialMock.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(amqpConn, nil)
|
||||
connection.dialer = amqpDialMock
|
||||
|
||||
assert.NoError(t, connection.Connect())
|
||||
assert.Equal(t, amqpConn, connection.conn)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnection_NewSender(t *testing.T) {
|
||||
connection := &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
}
|
||||
|
||||
channelReceiver := func(channel chan struct{}) <-chan struct{} {
|
||||
return channel
|
||||
}
|
||||
|
||||
t.Run("connection not initialized", func(t *testing.T) {
|
||||
sender, err := connection.NewSender(context.Background(), "topic")
|
||||
assert.EqualError(t, err, "connection is not initialized")
|
||||
assert.Nil(t, sender)
|
||||
})
|
||||
|
||||
t.Run("connection is closed", func(t *testing.T) {
|
||||
channel := make(chan struct{})
|
||||
close(channel)
|
||||
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done").Return(channelReceiver(channel))
|
||||
connection.conn = conn
|
||||
|
||||
sender, err := connection.NewSender(context.Background(), "topic")
|
||||
assert.EqualError(t, err, "amqp connection is closed")
|
||||
assert.Nil(t, sender)
|
||||
})
|
||||
|
||||
t.Run("session error", func(t *testing.T) {
|
||||
channel := make(chan struct{})
|
||||
|
||||
var session *amqpSessionMock = nil
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, errors.New("test error"))
|
||||
conn.On("Done").Return(channelReceiver(channel))
|
||||
connection.conn = conn
|
||||
|
||||
sender, err := connection.NewSender(context.Background(), "topic")
|
||||
assert.EqualError(t, err, "new session: test error")
|
||||
assert.Nil(t, sender)
|
||||
})
|
||||
|
||||
t.Run("sender error", func(t *testing.T) {
|
||||
channel := make(chan struct{})
|
||||
|
||||
sessionMock := &amqpSessionMock{}
|
||||
var amqpSender *amqp.Sender = nil
|
||||
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, errors.New("test error"))
|
||||
sessionMock.On("Close", mock.Anything).Return(nil)
|
||||
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done").Return(channelReceiver(channel))
|
||||
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
|
||||
connection.conn = conn
|
||||
|
||||
sender, err := connection.NewSender(context.Background(), "topic")
|
||||
assert.EqualError(t, err, "new internal sender: test error")
|
||||
assert.Nil(t, sender)
|
||||
})
|
||||
|
||||
t.Run("session close error", func(t *testing.T) {
|
||||
channel := make(chan struct{})
|
||||
|
||||
sessionMock := &amqpSessionMock{}
|
||||
var amqpSender *amqp.Sender = nil
|
||||
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, errors.New("test error"))
|
||||
sessionMock.On("Close", mock.Anything).Return(errors.New("close error"))
|
||||
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done").Return(channelReceiver(channel))
|
||||
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
|
||||
connection.conn = conn
|
||||
|
||||
sender, err := connection.NewSender(context.Background(), "topic")
|
||||
assert.EqualError(t, err, "new internal sender: test error\nclose session: close error")
|
||||
assert.Nil(t, sender)
|
||||
})
|
||||
|
||||
t.Run("get sender", func(t *testing.T) {
|
||||
channel := make(chan struct{})
|
||||
|
||||
amqpSender := &amqp.Sender{}
|
||||
sessionMock := &amqpSessionMock{}
|
||||
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, nil)
|
||||
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done").Return(channelReceiver(channel))
|
||||
conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil)
|
||||
connection.conn = conn
|
||||
|
||||
sender, err := connection.NewSender(context.Background(), "topic")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, sender)
|
||||
assert.Equal(t, amqpSender, sender.sender)
|
||||
assert.Equal(t, sessionMock, sender.session)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpConnection_NewAmqpConnection(t *testing.T) {
|
||||
config := AmqpConnectionConfig{
|
||||
BrokerUrl: "brokerUrl",
|
||||
Username: "username",
|
||||
Password: "password",
|
||||
}
|
||||
connection := NewAmqpConnection(config, "connectionName")
|
||||
assert.NotNil(t, connection)
|
||||
assert.Equal(t, connection.connectionName, "connectionName")
|
||||
assert.Equal(t, connection.brokerUrl, "brokerUrl")
|
||||
assert.Equal(t, connection.username, "username")
|
||||
assert.Equal(t, connection.password, "password")
|
||||
assert.NotNil(t, connection.dialer)
|
||||
}
|
||||
|
||||
func Test_As(t *testing.T) {
|
||||
|
||||
t.Run("error", func(t *testing.T) {
|
||||
value, err := As[amqp.Message](nil, errors.New("test error"))
|
||||
assert.EqualError(t, err, "test error")
|
||||
assert.Nil(t, value)
|
||||
})
|
||||
|
||||
t.Run("value nil", func(t *testing.T) {
|
||||
value, err := As[amqp.Message](nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, value)
|
||||
})
|
||||
|
||||
t.Run("value not not type", func(t *testing.T) {
|
||||
value, err := As[amqp.Message](struct{}{}, nil)
|
||||
assert.EqualError(t, err, "could not cast value: struct {}")
|
||||
assert.Nil(t, value)
|
||||
})
|
||||
|
||||
t.Run("cast", func(t *testing.T) {
|
||||
var sessionAny any = &amqpSessionMock{}
|
||||
value, err := As[amqpSessionMock](sessionAny, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, value)
|
||||
})
|
||||
}
|
||||
78
audit/messaging/amqp_sender_session.go
Normal file
78
audit/messaging/amqp_sender_session.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package messaging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Azure/go-amqp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type amqpSender interface {
|
||||
Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
type AmqpSenderSession struct {
|
||||
session amqpSession
|
||||
sender amqpSender
|
||||
}
|
||||
|
||||
func (s *AmqpSenderSession) Send(
|
||||
topic string,
|
||||
data [][]byte,
|
||||
contentType string,
|
||||
applicationProperties map[string]any,
|
||||
) error {
|
||||
// check topic name
|
||||
if !strings.HasPrefix(topic, AmqpTopicPrefix) {
|
||||
return fmt.Errorf(
|
||||
"topic %q name lacks mandatory prefix %q",
|
||||
topic,
|
||||
AmqpTopicPrefix,
|
||||
)
|
||||
}
|
||||
|
||||
if contentType == "" {
|
||||
return errors.New("content-type is required")
|
||||
}
|
||||
|
||||
// prepare the amqp message
|
||||
message := amqp.Message{
|
||||
Header: &amqp.MessageHeader{
|
||||
Durable: true,
|
||||
},
|
||||
Properties: &amqp.MessageProperties{
|
||||
To: &topic,
|
||||
ContentType: &contentType,
|
||||
},
|
||||
ApplicationProperties: applicationProperties,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
// send
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancelFn()
|
||||
return s.sender.Send(ctx, &message, nil)
|
||||
}
|
||||
|
||||
func (s *AmqpSenderSession) Close() error {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancelFn()
|
||||
|
||||
var closeErrors []error
|
||||
senderErr := s.sender.Close(ctx)
|
||||
if senderErr != nil {
|
||||
closeErrors = append(closeErrors, senderErr)
|
||||
}
|
||||
sessionErr := s.session.Close(ctx)
|
||||
if sessionErr != nil {
|
||||
closeErrors = append(closeErrors, sessionErr)
|
||||
}
|
||||
|
||||
if len(closeErrors) > 0 {
|
||||
return errors.Join(closeErrors...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
186
audit/messaging/amqp_sender_session_test.go
Normal file
186
audit/messaging/amqp_sender_session_test.go
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
package messaging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/Azure/go-amqp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type amqpSenderMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *amqpSenderMock) Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error {
|
||||
return m.Called(ctx, msg, opts).Error(0)
|
||||
}
|
||||
|
||||
func (m *amqpSenderMock) Close(ctx context.Context) error {
|
||||
return m.Called(ctx).Error(0)
|
||||
}
|
||||
|
||||
var _ amqpSender = (*amqpSenderMock)(nil)
|
||||
|
||||
func Test_AmqpSenderSession_Close(t *testing.T) {
|
||||
|
||||
t.Run("close without errors", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Close", mock.Anything).Return(nil)
|
||||
session := &amqpSessionMock{}
|
||||
session.On("Close", mock.Anything).Return(nil)
|
||||
|
||||
senderSession := &AmqpSenderSession{
|
||||
sender: sender,
|
||||
session: session,
|
||||
}
|
||||
|
||||
err := senderSession.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
sender.AssertNumberOfCalls(t, "Close", 1)
|
||||
session.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("close with sender error", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Close", mock.Anything).Return(errors.New("sender error"))
|
||||
session := &amqpSessionMock{}
|
||||
session.On("Close", mock.Anything).Return(nil)
|
||||
|
||||
senderSession := &AmqpSenderSession{
|
||||
sender: sender,
|
||||
session: session,
|
||||
}
|
||||
|
||||
err := senderSession.Close()
|
||||
assert.EqualError(t, err, "sender error")
|
||||
|
||||
sender.AssertNumberOfCalls(t, "Close", 1)
|
||||
session.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("close with session error", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Close", mock.Anything).Return(nil)
|
||||
session := &amqpSessionMock{}
|
||||
session.On("Close", mock.Anything).Return(errors.New("session error"))
|
||||
|
||||
senderSession := &AmqpSenderSession{
|
||||
sender: sender,
|
||||
session: session,
|
||||
}
|
||||
|
||||
err := senderSession.Close()
|
||||
assert.EqualError(t, err, "session error")
|
||||
|
||||
sender.AssertNumberOfCalls(t, "Close", 1)
|
||||
session.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("close with sender and session error", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Close", mock.Anything).Return(errors.New("sender error"))
|
||||
session := &amqpSessionMock{}
|
||||
session.On("Close", mock.Anything).Return(errors.New("session error"))
|
||||
|
||||
senderSession := &AmqpSenderSession{
|
||||
sender: sender,
|
||||
session: session,
|
||||
}
|
||||
|
||||
err := senderSession.Close()
|
||||
assert.EqualError(t, err, "sender error\nsession error")
|
||||
|
||||
sender.AssertNumberOfCalls(t, "Close", 1)
|
||||
session.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpSenderSession_Send(t *testing.T) {
|
||||
|
||||
t.Run("invalid topic name", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
session := &amqpSessionMock{}
|
||||
|
||||
senderSession := &AmqpSenderSession{
|
||||
sender: sender,
|
||||
session: session,
|
||||
}
|
||||
|
||||
data := [][]byte{[]byte("data")}
|
||||
err := senderSession.Send("invalid", data, "application/json", map[string]interface{}{})
|
||||
assert.EqualError(t, err, "topic \"invalid\" name lacks mandatory prefix \"topic://\"")
|
||||
})
|
||||
|
||||
t.Run("content type missing", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
session := &amqpSessionMock{}
|
||||
|
||||
senderSession := &AmqpSenderSession{
|
||||
sender: sender,
|
||||
session: session,
|
||||
}
|
||||
|
||||
data := [][]byte{[]byte("data")}
|
||||
err := senderSession.Send("topic://some/name", data, "", map[string]interface{}{})
|
||||
assert.EqualError(t, err, "content-type is required")
|
||||
})
|
||||
|
||||
t.Run("send", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
session := &amqpSessionMock{}
|
||||
|
||||
senderSession := &AmqpSenderSession{
|
||||
sender: sender,
|
||||
session: session,
|
||||
}
|
||||
|
||||
data := [][]byte{[]byte("data")}
|
||||
applicationProperties := map[string]interface{}{}
|
||||
applicationProperties["key"] = "value"
|
||||
err := senderSession.Send("topic://some/name", data, "application/json", applicationProperties)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sender.AssertNumberOfCalls(t, "Send", 1)
|
||||
calls := sender.Calls
|
||||
assert.Equal(t, 1, len(calls))
|
||||
|
||||
ctx, isCtx := calls[0].Arguments[0].(context.Context)
|
||||
assert.True(t, isCtx)
|
||||
assert.NotNil(t, ctx)
|
||||
|
||||
message, isMsg := calls[0].Arguments[1].(*amqp.Message)
|
||||
assert.True(t, isMsg)
|
||||
assert.True(t, message.Header.Durable)
|
||||
assert.Equal(t, "topic://some/name", *message.Properties.To)
|
||||
assert.Equal(t, "application/json", *message.Properties.ContentType)
|
||||
assert.Equal(t, applicationProperties, message.ApplicationProperties)
|
||||
assert.Equal(t, data, message.Data)
|
||||
|
||||
senderOptions, isSenderOptions := calls[0].Arguments[2].(*amqp.SendOptions)
|
||||
assert.True(t, isSenderOptions)
|
||||
assert.Nil(t, senderOptions)
|
||||
})
|
||||
|
||||
t.Run("send fails", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("send fail"))
|
||||
session := &amqpSessionMock{}
|
||||
|
||||
senderSession := &AmqpSenderSession{
|
||||
sender: sender,
|
||||
session: session,
|
||||
}
|
||||
|
||||
data := [][]byte{[]byte("data")}
|
||||
applicationProperties := map[string]interface{}{}
|
||||
applicationProperties["key"] = "value"
|
||||
|
||||
err := senderSession.Send("topic://some/name", data, "application/json", applicationProperties)
|
||||
assert.EqualError(t, err, "send fail")
|
||||
sender.AssertNumberOfCalls(t, "Send", 1)
|
||||
})
|
||||
}
|
||||
|
|
@ -2,19 +2,13 @@ package messaging
|
|||
|
||||
import (
|
||||
"context"
|
||||
"dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/log"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/log"
|
||||
"github.com/Azure/go-amqp"
|
||||
)
|
||||
|
||||
// Default connection timeout for the AMQP connection
|
||||
const connectionTimeoutSeconds = 10
|
||||
|
||||
// Api is an abstraction for a messaging system that can be used to send
|
||||
// audit logs to the audit log system.
|
||||
type Api interface {
|
||||
|
|
@ -38,203 +32,122 @@ type Api interface {
|
|||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
// MutexApi is wrapper around an API implementation that controls mutual exclusive access to the api.
|
||||
type MutexApi struct {
|
||||
mutex sync.Mutex
|
||||
api Api
|
||||
}
|
||||
|
||||
var _ Api = &MutexApi{}
|
||||
|
||||
func NewMutexApi(api Api) (Api, error) {
|
||||
if api == nil {
|
||||
return nil, errors.New("api is nil")
|
||||
}
|
||||
mutexApi := MutexApi{
|
||||
mutex: sync.Mutex{},
|
||||
api: api,
|
||||
}
|
||||
|
||||
var genericApi Api = &mutexApi
|
||||
return genericApi, nil
|
||||
}
|
||||
|
||||
// Send implements Api.Send
|
||||
func (m *MutexApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.api.Send(ctx, topic, data, contentType, applicationProperties)
|
||||
}
|
||||
|
||||
func (m *MutexApi) Close(ctx context.Context) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.api.Close(ctx)
|
||||
}
|
||||
|
||||
// AmqpConfig provides AMQP connection related parameters.
|
||||
type AmqpConfig struct {
|
||||
URL string
|
||||
User string
|
||||
Password string
|
||||
}
|
||||
|
||||
// AmqpSession is an abstraction providing a subset of the methods of amqp.Session
|
||||
type AmqpSession interface {
|
||||
NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error)
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
type AmqpSessionWrapper struct {
|
||||
session *amqp.Session
|
||||
}
|
||||
|
||||
func (w AmqpSessionWrapper) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) {
|
||||
return w.session.NewSender(ctx, target, opts)
|
||||
}
|
||||
|
||||
func (w AmqpSessionWrapper) Close(ctx context.Context) error {
|
||||
return w.session.Close(ctx)
|
||||
}
|
||||
|
||||
// AmqpSender is an abstraction providing a subset of the methods of amqp.Sender
|
||||
type AmqpSender interface {
|
||||
Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
// AmqpApi implements Api.
|
||||
type AmqpApi struct {
|
||||
config AmqpConfig
|
||||
connection *amqp.Conn
|
||||
session AmqpSession
|
||||
config AmqpConnectionPoolConfig
|
||||
connection *AmqpConnection
|
||||
connectionPool ConnectionPool
|
||||
connectionPoolHandle *ConnectionPoolHandle
|
||||
senderCache map[string]*AmqpSenderSession
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
var _ Api = &AmqpApi{}
|
||||
|
||||
func NewAmqpApi(amqpConfig AmqpConfig) (Api, error) {
|
||||
amqpApi := &AmqpApi{config: amqpConfig}
|
||||
|
||||
if err := amqpApi.connect(); err != nil {
|
||||
return nil, fmt.Errorf("connect to broker: %w", err)
|
||||
}
|
||||
|
||||
return amqpApi, nil
|
||||
}
|
||||
|
||||
// connect opens a new connection and session to the AMQP messaging system.
|
||||
// The connection attempt will be cancelled after connectionTimeoutSeconds.
|
||||
func (a *AmqpApi) connect() error {
|
||||
log.AuditLogger.Info("connecting to audit messaging system")
|
||||
|
||||
// Set credentials if specified
|
||||
auth := amqp.SASLTypeAnonymous()
|
||||
|
||||
if a.config.User != "" && a.config.Password != "" {
|
||||
auth = amqp.SASLTypePlain(a.config.User, a.config.Password)
|
||||
log.AuditLogger.Info("using username and password for messaging")
|
||||
} else {
|
||||
log.AuditLogger.Warn("using anonymous messaging!")
|
||||
}
|
||||
|
||||
options := &amqp.ConnOptions{
|
||||
SASLType: auth,
|
||||
}
|
||||
|
||||
// Create new context with timeout for the connection initialization
|
||||
subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Initialize connection
|
||||
conn, err := amqp.Dial(subCtx, a.config.URL, options)
|
||||
func NewAmqpApi(amqpConfig AmqpConnectionPoolConfig) (Api, error) {
|
||||
connectionPool, err := NewAmqpConnectionPool(amqpConfig, "sdk")
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial connection to broker: %w", err)
|
||||
}
|
||||
a.connection = conn
|
||||
|
||||
// Initialize session
|
||||
session, err := conn.NewSession(context.Background(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
return nil, fmt.Errorf("new amqp connection pool: %w", err)
|
||||
}
|
||||
|
||||
var amqpSession AmqpSession = &AmqpSessionWrapper{session: session}
|
||||
a.session = amqpSession
|
||||
amqpApi := &AmqpApi{config: amqpConfig,
|
||||
connectionPool: connectionPool,
|
||||
connectionPoolHandle: connectionPool.NewHandle(),
|
||||
senderCache: make(map[string]*AmqpSenderSession),
|
||||
}
|
||||
|
||||
return nil
|
||||
var messagingApi Api = amqpApi
|
||||
return messagingApi, nil
|
||||
}
|
||||
|
||||
// Send implements Api.Send.
|
||||
// If errors occur the connection to the messaging system will be closed and re-established.
|
||||
func (a *AmqpApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
|
||||
err := a.trySend(ctx, topic, data, contentType, applicationProperties)
|
||||
if err == nil {
|
||||
func (a *AmqpApi) Send(_ context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
|
||||
|
||||
// create or get sender from cache
|
||||
var sender = a.senderFromCache(topic)
|
||||
if sender == nil {
|
||||
if err := a.newSender(topic); err != nil {
|
||||
return err
|
||||
}
|
||||
sender = a.senderFromCache(topic)
|
||||
}
|
||||
|
||||
// first attempt to send
|
||||
var sendErr error
|
||||
wrappedData := [][]byte{data}
|
||||
if err := sender.Send(topic, wrappedData, contentType, applicationProperties); err != nil {
|
||||
sendErr = fmt.Errorf("send: %w", err)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Drop the current sender, as it cannot connect to the broker anymore
|
||||
log.AuditLogger.Error("message sender error, recreating", err)
|
||||
// renew sender
|
||||
if err := a.newSender(topic); err != nil {
|
||||
return errors.Join(sendErr, err)
|
||||
}
|
||||
sender = a.senderFromCache(topic)
|
||||
|
||||
err = a.resetConnection(ctx)
|
||||
// retry send
|
||||
if err := sender.Send(topic, wrappedData, contentType, applicationProperties); err != nil {
|
||||
return errors.Join(sendErr, fmt.Errorf("retry send: %w", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AmqpApi) senderFromCache(topic string) *AmqpSenderSession {
|
||||
a.lock.RLock()
|
||||
defer a.lock.RUnlock()
|
||||
return a.senderCache[topic]
|
||||
}
|
||||
|
||||
func (a *AmqpApi) newSender(topic string) error {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
connectionIsClosed := a.connection == nil || a.connection.IsClosed()
|
||||
if connectionIsClosed {
|
||||
connection, err := a.connectionPool.GetConnection(a.connectionPoolHandle)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reset connection: %w", err)
|
||||
return fmt.Errorf("get connection: %w", err)
|
||||
}
|
||||
a.connection = connection
|
||||
}
|
||||
|
||||
return a.trySend(ctx, topic, data, contentType, applicationProperties)
|
||||
}
|
||||
|
||||
// trySend actually sends the given data as amqp.Message to the messaging system.
|
||||
func (a *AmqpApi) trySend(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error {
|
||||
if !strings.HasPrefix(topic, AmqpTopicPrefix) {
|
||||
return fmt.Errorf(
|
||||
"topic %q name lacks mandatory prefix %q",
|
||||
topic,
|
||||
AmqpTopicPrefix,
|
||||
)
|
||||
}
|
||||
|
||||
sender, err := a.session.NewSender(ctx, topic, nil)
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
sender, err := a.connection.NewSender(ctx, topic)
|
||||
cancelFn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new sender: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := sender.Close(ctx); err != nil {
|
||||
log.AuditLogger.Error("failed to close session sender", err)
|
||||
}
|
||||
}()
|
||||
|
||||
bytes := [][]byte{data}
|
||||
message := amqp.Message{
|
||||
Header: &amqp.MessageHeader{
|
||||
Durable: true,
|
||||
},
|
||||
Properties: &amqp.MessageProperties{
|
||||
To: &topic,
|
||||
ContentType: &contentType,
|
||||
},
|
||||
ApplicationProperties: applicationProperties,
|
||||
Data: bytes,
|
||||
}
|
||||
|
||||
err = sender.Send(ctx, &message, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send message: %w", err)
|
||||
}
|
||||
|
||||
a.senderCache[topic] = sender
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetConnection closes the current session and connection and reconnects to the messaging system.
|
||||
func (a *AmqpApi) resetConnection(ctx context.Context) error {
|
||||
if err := a.Close(ctx); err != nil {
|
||||
log.AuditLogger.Error("failed to close audit messaging connection", err)
|
||||
}
|
||||
|
||||
return a.connect()
|
||||
}
|
||||
|
||||
// Close implements Api.Close
|
||||
func (a *AmqpApi) Close(ctx context.Context) error {
|
||||
log.AuditLogger.Info("close audit messaging connection")
|
||||
return errors.Join(a.session.Close(ctx), a.connection.Close())
|
||||
func (a *AmqpApi) Close(_ context.Context) error {
|
||||
log.AuditLogger.Info("close audit amqp connection pool")
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
// cached senders
|
||||
var closeErrors []error
|
||||
for _, session := range a.senderCache {
|
||||
if err := session.Close(); err != nil {
|
||||
closeErrors = append(closeErrors, fmt.Errorf("close session: %w", err))
|
||||
}
|
||||
}
|
||||
clear(a.senderCache)
|
||||
|
||||
// pool
|
||||
if err := a.connectionPool.Close(); err != nil {
|
||||
closeErrors = append(closeErrors, fmt.Errorf("close pool: %w", err))
|
||||
}
|
||||
|
||||
if len(closeErrors) > 0 {
|
||||
return fmt.Errorf("close: %w", errors.Join(closeErrors...))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,50 +4,38 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/go-amqp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AmqpSessionMock struct {
|
||||
type connectionPoolMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *AmqpSessionMock) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) {
|
||||
args := m.Called(ctx, target, opts)
|
||||
var sender AmqpSender = nil
|
||||
if args.Get(0) != nil {
|
||||
sender = args.Get(0).(AmqpSender)
|
||||
}
|
||||
err := args.Error(1)
|
||||
return sender, err
|
||||
func (m *connectionPoolMock) Close() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func (m *AmqpSessionMock) Close(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
func (m *connectionPoolMock) NewHandle() *ConnectionPoolHandle {
|
||||
return m.Called().Get(0).(*ConnectionPoolHandle)
|
||||
}
|
||||
|
||||
type AmqpSenderMock struct {
|
||||
mock.Mock
|
||||
func (m *connectionPoolMock) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) {
|
||||
return m.Called(handle).Get(0).(*AmqpConnection), m.Called(handle).Error(1)
|
||||
}
|
||||
|
||||
func (m *AmqpSenderMock) Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error {
|
||||
args := m.Called(ctx, msg, opts)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *AmqpSenderMock) Close(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
var _ ConnectionPool = (*connectionPoolMock)(nil)
|
||||
|
||||
func Test_NewAmqpMessagingApi(t *testing.T) {
|
||||
_, err := NewAmqpApi(AmqpConfig{URL: "not-handled-protocol://localhost:5672"})
|
||||
assert.EqualError(t, err, "connect to broker: dial connection to broker: unsupported scheme \"not-handled-protocol\"")
|
||||
_, err := NewAmqpApi(
|
||||
AmqpConnectionPoolConfig{
|
||||
Parameters: AmqpConnectionConfig{BrokerUrl: "not-handled-protocol://localhost:5672"},
|
||||
PoolSize: 1,
|
||||
})
|
||||
assert.EqualError(t, err, "new amqp connection pool: initialize connections: new connection: new internal connection: internal connect: dial: unsupported scheme \"not-handled-protocol\"")
|
||||
}
|
||||
|
||||
func Test_AmqpMessagingApi_Send(t *testing.T) {
|
||||
|
|
@ -63,121 +51,359 @@ func Test_AmqpMessagingApi_Send(t *testing.T) {
|
|||
t.Run("Missing topic prefix", func(t *testing.T) {
|
||||
defer solaceContainer.StopOnError()
|
||||
|
||||
api, err := NewAmqpApi(AmqpConfig{URL: solaceContainer.AmqpConnectionString})
|
||||
api, err := NewAmqpApi(AmqpConnectionPoolConfig{
|
||||
Parameters: AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
|
||||
PoolSize: 1,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = api.Send(ctx, "topic-name", []byte{}, "application/json", make(map[string]any))
|
||||
assert.EqualError(t, err, "topic \"topic-name\" name lacks mandatory prefix \"topic://\"")
|
||||
assert.EqualError(t, err, "send: topic \"topic-name\" name lacks mandatory prefix \"topic://\"\nretry send: topic \"topic-name\" name lacks mandatory prefix \"topic://\"")
|
||||
})
|
||||
|
||||
t.Run("Close connection without errors", func(t *testing.T) {
|
||||
t.Run("send successfully", func(t *testing.T) {
|
||||
defer solaceContainer.StopOnError()
|
||||
|
||||
// Initialize the solace queue
|
||||
topicSubscriptionTopicPattern := "auditlog/>"
|
||||
queueName := "close-connection-without-error"
|
||||
queueName := "send-successfully"
|
||||
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
|
||||
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
|
||||
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-close-connection")
|
||||
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-send-successfully")
|
||||
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
|
||||
|
||||
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}}
|
||||
err := api.connect()
|
||||
api, err := NewAmqpApi(AmqpConnectionPoolConfig{
|
||||
Parameters: AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString},
|
||||
PoolSize: 1,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
data := []byte("data")
|
||||
applicationProperties := make(map[string]interface{})
|
||||
applicationProperties["key"] = "value"
|
||||
|
||||
err = api.Send(ctx, topicName, data, "application/json", applicationProperties)
|
||||
assert.NoError(t, err)
|
||||
|
||||
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "data", string(message.Data[0]))
|
||||
assert.Equal(t, topicName, *message.Properties.To)
|
||||
assert.Equal(t, "application/json", *message.Properties.ContentType)
|
||||
assert.Equal(t, applicationProperties, message.ApplicationProperties)
|
||||
|
||||
err = api.Close(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("New sender call returns error", func(t *testing.T) {
|
||||
defer solaceContainer.StopOnError()
|
||||
func Test_AmqpMessagingApi_Send_Special_Cases(t *testing.T) {
|
||||
|
||||
// Initialize the solace queue
|
||||
topicSubscriptionTopicPattern := "auditlog/>"
|
||||
queueName := "messaging-new-sender"
|
||||
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
|
||||
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
|
||||
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-no-new-sender")
|
||||
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
|
||||
channelReceiver := func(channel chan struct{}) <-chan struct{} {
|
||||
return channel
|
||||
}
|
||||
|
||||
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}}
|
||||
err := api.connect()
|
||||
newActiveConnection := func() *AmqpConnection {
|
||||
channel := make(chan struct{})
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
|
||||
return &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
newClosedConnection := func() *AmqpConnection {
|
||||
channel := make(chan struct{})
|
||||
close(channel)
|
||||
|
||||
conn := &amqpConnMock{}
|
||||
conn.On("Done", mock.Anything).Return(channelReceiver(channel))
|
||||
return &AmqpConnection{
|
||||
connectionName: "test",
|
||||
lock: sync.RWMutex{},
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("connection nil sender nil", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
session := &amqpSessionMock{}
|
||||
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
|
||||
|
||||
connection := newActiveConnection()
|
||||
conn := connection.conn.(*amqpConnMock)
|
||||
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
|
||||
|
||||
pool := &connectionPoolMock{}
|
||||
pool.On("GetConnection", mock.Anything).Return(connection, nil)
|
||||
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connectionPool: pool,
|
||||
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
|
||||
senderCache: make(map[string]*AmqpSenderSession),
|
||||
}
|
||||
|
||||
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedError := errors.New("expected error")
|
||||
|
||||
// Set mock session
|
||||
sessionMock := AmqpSessionMock{}
|
||||
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(nil, expectedError)
|
||||
sessionMock.On("Close", mock.Anything).Return(nil)
|
||||
|
||||
var amqpSession AmqpSession = &sessionMock
|
||||
api.session = amqpSession
|
||||
|
||||
// It's expected that the test succeeds.
|
||||
// First the session is closed as it returns the expected error
|
||||
// Then the retry mechanism restarts the connection and successfully sends the data
|
||||
value := "test"
|
||||
err = api.Send(ctx, topicName, []byte(value), "application/json", make(map[string]any))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check that the mock was called
|
||||
assert.True(t, sessionMock.AssertNumberOfCalls(t, "NewSender", 1))
|
||||
assert.True(t, sessionMock.AssertNumberOfCalls(t, "Close", 1))
|
||||
|
||||
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, value, string(message.Data[0]))
|
||||
assert.Equal(t, topicName, *message.Properties.To)
|
||||
sender.AssertNumberOfCalls(t, "Send", 1)
|
||||
session.AssertNumberOfCalls(t, "NewSender", 1)
|
||||
pool.AssertNumberOfCalls(t, "GetConnection", 2)
|
||||
})
|
||||
|
||||
t.Run("Send call on sender returns error", func(t *testing.T) {
|
||||
defer solaceContainer.StopOnError()
|
||||
t.Run("connection closed sender nil", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
// Initialize the solace queue
|
||||
topicSubscriptionTopicPattern := "auditlog/>"
|
||||
queueName := "messaging-sender-error"
|
||||
assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName))
|
||||
assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern))
|
||||
topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-sender-error")
|
||||
assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName))
|
||||
session := &amqpSessionMock{}
|
||||
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
|
||||
|
||||
api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}}
|
||||
err := api.connect()
|
||||
connection := newActiveConnection()
|
||||
conn := connection.conn.(*amqpConnMock)
|
||||
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
|
||||
|
||||
pool := &connectionPoolMock{}
|
||||
pool.On("GetConnection", mock.Anything).Return(connection, nil)
|
||||
|
||||
closedConnection := newClosedConnection()
|
||||
closedConnMock := closedConnection.conn.(*amqpConnMock)
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connection: closedConnection,
|
||||
connectionPool: pool,
|
||||
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
|
||||
senderCache: make(map[string]*AmqpSenderSession),
|
||||
}
|
||||
|
||||
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedError := errors.New("expected error")
|
||||
sender.AssertNumberOfCalls(t, "Send", 1)
|
||||
session.AssertNumberOfCalls(t, "NewSender", 1)
|
||||
pool.AssertNumberOfCalls(t, "GetConnection", 2)
|
||||
closedConnMock.AssertNumberOfCalls(t, "Done", 1)
|
||||
})
|
||||
|
||||
// Instantiate mock sender
|
||||
senderMock := AmqpSenderMock{}
|
||||
senderMock.On("Send", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError)
|
||||
senderMock.On("Close", mock.Anything).Return(nil)
|
||||
var amqpSender AmqpSender = &senderMock
|
||||
t.Run("connection nil get connection fail", func(t *testing.T) {
|
||||
var connection *AmqpConnection = nil
|
||||
|
||||
// Set mock session
|
||||
sessionMock := AmqpSessionMock{}
|
||||
sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(&amqpSender, nil)
|
||||
sessionMock.On("Close", mock.Anything).Return(nil)
|
||||
pool := &connectionPoolMock{}
|
||||
pool.On("GetConnection", mock.Anything).Return(connection, errors.New("connection error"))
|
||||
|
||||
var amqpSession AmqpSession = &sessionMock
|
||||
api.session = amqpSession
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connectionPool: pool,
|
||||
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
|
||||
senderCache: make(map[string]*AmqpSenderSession),
|
||||
}
|
||||
|
||||
// It's expected that the test succeeds.
|
||||
// First the sender and session are closed as the sender returns the expected error
|
||||
// Then the retry mechanism restarts the connection and successfully sends the data
|
||||
value := "test"
|
||||
err = api.Send(ctx, topicName, []byte(value), "application/json", make(map[string]any))
|
||||
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
|
||||
assert.EqualError(t, err, "get connection: connection error")
|
||||
|
||||
pool.AssertNumberOfCalls(t, "GetConnection", 2)
|
||||
})
|
||||
|
||||
t.Run("connection active sender nil", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
session := &amqpSessionMock{}
|
||||
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
|
||||
|
||||
connection := newActiveConnection()
|
||||
conn := connection.conn.(*amqpConnMock)
|
||||
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
|
||||
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connection: connection,
|
||||
senderCache: make(map[string]*AmqpSenderSession),
|
||||
}
|
||||
|
||||
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check that the mocks were called
|
||||
assert.True(t, sessionMock.AssertNumberOfCalls(t, "NewSender", 1))
|
||||
assert.True(t, sessionMock.AssertNumberOfCalls(t, "Close", 1))
|
||||
assert.True(t, senderMock.AssertNumberOfCalls(t, "Send", 1))
|
||||
assert.True(t, senderMock.AssertNumberOfCalls(t, "Close", 1))
|
||||
sender.AssertNumberOfCalls(t, "Send", 1)
|
||||
session.AssertNumberOfCalls(t, "NewSender", 1)
|
||||
})
|
||||
|
||||
message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true)
|
||||
t.Run("connection active new sender fail", func(t *testing.T) {
|
||||
var sender *amqpSenderMock = nil
|
||||
|
||||
session := &amqpSessionMock{}
|
||||
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, errors.New("new sender error"))
|
||||
session.On("Close", mock.Anything).Return(nil)
|
||||
|
||||
connection := newActiveConnection()
|
||||
conn := connection.conn.(*amqpConnMock)
|
||||
conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil)
|
||||
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connection: connection,
|
||||
senderCache: make(map[string]*AmqpSenderSession),
|
||||
}
|
||||
|
||||
err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any))
|
||||
assert.EqualError(t, err, "new sender: new internal sender: new sender error")
|
||||
|
||||
session.AssertNumberOfCalls(t, "NewSender", 1)
|
||||
session.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("connection active sender set", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
topic := "topic://some-topic"
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connection: newActiveConnection(),
|
||||
senderCache: map[string]*AmqpSenderSession{topic: {sender: sender}},
|
||||
}
|
||||
|
||||
err := amqpApi.Send(context.Background(), topic, []byte("data"), "application/json", make(map[string]any))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, value, string(message.Data[0]))
|
||||
assert.Equal(t, topicName, *message.Properties.To)
|
||||
|
||||
sender.AssertNumberOfCalls(t, "Send", 1)
|
||||
})
|
||||
|
||||
t.Run("send fail", func(t *testing.T) {
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("send error"))
|
||||
|
||||
session := &amqpSessionMock{}
|
||||
session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil)
|
||||
|
||||
topic := "topic://some-topic"
|
||||
connection := newActiveConnection()
|
||||
connection.conn.(*amqpConnMock).On("NewSession", mock.Anything, mock.Anything, mock.Anything).Return(session, nil)
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connection: connection,
|
||||
senderCache: map[string]*AmqpSenderSession{topic: {sender: sender}},
|
||||
}
|
||||
|
||||
err := amqpApi.Send(context.Background(), topic, []byte("data"), "application/json", make(map[string]any))
|
||||
assert.EqualError(t, err, "send: send error\nretry send: send error")
|
||||
|
||||
sender.AssertNumberOfCalls(t, "Send", 2)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AmqpMessagingApi_Close(t *testing.T) {
|
||||
|
||||
t.Run("close without cached senders", func(t *testing.T) {
|
||||
pool := &connectionPoolMock{}
|
||||
pool.On("Close").Return(nil)
|
||||
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connectionPool: pool,
|
||||
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
|
||||
senderCache: make(map[string]*AmqpSenderSession),
|
||||
}
|
||||
|
||||
err := amqpApi.Close(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
pool.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("close fail without cached senders", func(t *testing.T) {
|
||||
pool := &connectionPoolMock{}
|
||||
pool.On("Close").Return(errors.New("close error"))
|
||||
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connectionPool: pool,
|
||||
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
|
||||
senderCache: make(map[string]*AmqpSenderSession),
|
||||
}
|
||||
|
||||
err := amqpApi.Close(context.Background())
|
||||
assert.EqualError(t, err, "close: close pool: close error")
|
||||
|
||||
pool.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("close with cached senders", func(t *testing.T) {
|
||||
pool := &connectionPoolMock{}
|
||||
pool.On("Close").Return(nil)
|
||||
|
||||
session := &amqpSessionMock{}
|
||||
session.On("Close", mock.Anything).Return(nil)
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Close", mock.Anything).Return(nil)
|
||||
senderSession := &AmqpSenderSession{
|
||||
session: session,
|
||||
sender: sender,
|
||||
}
|
||||
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connectionPool: pool,
|
||||
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
|
||||
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
|
||||
}
|
||||
|
||||
err := amqpApi.Close(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(amqpApi.senderCache))
|
||||
|
||||
pool.AssertNumberOfCalls(t, "Close", 1)
|
||||
session.AssertNumberOfCalls(t, "Close", 1)
|
||||
sender.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("close fail with cached senders", func(t *testing.T) {
|
||||
pool := &connectionPoolMock{}
|
||||
pool.On("Close").Return(nil)
|
||||
|
||||
session := &amqpSessionMock{}
|
||||
session.On("Close", mock.Anything).Return(nil)
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Close", mock.Anything).Return(errors.New("close sender error"))
|
||||
senderSession := &AmqpSenderSession{
|
||||
session: session,
|
||||
sender: sender,
|
||||
}
|
||||
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connectionPool: pool,
|
||||
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
|
||||
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
|
||||
}
|
||||
|
||||
err := amqpApi.Close(context.Background())
|
||||
assert.EqualError(t, err, "close: close session: close sender error")
|
||||
assert.Equal(t, 0, len(amqpApi.senderCache))
|
||||
|
||||
pool.AssertNumberOfCalls(t, "Close", 1)
|
||||
session.AssertNumberOfCalls(t, "Close", 1)
|
||||
sender.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
|
||||
t.Run("close fail", func(t *testing.T) {
|
||||
pool := &connectionPoolMock{}
|
||||
pool.On("Close").Return(errors.New("close pool error"))
|
||||
|
||||
session := &amqpSessionMock{}
|
||||
session.On("Close", mock.Anything).Return(errors.New("close session error"))
|
||||
sender := &amqpSenderMock{}
|
||||
sender.On("Close", mock.Anything).Return(errors.New("close sender error"))
|
||||
senderSession := &AmqpSenderSession{
|
||||
session: session,
|
||||
sender: sender,
|
||||
}
|
||||
|
||||
amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{},
|
||||
connectionPool: pool,
|
||||
connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0},
|
||||
senderCache: map[string]*AmqpSenderSession{"key": senderSession},
|
||||
}
|
||||
|
||||
err := amqpApi.Close(context.Background())
|
||||
assert.EqualError(t, err, "close: close session: close sender error\nclose session error\nclose pool: close pool error")
|
||||
assert.Equal(t, 0, len(amqpApi.senderCache))
|
||||
|
||||
pool.AssertNumberOfCalls(t, "Close", 1)
|
||||
session.AssertNumberOfCalls(t, "Close", 1)
|
||||
sender.AssertNumberOfCalls(t, "Close", 1)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
AmqpTopicPrefix = "topic://"
|
||||
AmqpQueuePrefix = "queue://"
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue