diff --git a/audit/api/api_legacy_dynamic_test.go b/audit/api/api_legacy_dynamic_test.go index cea7df5..7a26d6b 100644 --- a/audit/api/api_legacy_dynamic_test.go +++ b/audit/api/api_legacy_dynamic_test.go @@ -30,7 +30,11 @@ func TestDynamicLegacyAuditApi(t *testing.T) { defer solaceContainer.Stop() // Instantiate the messaging api - messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString}) + messagingApi, err := messaging.NewAmqpApi( + messaging.AmqpConnectionPoolConfig{ + Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString}, + PoolSize: 1, + }) assert.NoError(t, err) // Validator diff --git a/audit/api/api_legacy_test.go b/audit/api/api_legacy_test.go index ffff969..ee45a97 100644 --- a/audit/api/api_legacy_test.go +++ b/audit/api/api_legacy_test.go @@ -32,7 +32,10 @@ func TestLegacyAuditApi(t *testing.T) { defer solaceContainer.Stop() // Instantiate the messaging api - messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString}) + messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{ + Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString}, + PoolSize: 1, + }) assert.NoError(t, err) // Validator @@ -579,7 +582,10 @@ func TestLegacyAuditApi_NewLegacyAuditApi(t *testing.T) { defer solaceContainer.Stop() // Instantiate the messaging api - messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString}) + messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{ + Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString}, + PoolSize: 1, + }) assert.NoError(t, err) // Validator diff --git a/audit/api/api_routable_test.go b/audit/api/api_routable_test.go index f8881df..78048e0 100644 --- a/audit/api/api_routable_test.go +++ b/audit/api/api_routable_test.go @@ -33,7 +33,10 @@ func TestRoutableAuditApi(t *testing.T) { defer solaceContainer.Stop() // Instantiate the messaging api - messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConfig{URL: solaceContainer.AmqpConnectionString}) + messagingApi, err := messaging.NewAmqpApi(messaging.AmqpConnectionPoolConfig{ + Parameters: messaging.AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString}, + PoolSize: 1, + }) assert.NoError(t, err) // Validator diff --git a/audit/messaging/amqp_config.go b/audit/messaging/amqp_config.go new file mode 100644 index 0000000..11a571f --- /dev/null +++ b/audit/messaging/amqp_config.go @@ -0,0 +1,15 @@ +package messaging + +const AmqpTopicPrefix = "topic://" +const connectionTimeoutSeconds = 10 + +type AmqpConnectionConfig struct { + BrokerUrl string `json:"brokerUrl"` + Username string `json:"username"` + Password string `json:"password"` +} + +type AmqpConnectionPoolConfig struct { + Parameters AmqpConnectionConfig `json:"parameters"` + PoolSize int `json:"poolSize"` +} diff --git a/audit/messaging/amqp_connection.go b/audit/messaging/amqp_connection.go new file mode 100644 index 0000000..11d24bf --- /dev/null +++ b/audit/messaging/amqp_connection.go @@ -0,0 +1,277 @@ +package messaging + +import ( + "context" + "errors" + "fmt" + "github.com/Azure/go-amqp" + "log/slog" + "sync" + "time" +) + +var ConnectionClosedError = errors.New("amqp connection is closed") + +type AmqpConnection struct { + connectionName string + lock sync.RWMutex + brokerUrl string + username string + password string + conn amqpConn + dialer amqpDial +} + +// amqpConn is an abstraction of amqp.Conn +type amqpConn interface { + NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) + Close() error + Done() <-chan struct{} +} + +type defaultAmqpConn struct { + conn *amqp.Conn +} + +func newDefaultAmqpConn(conn *amqp.Conn) *defaultAmqpConn { + return &defaultAmqpConn{ + conn: conn, + } +} + +func (d defaultAmqpConn) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) { + session, err := d.conn.NewSession(ctx, opts) + if err != nil { + return nil, err + } + return newDefaultAmqpSession(session), nil +} + +func (d defaultAmqpConn) Close() error { + return d.conn.Close() +} + +func (d defaultAmqpConn) Done() <-chan struct{} { + return d.conn.Done() +} + +var _ amqpConn = (*defaultAmqpConn)(nil) + +type amqpDial interface { + Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) +} + +type amqpSession interface { + NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) + Close(ctx context.Context) error +} + +type defaultAmqpSession struct { + session *amqp.Session +} + +func newDefaultAmqpSession(session *amqp.Session) *defaultAmqpSession { + return &defaultAmqpSession{ + session: session, + } +} + +func (s *defaultAmqpSession) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) { + return s.session.NewSender(ctx, target, opts) +} + +func (s *defaultAmqpSession) Close(ctx context.Context) error { + return s.session.Close(ctx) +} + +var _ amqpSession = (*defaultAmqpSession)(nil) + +type defaultAmqpDialer struct{} + +func (d *defaultAmqpDialer) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) { + dial, err := amqp.Dial(ctx, addr, opts) + if err != nil { + return nil, err + } + return newDefaultAmqpConn(dial), nil +} + +var _ amqpDial = (*defaultAmqpDialer)(nil) + +func NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection { + return &AmqpConnection{ + connectionName: connectionName, + lock: sync.RWMutex{}, + brokerUrl: config.BrokerUrl, + username: config.Username, + password: config.Password, + dialer: &defaultAmqpDialer{}, + } +} + +func (c *AmqpConnection) NewSender(ctx context.Context, topic string) (*AmqpSenderSession, error) { + if c.conn == nil { + return nil, errors.New("connection is not initialized") + } + + if c.internalIsClosed() { + return nil, ConnectionClosedError + } + + c.lock.RLock() + defer c.lock.RUnlock() + + // new session + newSession, err := c.conn.NewSession(ctx, nil) + if err != nil { + return nil, fmt.Errorf("new session: %w", err) + } + + // new sender + newSender, err := newSession.NewSender(ctx, topic, nil) + if err != nil { + err = fmt.Errorf("new internal sender: %w", err) + + closeErr := newSession.Close(ctx) + if closeErr != nil { + return nil, errors.Join(err, fmt.Errorf("close session: %w", closeErr)) + } + return nil, err + } + + return &AmqpSenderSession{newSession, newSender}, nil +} + +func (c *AmqpConnection) ResetConnectionAndRetryIfErrorWithReturnValue(opName string, fn func(ctx context.Context) (any, error)) (any, error) { + ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + result, err := fn(ctx) + cancelFn() + + if err != nil { + slog.Info(fmt.Sprintf("amqp connection: %s", opName), slog.Any("connection", c.connectionName), slog.Any("err", err)) + err := c.ResetConnection(context.Background()) + if err != nil { + return nil, fmt.Errorf("reset connection: %w", err) + } + + newCtx, closeFn := context.WithTimeout(context.Background(), 10*time.Second) + defer closeFn() + + // Retry + return fn(newCtx) + } + return result, nil +} + +func (c *AmqpConnection) ResetConnection(ctx context.Context) error { + c.lock.Lock() + defer c.lock.Unlock() + + err := c.internalClose() + if err != nil { + slog.Warn("amqp connection: reset: failed to close amqp connection", slog.Any("err", err)) + } + + subCtx, cancel := context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second) + err = c.internalConnect(subCtx) + cancel() + if err != nil { + slog.Warn("amqp connection: reset: failed to connect to amqp server, retry..", slog.Any("err", err)) + subCtx, cancel = context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second) + err = c.internalConnect(subCtx) + cancel() + if err != nil { + return fmt.Errorf("connect: %w", err) + } + } + return nil +} + +func As[T any](value any, err error) (*T, error) { + if err != nil { + return nil, err + } + if value == nil { + return nil, nil + } + castedValue, isType := value.(*T) + if !isType { + return nil, fmt.Errorf("could not cast value: %T", value) + } + return castedValue, nil +} + +func (c *AmqpConnection) Connect() error { + c.lock.Lock() + defer c.lock.Unlock() + + subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second) + defer cancel() + + if err := c.internalConnect(subCtx); err != nil { + return fmt.Errorf("internal connection connect: %w", err) + } + return nil +} + +func (c *AmqpConnection) internalConnect(ctx context.Context) error { + if c.conn == nil { + // Set credentials if specified + auth := amqp.SASLTypeAnonymous() + if c.username != "" && c.password != "" { + auth = amqp.SASLTypePlain(c.username, c.password) + } else { + slog.Debug("amqp connection: connect: using anonymous messaging") + } + options := &amqp.ConnOptions{ + SASLType: auth, + } + + // Initialize connection + conn, err := c.dialer.Dial(ctx, c.brokerUrl, options) + if err != nil { + return fmt.Errorf("dial: %w", err) + } + c.conn = conn + } + return nil +} + +func (c *AmqpConnection) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + if err := c.internalClose(); err != nil { + return fmt.Errorf("close: %w", err) + } + return nil +} + +func (c *AmqpConnection) internalClose() error { + if c.conn != nil { + if err := c.conn.Close(); err != nil { + return fmt.Errorf("internal connection close: %w", err) + } + c.conn = nil + } + return nil +} + +func (c *AmqpConnection) IsClosed() bool { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.internalIsClosed() +} + +func (c *AmqpConnection) internalIsClosed() bool { + if c.conn == nil { + return true + } + select { + case <-c.conn.Done(): + return true + default: + return false + } +} diff --git a/audit/messaging/amqp_connection_pool.go b/audit/messaging/amqp_connection_pool.go new file mode 100644 index 0000000..2bb2064 --- /dev/null +++ b/audit/messaging/amqp_connection_pool.go @@ -0,0 +1,213 @@ +package messaging + +import ( + "errors" + "fmt" + "log/slog" + "sync" +) + +type connectionProvider interface { + NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection +} + +type defaultAmqpConnectionProvider struct{} + +func (p defaultAmqpConnectionProvider) NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection { + return NewAmqpConnection(config, connectionName) +} + +var _ connectionProvider = (*defaultAmqpConnectionProvider)(nil) + +type ConnectionPool interface { + Close() error + NewHandle() *ConnectionPoolHandle + GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) +} + +type AmqpConnectionPool struct { + config AmqpConnectionPoolConfig + connectionName string + connections []*AmqpConnection + connectionProvider connectionProvider + handleOffset int + lock sync.RWMutex +} + +type ConnectionPoolHandle struct { + connectionOffset int +} + +func NewAmqpConnectionPool(config AmqpConnectionPoolConfig, connectionName string) (ConnectionPool, error) { + pool := &AmqpConnectionPool{ + config: config, + connectionName: connectionName, + connections: make([]*AmqpConnection, 0), + connectionProvider: defaultAmqpConnectionProvider{}, + handleOffset: 0, + lock: sync.RWMutex{}, + } + + if err := pool.initializeConnections(); err != nil { + if closeErr := pool.Close(); closeErr != nil { + return nil, errors.Join(err, fmt.Errorf("initialize amqp connection: pool closed: %w", closeErr)) + } + return nil, fmt.Errorf("initialize connections: %w", err) + } + + return pool, nil +} + +func (p *AmqpConnectionPool) initializeConnections() error { + if len(p.connections) < p.config.PoolSize { + p.lock.Lock() + defer p.lock.Unlock() + + numMissingConnections := p.config.PoolSize - len(p.connections) + + for i := 0; i < numMissingConnections; i++ { + if err := p.internalAddConnection(); err != nil { + return err + } + } + } + return nil +} + +func (p *AmqpConnectionPool) internalAddConnection() error { + newConnection, err := p.internalNewConnection() + if err != nil { + return fmt.Errorf("new connection: %w", err) + } + p.connections = append(p.connections, newConnection) + return nil +} + +func (p *AmqpConnectionPool) internalNewConnection() (*AmqpConnection, error) { + conn := p.connectionProvider.NewAmqpConnection(&p.config.Parameters, p.connectionName) + if err := conn.Connect(); err != nil { + slog.Warn("amqp connection: failed to connect to amqp broker", slog.Any("err", err)) + + // retry + if err = conn.Connect(); err != nil { + connectErr := fmt.Errorf("failed to connect to amqp broker: %w", err) + if closeErr := conn.Close(); closeErr != nil { + // this case should never happen as the inner connection should always be null, therefore + // it should not have to be closed, i.e. be able to return errors. + return nil, errors.Join(connectErr, fmt.Errorf("close connection: %w", closeErr)) + } + return nil, connectErr + } + } + return conn, nil +} + +func (p *AmqpConnectionPool) Close() error { + p.lock.Lock() + defer p.lock.Unlock() + + closeErrors := make([]error, 0) + for _, conn := range p.connections { + if err := conn.Close(); err != nil { + closeErrors = append(closeErrors, fmt.Errorf("connection: close: %w", err)) + } + } + p.connections = make([]*AmqpConnection, 0) + if len(closeErrors) > 0 { + return errors.Join(closeErrors...) + } + return nil +} + +func (p *AmqpConnectionPool) NewHandle() *ConnectionPoolHandle { + p.lock.Lock() + defer p.lock.Unlock() + + offset := p.handleOffset + p.handleOffset += 1 + + offset = offset % p.config.PoolSize + + return &ConnectionPoolHandle{ + connectionOffset: offset, + } +} + +func (p *AmqpConnectionPool) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) { + // get the requested connection or another one + conn, addConnection := p.nextConnectionForHandle(handle) + + // renew the requested connection if the request connection is closed + if conn == nil || addConnection { + p.lock.Lock() + + // check that accessing the pool only with a valid index (out of bounds should only occur on shutdown) + if handle.connectionOffset < len(p.connections) && p.connections[handle.connectionOffset] == nil { + connection, err := p.internalNewConnection() + if err != nil { + if conn == nil { + // case: connection could not be renewed and no connection to return has been found + p.lock.Unlock() + return nil, fmt.Errorf("renew connection: %w", err) + } + + // case: connection could not be renewed but another connection will be returned + slog.Warn("amqp connection pool: get connection: renew connection: ", slog.Any("err", err)) + } else { + // case: connection could be renewed and will be added to pool + p.connections[handle.connectionOffset] = connection + conn = connection + } + } + p.lock.Unlock() + } + + return conn, nil +} + +func (p *AmqpConnectionPool) nextConnectionForHandle(handle *ConnectionPoolHandle) (*AmqpConnection, bool) { + // return the next possible index (including the retry offset) + nextIndex := func(idx int) int { + if idx+handle.connectionOffset >= p.config.PoolSize { + return (idx + handle.connectionOffset) % p.config.PoolSize + } else { + return idx + handle.connectionOffset + } + } + + // retry as long as there are remaining connections in the pool + var conn *AmqpConnection + var addConnection bool + for i := 0; i < p.config.PoolSize; i++ { + + // get the next possible connection (considering the retry index) + idx := nextIndex(i) + p.lock.RLock() + if idx < len(p.connections) { + conn = p.connections[idx] + } else { + // handle the edge case that the pool is empty on shutdown + conn = nil + } + p.lock.RUnlock() + + // remember that the requested is closed, retry with the next + if conn == nil { + addConnection = true + continue + } + + // if the connection is closed, mark it by setting it to nil + if conn.IsClosed() { + p.lock.Lock() + p.connections[idx] = nil + p.lock.Unlock() + + addConnection = true + continue + } + + return conn, addConnection + } + return nil, true +} diff --git a/audit/messaging/amqp_connection_pool_test.go b/audit/messaging/amqp_connection_pool_test.go new file mode 100644 index 0000000..07316de --- /dev/null +++ b/audit/messaging/amqp_connection_pool_test.go @@ -0,0 +1,572 @@ +package messaging + +import ( + "errors" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "sync" + "testing" +) + +type connectionProviderMock struct { + mock.Mock +} + +func (p *connectionProviderMock) NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection { + args := p.Called(config, connectionName) + return args.Get(0).(*AmqpConnection) +} + +var _ connectionProvider = (*connectionProviderMock)(nil) + +func Test_AmqpConnectionPool_GetHandle(t *testing.T) { + + t.Run("next handle", func(t *testing.T) { + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + } + + handle := pool.NewHandle() + assert.NotNil(t, handle) + assert.Equal(t, 0, handle.connectionOffset) + assert.Equal(t, 1, pool.handleOffset) + }) + + t.Run("next handle high offset", func(t *testing.T) { + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 13, + lock: sync.RWMutex{}, + } + + handle := pool.NewHandle() + assert.NotNil(t, handle) + assert.Equal(t, 3, handle.connectionOffset) + assert.Equal(t, 14, pool.handleOffset) + }) +} + +func Test_AmqpConnectionPool_internalAddConnection(t *testing.T) { + + t.Run("internal add connection", func(t *testing.T) { + conn := &amqpConnMock{} + + dialer := &amqpDialMock{} + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil) + + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: dialer, + } + + connectionProvider := &connectionProviderMock{} + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection) + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connectionProvider: connectionProvider, + } + + err := pool.internalAddConnection() + assert.NoError(t, err) + + assert.Equal(t, 1, len(pool.connections)) + connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1) + dialer.AssertNumberOfCalls(t, "Dial", 1) + }) + + t.Run("dialer error", func(t *testing.T) { + conn := &amqpConnMock{} + + dialer := &amqpDialMock{} + var c *amqpConnMock = nil + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error")).Once() + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil) + + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: dialer, + } + + connectionProvider := &connectionProviderMock{} + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection) + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connectionProvider: connectionProvider, + } + + err := pool.internalAddConnection() + assert.NoError(t, err) + + assert.Equal(t, 1, len(pool.connections)) + connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1) + dialer.AssertNumberOfCalls(t, "Dial", 2) + }) + + t.Run("repetitive dialer error", func(t *testing.T) { + dialer := &amqpDialMock{} + var c *amqpConnMock = nil + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error")) + + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: dialer, + } + + connectionProvider := &connectionProviderMock{} + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection) + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connectionProvider: connectionProvider, + } + + err := pool.internalAddConnection() + assert.EqualError(t, err, "new connection: failed to connect to amqp broker: internal connection connect: dial: test error") + + assert.Equal(t, 0, len(pool.connections)) + connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 1) + dialer.AssertNumberOfCalls(t, "Dial", 2) + }) +} + +func Test_AmqpConnectionPool_initializeConnections(t *testing.T) { + + t.Run("initialize connections successfully", func(t *testing.T) { + + conn := &amqpConnMock{} + dialer := &amqpDialMock{} + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil) + + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: dialer, + } + + connectionProvider := &connectionProviderMock{} + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection) + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connectionProvider: connectionProvider, + } + + err := pool.initializeConnections() + assert.NoError(t, err) + + assert.Equal(t, 5, len(pool.connections)) + connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 5) + }) + + t.Run("fail initialization of connections", func(t *testing.T) { + + var c *amqpConnMock = nil + failingDialer := &amqpDialMock{} + failingDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error")) + + failingConnection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: failingDialer, + } + + conn := &amqpConnMock{} + successfulDialer := &amqpDialMock{} + successfulDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil) + + successfulConnection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: successfulDialer, + } + + connectionProvider := &connectionProviderMock{} + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(4) + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(failingConnection) + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connectionProvider: connectionProvider, + } + + err := pool.initializeConnections() + assert.EqualError(t, err, "new connection: failed to connect to amqp broker: internal connection connect: dial: test error") + + assert.Equal(t, 4, len(pool.connections)) + connectionProvider.AssertNumberOfCalls(t, "NewAmqpConnection", 5) + }) +} + +func Test_AmqpConnectionPool_Close(t *testing.T) { + + t.Run("close connection successfully", func(t *testing.T) { + // add 5 connections to the pool + conn := &amqpConnMock{} + conn.On("Close").Return(nil) + + dialer := &amqpDialMock{} + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil) + + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: dialer, + } + + connectionProvider := &connectionProviderMock{} + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection) + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connectionProvider: connectionProvider, + } + + err := pool.initializeConnections() + assert.NoError(t, err) + + assert.Equal(t, 5, len(pool.connections)) + + // close the pool + err = pool.Close() + assert.NoError(t, err) + assert.Equal(t, 0, len(pool.connections)) + }) + + t.Run("close connection fail", func(t *testing.T) { + // add 5 connections to the pool + failingConn := &amqpConnMock{} + failingConn.On("Close").Return(errors.New("test error")) + + failingDialer := &amqpDialMock{} + failingDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(failingConn, nil) + + failingConnection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: failingDialer, + } + + successfulConn := &amqpConnMock{} + successfulConn.On("Close").Return(nil) + successfulDialer := &amqpDialMock{} + successfulDialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(successfulConn, nil) + + successfulConnection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: successfulDialer, + } + + connectionProvider := &connectionProviderMock{} + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(2) + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(failingConnection).Times(2) + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(successfulConnection).Times(1) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connectionProvider: connectionProvider, + } + + err := pool.initializeConnections() + assert.NoError(t, err) + + assert.Equal(t, 5, len(pool.connections)) + + // close the pool + err = pool.Close() + assert.EqualError(t, err, "connection: close: close: internal connection close: test error\nconnection: close: close: internal connection close: test error") + assert.Equal(t, 0, len(pool.connections)) + }) +} + +func Test_AmqpConnectionPool_nextConnectionForHandle(t *testing.T) { + channelReceiver := func(channel chan struct{}) <-chan struct{} { + return channel + } + + newActiveConnection := func() *AmqpConnection { + channel := make(chan struct{}) + conn := &amqpConnMock{} + conn.On("Done", mock.Anything).Return(channelReceiver(channel)) + return &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + conn: conn, + } + } + + newClosedConnection := func() *AmqpConnection { + channel := make(chan struct{}) + close(channel) + + conn := &amqpConnMock{} + conn.On("Done", mock.Anything).Return(channelReceiver(channel)) + return &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + conn: conn, + } + } + + t.Run("next connection for requested handle", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + for i := 0; i < 5; i++ { + connections = append(connections, newActiveConnection()) + } + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + } + + connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1}) + assert.NotNil(t, connection) + assert.False(t, addConnection) + }) + + t.Run("nil connection for requested handle", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + connections = append(connections, newActiveConnection()) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, newActiveConnection()) + connections = append(connections, newActiveConnection()) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + } + + connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1}) + assert.NotNil(t, connection) + assert.True(t, addConnection) + }) + + t.Run("closed connection for requested handle", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + connections = append(connections, newActiveConnection()) + connections = append(connections, newClosedConnection()) + connections = append(connections, newClosedConnection()) + connections = append(connections, newActiveConnection()) + connections = append(connections, newActiveConnection()) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + } + + connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1}) + assert.NotNil(t, connection) + assert.True(t, addConnection) + }) + + t.Run("no connection for requested handle", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + } + + connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 1}) + assert.Nil(t, connection) + assert.True(t, addConnection) + }) + + t.Run("connection for requested handle with large index", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, newActiveConnection()) + connections = append(connections, nil) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + } + + connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 23}) + assert.NotNil(t, connection) + assert.False(t, addConnection) + }) + + t.Run("connection for requested handle nil with large index", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, newActiveConnection()) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + } + + connection, addConnection := pool.nextConnectionForHandle(&ConnectionPoolHandle{connectionOffset: 23}) + assert.NotNil(t, connection) + assert.True(t, addConnection) + }) +} + +func Test_AmqpConnectionPool_GetConnection(t *testing.T) { + channelReceiver := func(channel chan struct{}) <-chan struct{} { + return channel + } + + newActiveConnection := func() *AmqpConnection { + channel := make(chan struct{}) + conn := &amqpConnMock{} + conn.On("Done", mock.Anything).Return(channelReceiver(channel)) + return &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + conn: conn, + } + } + + t.Run("get connection for requested handle", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + for i := 0; i < 5; i++ { + connections = append(connections, newActiveConnection()) + } + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + } + + connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1}) + assert.NoError(t, err) + assert.NotNil(t, connection) + assert.Equal(t, connections[1], connection) + assert.Equal(t, 5, len(connections)) + }) + + t.Run("add connection if missing", func(t *testing.T) { + connections := make([]*AmqpConnection, 5) + + connectionProvider := &connectionProviderMock{} + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(newActiveConnection()) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + connectionProvider: connectionProvider, + } + + connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1}) + assert.NoError(t, err) + assert.NotNil(t, connection) + assert.Equal(t, connections[1], connection) + assert.Equal(t, 5, len(connections)) + }) + + t.Run("add connection fails returns alternative connection", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + connections = append(connections, newActiveConnection()) + connections = append(connections, nil) + connections = append(connections, newActiveConnection()) + connections = append(connections, newActiveConnection()) + connections = append(connections, newActiveConnection()) + + connectionProvider := &connectionProviderMock{} + + dialer := &amqpDialMock{} + var c *amqpConnMock = nil + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, fmt.Errorf("dial error")) + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: dialer, + } + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + connectionProvider: connectionProvider, + } + + connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1}) + assert.NoError(t, err) + assert.NotNil(t, connection) + assert.Nil(t, connections[1]) + assert.Equal(t, connections[2], connection) + assert.Equal(t, 5, len(connections)) + }) + + t.Run("add connection fails", func(t *testing.T) { + connections := make([]*AmqpConnection, 0) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + connections = append(connections, nil) + + connectionProvider := &connectionProviderMock{} + + dialer := &amqpDialMock{} + var c *amqpConnMock = nil + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, fmt.Errorf("dial error")) + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + dialer: dialer, + } + connectionProvider.On("NewAmqpConnection", mock.Anything, mock.Anything).Return(connection) + + pool := AmqpConnectionPool{ + config: AmqpConnectionPoolConfig{PoolSize: 5}, + handleOffset: 0, + lock: sync.RWMutex{}, + connections: connections, + connectionProvider: connectionProvider, + } + + connection, err := pool.GetConnection(&ConnectionPoolHandle{connectionOffset: 1}) + assert.EqualError(t, err, "renew connection: failed to connect to amqp broker: internal connection connect: dial: dial error") + assert.Nil(t, connection) + assert.Equal(t, 5, len(connections)) + }) +} diff --git a/audit/messaging/amqp_connection_test.go b/audit/messaging/amqp_connection_test.go new file mode 100644 index 0000000..197d964 --- /dev/null +++ b/audit/messaging/amqp_connection_test.go @@ -0,0 +1,415 @@ +package messaging + +import ( + "context" + "errors" + "github.com/Azure/go-amqp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "sync" + "testing" +) + +type amqpConnMock struct { + mock.Mock +} + +func (m *amqpConnMock) Done() <-chan struct{} { + args := m.Called() + return args.Get(0).(<-chan struct{}) +} + +func (m *amqpConnMock) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) { + args := m.Called(ctx, opts) + return args.Get(0).(amqpSession), args.Error(1) +} + +func (m *amqpConnMock) Close() error { + args := m.Called() + return args.Error(0) +} + +var _ amqpConn = (*amqpConnMock)(nil) + +type amqpDialMock struct { + mock.Mock +} + +func (m *amqpDialMock) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) { + args := m.Called(ctx, addr, opts) + return args.Get(0).(amqpConn), args.Error(1) +} + +var _ amqpDial = (*amqpDialMock)(nil) + +type amqpSessionMock struct { + mock.Mock +} + +func (m *amqpSessionMock) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) { + args := m.Called(ctx, target, opts) + return args.Get(0).(amqpSender), args.Error(1) +} + +func (m *amqpSessionMock) Close(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +var _ amqpSession = (*amqpSessionMock)(nil) + +func Test_AmqpConnection_IsClosed(t *testing.T) { + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + } + + channelReceiver := func(channel chan struct{}) <-chan struct{} { + return channel + } + + t.Run("is closed - connection nil", func(t *testing.T) { + assert.True(t, connection.IsClosed()) + }) + + t.Run("is closed", func(t *testing.T) { + channel := make(chan struct{}) + close(channel) + amqpConnMock := &amqpConnMock{} + amqpConnMock.On("Done").Return(channelReceiver(channel)) + connection.conn = amqpConnMock + + assert.True(t, connection.IsClosed()) + }) + + t.Run("is not closed", func(t *testing.T) { + channel := make(chan struct{}) + amqpConnMock := &amqpConnMock{} + amqpConnMock.On("Done").Return(channelReceiver(channel)) + connection.conn = amqpConnMock + + assert.False(t, connection.IsClosed()) + }) +} + +func Test_AmqpConnection_Close(t *testing.T) { + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + } + + t.Run("already closed", func(t *testing.T) { + assert.NoError(t, connection.Close()) + }) + + t.Run("close error", func(t *testing.T) { + err := errors.New("test error") + + amqpConnMock := &amqpConnMock{} + amqpConnMock.On("Close").Return(err) + connection.conn = amqpConnMock + + assert.EqualError(t, connection.Close(), "close: internal connection close: test error") + assert.NotNil(t, connection.conn) + amqpConnMock.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("close without error", func(t *testing.T) { + amqpConnMock := &amqpConnMock{} + amqpConnMock.On("Close").Return(nil) + connection.conn = amqpConnMock + + assert.Nil(t, connection.Close()) + assert.Nil(t, connection.conn) + amqpConnMock.AssertNumberOfCalls(t, "Close", 1) + }) +} + +func Test_AmqpConnection_Connect(t *testing.T) { + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + } + + t.Run("already connected", func(t *testing.T) { + connection.conn = &amqpConnMock{} + assert.NoError(t, connection.Connect()) + }) + + t.Run("dial error", func(t *testing.T) { + connection.conn = nil + connection.username = "user" + connection.password = "pass" + + amqpDialMock := &amqpDialMock{} + var c *amqpConnMock = nil + amqpDialMock.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(c, errors.New("test error")) + connection.dialer = amqpDialMock + + assert.EqualError(t, connection.Connect(), "internal connection connect: dial: test error") + assert.Nil(t, connection.conn) + }) + + t.Run("connect without error", func(t *testing.T) { + connection.conn = nil + + amqpDialMock := &amqpDialMock{} + amqpConn := &amqpConnMock{} + amqpDialMock.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(amqpConn, nil) + connection.dialer = amqpDialMock + + assert.NoError(t, connection.Connect()) + assert.Equal(t, amqpConn, connection.conn) + }) +} + +func Test_AmqpConnection_ResetConnection(t *testing.T) { + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + } + + t.Run("reset connection - connect on first attempt", func(t *testing.T) { + connection.conn = nil + + dialer := &amqpDialMock{} + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(&amqpConnMock{}, nil) + connection.dialer = dialer + + assert.NoError(t, connection.ResetConnection(context.Background())) + assert.NotNil(t, connection.conn) + + dialer.AssertNumberOfCalls(t, "Dial", 1) + }) + + t.Run("reset connection - connect on second attempt", func(t *testing.T) { + connection.conn = nil + + dialer := &amqpDialMock{} + var conn *amqpConnMock = nil + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, errors.New("test error")).Once() + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(&amqpConnMock{}, nil) + connection.dialer = dialer + + assert.NoError(t, connection.ResetConnection(context.Background())) + assert.NotNil(t, connection.conn) + + dialer.AssertNumberOfCalls(t, "Dial", 2) + }) + + t.Run("reset connection - fail continuously", func(t *testing.T) { + connection.conn = nil + + dialer := &amqpDialMock{} + var conn *amqpConnMock = nil + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, errors.New("test error")) + connection.dialer = dialer + + assert.EqualError(t, connection.ResetConnection(context.Background()), "connect: dial: test error") + assert.Nil(t, connection.conn) + + dialer.AssertNumberOfCalls(t, "Dial", 2) + }) +} + +type ResettableFunction struct { + mock.Mock +} + +func (f *ResettableFunction) Run(ctx context.Context) (any, error) { + args := f.Called(ctx) + return args.Get(0), args.Error(1) +} + +func Test_AmqpConnection_ResetConnectionAndRetryIfErrorWithReturnValue(t *testing.T) { + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + } + + t.Run("no error", func(t *testing.T) { + value, err := connection.ResetConnectionAndRetryIfErrorWithReturnValue("test", func(ctx context.Context) (any, error) { + return struct{}{}, nil + }) + assert.NoError(t, err) + assert.Equal(t, struct{}{}, value) + }) + + t.Run("reset - success", func(t *testing.T) { + connection.conn = nil + + dialer := &amqpDialMock{} + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(&amqpConnMock{}, nil) + connection.dialer = dialer + + resettableFunction := &ResettableFunction{} + resettableFunction.On("Run", mock.Anything).Return(struct{}{}, errors.New("test error")).Once() + resettableFunction.On("Run", mock.Anything).Return(struct{}{}, nil) + + value, err := connection.ResetConnectionAndRetryIfErrorWithReturnValue("test", resettableFunction.Run) + assert.NoError(t, err) + assert.Equal(t, struct{}{}, value) + + resettableFunction.AssertNumberOfCalls(t, "Run", 2) + }) + + t.Run("reset - fail", func(t *testing.T) { + connection.conn = nil + + dialer := &amqpDialMock{} + var conn *amqpConnMock = nil + dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, errors.New("test error")) + connection.dialer = dialer + + resettableFunction := &ResettableFunction{} + resettableFunction.On("Run", mock.Anything).Return(struct{}{}, errors.New("test error")) + + value, err := connection.ResetConnectionAndRetryIfErrorWithReturnValue("test", resettableFunction.Run) + assert.EqualError(t, err, "reset connection: connect: dial: test error") + assert.Nil(t, value) + + resettableFunction.AssertNumberOfCalls(t, "Run", 1) + }) +} + +func Test_AmqpConnection_NewSender(t *testing.T) { + connection := &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + } + + channelReceiver := func(channel chan struct{}) <-chan struct{} { + return channel + } + + t.Run("connection not initialized", func(t *testing.T) { + sender, err := connection.NewSender(context.Background(), "topic") + assert.EqualError(t, err, "connection is not initialized") + assert.Nil(t, sender) + }) + + t.Run("connection is closed", func(t *testing.T) { + channel := make(chan struct{}) + close(channel) + + conn := &amqpConnMock{} + conn.On("Done").Return(channelReceiver(channel)) + connection.conn = conn + + sender, err := connection.NewSender(context.Background(), "topic") + assert.EqualError(t, err, "amqp connection is closed") + assert.Nil(t, sender) + }) + + t.Run("session error", func(t *testing.T) { + channel := make(chan struct{}) + + var session *amqpSessionMock = nil + conn := &amqpConnMock{} + conn.On("NewSession", mock.Anything, mock.Anything).Return(session, errors.New("test error")) + conn.On("Done").Return(channelReceiver(channel)) + connection.conn = conn + + sender, err := connection.NewSender(context.Background(), "topic") + assert.EqualError(t, err, "new session: test error") + assert.Nil(t, sender) + }) + + t.Run("sender error", func(t *testing.T) { + channel := make(chan struct{}) + + sessionMock := &amqpSessionMock{} + var amqpSender *amqp.Sender = nil + sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, errors.New("test error")) + sessionMock.On("Close", mock.Anything).Return(nil) + + conn := &amqpConnMock{} + conn.On("Done").Return(channelReceiver(channel)) + conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil) + connection.conn = conn + + sender, err := connection.NewSender(context.Background(), "topic") + assert.EqualError(t, err, "new internal sender: test error") + assert.Nil(t, sender) + }) + + t.Run("session close error", func(t *testing.T) { + channel := make(chan struct{}) + + sessionMock := &amqpSessionMock{} + var amqpSender *amqp.Sender = nil + sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, errors.New("test error")) + sessionMock.On("Close", mock.Anything).Return(errors.New("close error")) + + conn := &amqpConnMock{} + conn.On("Done").Return(channelReceiver(channel)) + conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil) + connection.conn = conn + + sender, err := connection.NewSender(context.Background(), "topic") + assert.EqualError(t, err, "new internal sender: test error\nclose session: close error") + assert.Nil(t, sender) + }) + + t.Run("get sender", func(t *testing.T) { + channel := make(chan struct{}) + + amqpSender := &amqp.Sender{} + sessionMock := &amqpSessionMock{} + sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(amqpSender, nil) + + conn := &amqpConnMock{} + conn.On("Done").Return(channelReceiver(channel)) + conn.On("NewSession", mock.Anything, mock.Anything).Return(sessionMock, nil) + connection.conn = conn + + sender, err := connection.NewSender(context.Background(), "topic") + assert.NoError(t, err) + assert.NotNil(t, sender) + assert.Equal(t, amqpSender, sender.sender) + assert.Equal(t, sessionMock, sender.session) + }) +} + +func Test_AmqpConnection_NewAmqpConnection(t *testing.T) { + config := AmqpConnectionConfig{ + BrokerUrl: "brokerUrl", + Username: "username", + Password: "password", + } + connection := NewAmqpConnection(&config, "connectionName") + assert.NotNil(t, connection) + assert.Equal(t, connection.connectionName, "connectionName") + assert.Equal(t, connection.brokerUrl, "brokerUrl") + assert.Equal(t, connection.username, "username") + assert.Equal(t, connection.password, "password") + assert.NotNil(t, connection.dialer) +} + +func Test_As(t *testing.T) { + + t.Run("error", func(t *testing.T) { + value, err := As[amqp.Message](nil, errors.New("test error")) + assert.EqualError(t, err, "test error") + assert.Nil(t, value) + }) + + t.Run("value nil", func(t *testing.T) { + value, err := As[amqp.Message](nil, nil) + assert.NoError(t, err) + assert.Nil(t, value) + }) + + t.Run("value not not type", func(t *testing.T) { + value, err := As[amqp.Message](struct{}{}, nil) + assert.EqualError(t, err, "could not cast value: struct {}") + assert.Nil(t, value) + }) + + t.Run("cast", func(t *testing.T) { + var sessionAny any = &amqpSessionMock{} + value, err := As[amqpSessionMock](sessionAny, nil) + assert.NoError(t, err) + assert.NotNil(t, value) + }) +} diff --git a/audit/messaging/amqp_sender_session.go b/audit/messaging/amqp_sender_session.go new file mode 100644 index 0000000..92d3e61 --- /dev/null +++ b/audit/messaging/amqp_sender_session.go @@ -0,0 +1,86 @@ +package messaging + +import ( + "context" + "errors" + "fmt" + "github.com/Azure/go-amqp" + "log/slog" + "strings" + "time" +) + +type amqpSender interface { + Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error + Close(ctx context.Context) error +} + +type AmqpSenderSession struct { + session amqpSession + sender amqpSender +} + +func (s *AmqpSenderSession) Send( + topic string, + data [][]byte, + contentType string, + applicationProperties map[string]any, +) error { + // check topic name + if !strings.HasPrefix(topic, AmqpTopicPrefix) { + return fmt.Errorf( + "topic %q name lacks mandatory prefix %q", + topic, + AmqpTopicPrefix, + ) + } + + if contentType == "" { + return errors.New("content-type is required") + } + + // prepare the amqp message + message := amqp.Message{ + Header: &amqp.MessageHeader{ + Durable: true, + }, + Properties: &amqp.MessageProperties{ + To: &topic, + ContentType: &contentType, + }, + ApplicationProperties: applicationProperties, + Data: data, + } + + // send + ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFn() + return s.sender.Send(ctx, &message, nil) +} + +func (s *AmqpSenderSession) Close() error { + ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFn() + + var closeErrors []error + senderErr := s.sender.Close(ctx) + if senderErr != nil { + closeErrors = append(closeErrors, senderErr) + } + sessionErr := s.session.Close(ctx) + if sessionErr != nil { + closeErrors = append(closeErrors, sessionErr) + } + + if len(closeErrors) > 0 { + return errors.Join(closeErrors...) + } + return nil +} + +func (s *AmqpSenderSession) CloseSilently() { + err := s.Close() + if err != nil { + slog.Error("error closing sender session", slog.Any("err", err)) + } +} diff --git a/audit/messaging/amqp_sender_session_test.go b/audit/messaging/amqp_sender_session_test.go new file mode 100644 index 0000000..9d0409f --- /dev/null +++ b/audit/messaging/amqp_sender_session_test.go @@ -0,0 +1,186 @@ +package messaging + +import ( + "context" + "errors" + "github.com/Azure/go-amqp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "testing" +) + +type amqpSenderMock struct { + mock.Mock +} + +func (m *amqpSenderMock) Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error { + return m.Called(ctx, msg, opts).Error(0) +} + +func (m *amqpSenderMock) Close(ctx context.Context) error { + return m.Called(ctx).Error(0) +} + +var _ amqpSender = (*amqpSenderMock)(nil) + +func Test_AmqpSenderSession_Close(t *testing.T) { + + t.Run("close without errors", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Close", mock.Anything).Return(nil) + session := &amqpSessionMock{} + session.On("Close", mock.Anything).Return(nil) + + senderSession := &AmqpSenderSession{ + sender: sender, + session: session, + } + + err := senderSession.Close() + assert.NoError(t, err) + + sender.AssertNumberOfCalls(t, "Close", 1) + session.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("close with sender error", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Close", mock.Anything).Return(errors.New("sender error")) + session := &amqpSessionMock{} + session.On("Close", mock.Anything).Return(nil) + + senderSession := &AmqpSenderSession{ + sender: sender, + session: session, + } + + err := senderSession.Close() + assert.EqualError(t, err, "sender error") + + sender.AssertNumberOfCalls(t, "Close", 1) + session.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("close with session error", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Close", mock.Anything).Return(nil) + session := &amqpSessionMock{} + session.On("Close", mock.Anything).Return(errors.New("session error")) + + senderSession := &AmqpSenderSession{ + sender: sender, + session: session, + } + + err := senderSession.Close() + assert.EqualError(t, err, "session error") + + sender.AssertNumberOfCalls(t, "Close", 1) + session.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("close with sender and session error", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Close", mock.Anything).Return(errors.New("sender error")) + session := &amqpSessionMock{} + session.On("Close", mock.Anything).Return(errors.New("session error")) + + senderSession := &AmqpSenderSession{ + sender: sender, + session: session, + } + + err := senderSession.Close() + assert.EqualError(t, err, "sender error\nsession error") + + sender.AssertNumberOfCalls(t, "Close", 1) + session.AssertNumberOfCalls(t, "Close", 1) + }) +} + +func Test_AmqpSenderSession_Send(t *testing.T) { + + t.Run("invalid topic name", func(t *testing.T) { + sender := &amqpSenderMock{} + session := &amqpSessionMock{} + + senderSession := &AmqpSenderSession{ + sender: sender, + session: session, + } + + data := [][]byte{[]byte("data")} + err := senderSession.Send("invalid", data, "application/json", map[string]interface{}{}) + assert.EqualError(t, err, "topic \"invalid\" name lacks mandatory prefix \"topic://\"") + }) + + t.Run("content type missing", func(t *testing.T) { + sender := &amqpSenderMock{} + session := &amqpSessionMock{} + + senderSession := &AmqpSenderSession{ + sender: sender, + session: session, + } + + data := [][]byte{[]byte("data")} + err := senderSession.Send("topic://some/name", data, "", map[string]interface{}{}) + assert.EqualError(t, err, "content-type is required") + }) + + t.Run("send", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil) + session := &amqpSessionMock{} + + senderSession := &AmqpSenderSession{ + sender: sender, + session: session, + } + + data := [][]byte{[]byte("data")} + applicationProperties := map[string]interface{}{} + applicationProperties["key"] = "value" + err := senderSession.Send("topic://some/name", data, "application/json", applicationProperties) + assert.NoError(t, err) + + sender.AssertNumberOfCalls(t, "Send", 1) + calls := sender.Calls + assert.Equal(t, 1, len(calls)) + + ctx, isCtx := calls[0].Arguments[0].(context.Context) + assert.True(t, isCtx) + assert.NotNil(t, ctx) + + message, isMsg := calls[0].Arguments[1].(*amqp.Message) + assert.True(t, isMsg) + assert.True(t, message.Header.Durable) + assert.Equal(t, "topic://some/name", *message.Properties.To) + assert.Equal(t, "application/json", *message.Properties.ContentType) + assert.Equal(t, applicationProperties, message.ApplicationProperties) + assert.Equal(t, data, message.Data) + + senderOptions, isSenderOptions := calls[0].Arguments[2].(*amqp.SendOptions) + assert.True(t, isSenderOptions) + assert.Nil(t, senderOptions) + }) + + t.Run("send fails", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("send fail")) + session := &amqpSessionMock{} + + senderSession := &AmqpSenderSession{ + sender: sender, + session: session, + } + + data := [][]byte{[]byte("data")} + applicationProperties := map[string]interface{}{} + applicationProperties["key"] = "value" + + err := senderSession.Send("topic://some/name", data, "application/json", applicationProperties) + assert.EqualError(t, err, "send fail") + sender.AssertNumberOfCalls(t, "Send", 1) + }) +} diff --git a/audit/messaging/messaging.go b/audit/messaging/messaging.go index d94b349..d86c66a 100644 --- a/audit/messaging/messaging.go +++ b/audit/messaging/messaging.go @@ -2,19 +2,13 @@ package messaging import ( "context" + "dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/log" "errors" "fmt" - "strings" "sync" "time" - - "dev.azure.com/schwarzit/schwarzit.stackit-public/audit-go.git/log" - "github.com/Azure/go-amqp" ) -// Default connection timeout for the AMQP connection -const connectionTimeoutSeconds = 10 - // Api is an abstraction for a messaging system that can be used to send // audit logs to the audit log system. type Api interface { @@ -38,203 +32,99 @@ type Api interface { Close(ctx context.Context) error } -// MutexApi is wrapper around an API implementation that controls mutual exclusive access to the api. -type MutexApi struct { - mutex sync.Mutex - api Api -} - -var _ Api = &MutexApi{} - -func NewMutexApi(api Api) (Api, error) { - if api == nil { - return nil, errors.New("api is nil") - } - mutexApi := MutexApi{ - mutex: sync.Mutex{}, - api: api, - } - - var genericApi Api = &mutexApi - return genericApi, nil -} - -// Send implements Api.Send -func (m *MutexApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error { - m.mutex.Lock() - defer m.mutex.Unlock() - return m.api.Send(ctx, topic, data, contentType, applicationProperties) -} - -func (m *MutexApi) Close(ctx context.Context) error { - m.mutex.Lock() - defer m.mutex.Unlock() - return m.api.Close(ctx) -} - -// AmqpConfig provides AMQP connection related parameters. -type AmqpConfig struct { - URL string - User string - Password string -} - -// AmqpSession is an abstraction providing a subset of the methods of amqp.Session -type AmqpSession interface { - NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) - Close(ctx context.Context) error -} - -type AmqpSessionWrapper struct { - session *amqp.Session -} - -func (w AmqpSessionWrapper) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) { - return w.session.NewSender(ctx, target, opts) -} - -func (w AmqpSessionWrapper) Close(ctx context.Context) error { - return w.session.Close(ctx) -} - -// AmqpSender is an abstraction providing a subset of the methods of amqp.Sender -type AmqpSender interface { - Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error - Close(ctx context.Context) error -} - // AmqpApi implements Api. type AmqpApi struct { - config AmqpConfig - connection *amqp.Conn - session AmqpSession + config AmqpConnectionPoolConfig + connection *AmqpConnection + connectionPool ConnectionPool + connectionPoolHandle *ConnectionPoolHandle + senderCache map[string]*AmqpSenderSession + lock sync.RWMutex } var _ Api = &AmqpApi{} -func NewAmqpApi(amqpConfig AmqpConfig) (Api, error) { - amqpApi := &AmqpApi{config: amqpConfig} - - if err := amqpApi.connect(); err != nil { - return nil, fmt.Errorf("connect to broker: %w", err) - } - - return amqpApi, nil -} - -// connect opens a new connection and session to the AMQP messaging system. -// The connection attempt will be cancelled after connectionTimeoutSeconds. -func (a *AmqpApi) connect() error { - log.AuditLogger.Info("connecting to audit messaging system") - - // Set credentials if specified - auth := amqp.SASLTypeAnonymous() - - if a.config.User != "" && a.config.Password != "" { - auth = amqp.SASLTypePlain(a.config.User, a.config.Password) - log.AuditLogger.Info("using username and password for messaging") - } else { - log.AuditLogger.Warn("using anonymous messaging!") - } - - options := &amqp.ConnOptions{ - SASLType: auth, - } - - // Create new context with timeout for the connection initialization - subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second) - defer cancel() - - // Initialize connection - conn, err := amqp.Dial(subCtx, a.config.URL, options) +func NewAmqpApi(amqpConfig AmqpConnectionPoolConfig) (Api, error) { + connectionPool, err := NewAmqpConnectionPool(amqpConfig, "sdk") if err != nil { - return fmt.Errorf("dial connection to broker: %w", err) - } - a.connection = conn - - // Initialize session - session, err := conn.NewSession(context.Background(), nil) - if err != nil { - return fmt.Errorf("create session: %w", err) + return nil, fmt.Errorf("new amqp connection pool: %w", err) } - var amqpSession AmqpSession = &AmqpSessionWrapper{session: session} - a.session = amqpSession + amqpApi := &AmqpApi{config: amqpConfig, + connectionPool: connectionPool, + connectionPoolHandle: connectionPool.NewHandle(), + senderCache: make(map[string]*AmqpSenderSession), + } - return nil + var messagingApi Api = amqpApi + return messagingApi, nil } // Send implements Api.Send. // If errors occur the connection to the messaging system will be closed and re-established. func (a *AmqpApi) Send(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error { - err := a.trySend(ctx, topic, data, contentType, applicationProperties) - if err == nil { - return nil - } - // Drop the current sender, as it cannot connect to the broker anymore - log.AuditLogger.Error("message sender error, recreating", err) - - err = a.resetConnection(ctx) - if err != nil { - return fmt.Errorf("reset connection: %w", err) - } - - return a.trySend(ctx, topic, data, contentType, applicationProperties) -} - -// trySend actually sends the given data as amqp.Message to the messaging system. -func (a *AmqpApi) trySend(ctx context.Context, topic string, data []byte, contentType string, applicationProperties map[string]any) error { - if !strings.HasPrefix(topic, AmqpTopicPrefix) { - return fmt.Errorf( - "topic %q name lacks mandatory prefix %q", - topic, - AmqpTopicPrefix, - ) - } - - sender, err := a.session.NewSender(ctx, topic, nil) - if err != nil { - return fmt.Errorf("new sender: %w", err) - } - defer func() { - if err := sender.Close(ctx); err != nil { - log.AuditLogger.Error("failed to close session sender", err) + a.lock.RLock() + connectionIsClosed := a.connection == nil || a.connection.IsClosed() + a.lock.RUnlock() + if connectionIsClosed { + connection, err := a.connectionPool.GetConnection(a.connectionPoolHandle) + if err != nil { + return fmt.Errorf("get connection: %w", err) } - }() - - bytes := [][]byte{data} - message := amqp.Message{ - Header: &amqp.MessageHeader{ - Durable: true, - }, - Properties: &amqp.MessageProperties{ - To: &topic, - ContentType: &contentType, - }, - ApplicationProperties: applicationProperties, - Data: bytes, + a.lock.Lock() + a.connection = connection + a.lock.Unlock() } - err = sender.Send(ctx, &message, nil) - if err != nil { - return fmt.Errorf("send message: %w", err) + a.lock.RLock() + var sender = a.senderCache[topic] + a.lock.RUnlock() + if sender == nil { + a.lock.RLock() + ctx, cancelFn := context.WithTimeout(ctx, 10*time.Second) + senderSession, err := a.connection.NewSender(ctx, topic) + cancelFn() + a.lock.RUnlock() + if err != nil { + return fmt.Errorf("new sender: %w", err) + } + a.lock.Lock() + a.senderCache[topic] = senderSession + a.lock.Unlock() + sender = senderSession } + wrappedData := [][]byte{data} + if err := sender.Send(topic, wrappedData, contentType, applicationProperties); err != nil { + return fmt.Errorf("send: %w", err) + } return nil } -// resetConnection closes the current session and connection and reconnects to the messaging system. -func (a *AmqpApi) resetConnection(ctx context.Context) error { - if err := a.Close(ctx); err != nil { - log.AuditLogger.Error("failed to close audit messaging connection", err) +// Close implements Api.Close +func (a *AmqpApi) Close(_ context.Context) error { + log.AuditLogger.Info("close audit amqp connection pool") + + a.lock.Lock() + defer a.lock.Unlock() + + // cached senders + var closeErrors []error + for _, session := range a.senderCache { + if err := session.Close(); err != nil { + closeErrors = append(closeErrors, fmt.Errorf("close session: %w", err)) + } + } + clear(a.senderCache) + + // pool + if err := a.connectionPool.Close(); err != nil { + closeErrors = append(closeErrors, fmt.Errorf("close pool: %w", err)) } - return a.connect() -} - -// Close implements Api.Close -func (a *AmqpApi) Close(ctx context.Context) error { - log.AuditLogger.Info("close audit messaging connection") - return errors.Join(a.session.Close(ctx), a.connection.Close()) + if len(closeErrors) > 0 { + return fmt.Errorf("close: %w", errors.Join(closeErrors...)) + } + return nil } diff --git a/audit/messaging/messaging_test.go b/audit/messaging/messaging_test.go index d6b1a5b..c009a89 100644 --- a/audit/messaging/messaging_test.go +++ b/audit/messaging/messaging_test.go @@ -4,50 +4,38 @@ import ( "context" "errors" "fmt" - "testing" - "time" - - "github.com/Azure/go-amqp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "sync" + "testing" + "time" ) -type AmqpSessionMock struct { +type connectionPoolMock struct { mock.Mock } -func (m *AmqpSessionMock) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (AmqpSender, error) { - args := m.Called(ctx, target, opts) - var sender AmqpSender = nil - if args.Get(0) != nil { - sender = args.Get(0).(AmqpSender) - } - err := args.Error(1) - return sender, err +func (m *connectionPoolMock) Close() error { + return m.Called().Error(0) } -func (m *AmqpSessionMock) Close(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) +func (m *connectionPoolMock) NewHandle() *ConnectionPoolHandle { + return m.Called().Get(0).(*ConnectionPoolHandle) } -type AmqpSenderMock struct { - mock.Mock +func (m *connectionPoolMock) GetConnection(handle *ConnectionPoolHandle) (*AmqpConnection, error) { + return m.Called(handle).Get(0).(*AmqpConnection), m.Called(handle).Error(1) } -func (m *AmqpSenderMock) Send(ctx context.Context, msg *amqp.Message, opts *amqp.SendOptions) error { - args := m.Called(ctx, msg, opts) - return args.Error(0) -} - -func (m *AmqpSenderMock) Close(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) -} +var _ ConnectionPool = (*connectionPoolMock)(nil) func Test_NewAmqpMessagingApi(t *testing.T) { - _, err := NewAmqpApi(AmqpConfig{URL: "not-handled-protocol://localhost:5672"}) - assert.EqualError(t, err, "connect to broker: dial connection to broker: unsupported scheme \"not-handled-protocol\"") + _, err := NewAmqpApi( + AmqpConnectionPoolConfig{ + Parameters: AmqpConnectionConfig{BrokerUrl: "not-handled-protocol://localhost:5672"}, + PoolSize: 1, + }) + assert.EqualError(t, err, "new amqp connection pool: initialize connections: new connection: failed to connect to amqp broker: internal connection connect: dial: unsupported scheme \"not-handled-protocol\"") } func Test_AmqpMessagingApi_Send(t *testing.T) { @@ -63,121 +51,354 @@ func Test_AmqpMessagingApi_Send(t *testing.T) { t.Run("Missing topic prefix", func(t *testing.T) { defer solaceContainer.StopOnError() - api, err := NewAmqpApi(AmqpConfig{URL: solaceContainer.AmqpConnectionString}) + api, err := NewAmqpApi(AmqpConnectionPoolConfig{ + Parameters: AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString}, + PoolSize: 1, + }) assert.NoError(t, err) err = api.Send(ctx, "topic-name", []byte{}, "application/json", make(map[string]any)) - assert.EqualError(t, err, "topic \"topic-name\" name lacks mandatory prefix \"topic://\"") + assert.EqualError(t, err, "send: topic \"topic-name\" name lacks mandatory prefix \"topic://\"") }) - t.Run("Close connection without errors", func(t *testing.T) { + t.Run("send successfully", func(t *testing.T) { defer solaceContainer.StopOnError() // Initialize the solace queue topicSubscriptionTopicPattern := "auditlog/>" - queueName := "close-connection-without-error" + queueName := "send-successfully" assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName)) assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern)) - topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-close-connection") + topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-send-successfully") assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName)) - api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}} - err := api.connect() + api, err := NewAmqpApi(AmqpConnectionPoolConfig{ + Parameters: AmqpConnectionConfig{BrokerUrl: solaceContainer.AmqpConnectionString}, + PoolSize: 1, + }) assert.NoError(t, err) + data := []byte("data") + applicationProperties := make(map[string]interface{}) + applicationProperties["key"] = "value" + + err = api.Send(ctx, topicName, data, "application/json", applicationProperties) + assert.NoError(t, err) + + message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true) + assert.NoError(t, err) + assert.Equal(t, "data", string(message.Data[0])) + assert.Equal(t, topicName, *message.Properties.To) + assert.Equal(t, "application/json", *message.Properties.ContentType) + assert.Equal(t, applicationProperties, message.ApplicationProperties) + err = api.Close(ctx) assert.NoError(t, err) }) +} - t.Run("New sender call returns error", func(t *testing.T) { - defer solaceContainer.StopOnError() +func Test_AmqpMessagingApi_Send_Special_Cases(t *testing.T) { - // Initialize the solace queue - topicSubscriptionTopicPattern := "auditlog/>" - queueName := "messaging-new-sender" - assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName)) - assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern)) - topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-no-new-sender") - assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName)) + channelReceiver := func(channel chan struct{}) <-chan struct{} { + return channel + } - api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}} - err := api.connect() + newActiveConnection := func() *AmqpConnection { + channel := make(chan struct{}) + conn := &amqpConnMock{} + conn.On("Done", mock.Anything).Return(channelReceiver(channel)) + return &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + conn: conn, + } + } + + newClosedConnection := func() *AmqpConnection { + channel := make(chan struct{}) + close(channel) + + conn := &amqpConnMock{} + conn.On("Done", mock.Anything).Return(channelReceiver(channel)) + return &AmqpConnection{ + connectionName: "test", + lock: sync.RWMutex{}, + conn: conn, + } + } + + t.Run("connection nil sender nil", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + session := &amqpSessionMock{} + session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil) + + connection := newActiveConnection() + conn := connection.conn.(*amqpConnMock) + conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil) + + pool := &connectionPoolMock{} + pool.On("GetConnection", mock.Anything).Return(connection, nil) + + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connectionPool: pool, + connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0}, + senderCache: make(map[string]*AmqpSenderSession), + } + + err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any)) assert.NoError(t, err) - expectedError := errors.New("expected error") - - // Set mock session - sessionMock := AmqpSessionMock{} - sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(nil, expectedError) - sessionMock.On("Close", mock.Anything).Return(nil) - - var amqpSession AmqpSession = &sessionMock - api.session = amqpSession - - // It's expected that the test succeeds. - // First the session is closed as it returns the expected error - // Then the retry mechanism restarts the connection and successfully sends the data - value := "test" - err = api.Send(ctx, topicName, []byte(value), "application/json", make(map[string]any)) - assert.NoError(t, err) - - // Check that the mock was called - assert.True(t, sessionMock.AssertNumberOfCalls(t, "NewSender", 1)) - assert.True(t, sessionMock.AssertNumberOfCalls(t, "Close", 1)) - - message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true) - assert.NoError(t, err) - assert.Equal(t, value, string(message.Data[0])) - assert.Equal(t, topicName, *message.Properties.To) + sender.AssertNumberOfCalls(t, "Send", 1) + session.AssertNumberOfCalls(t, "NewSender", 1) + pool.AssertNumberOfCalls(t, "GetConnection", 2) }) - t.Run("Send call on sender returns error", func(t *testing.T) { - defer solaceContainer.StopOnError() + t.Run("connection closed sender nil", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil) - // Initialize the solace queue - topicSubscriptionTopicPattern := "auditlog/>" - queueName := "messaging-sender-error" - assert.NoError(t, solaceContainer.QueueCreate(ctx, queueName)) - assert.NoError(t, solaceContainer.TopicSubscriptionCreate(ctx, queueName, topicSubscriptionTopicPattern)) - topicName := fmt.Sprintf("topic://auditlog/%s", "amqp-sender-error") - assert.NoError(t, solaceContainer.ValidateTopicName(topicSubscriptionTopicPattern, topicName)) + session := &amqpSessionMock{} + session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil) - api := &AmqpApi{config: AmqpConfig{URL: solaceContainer.AmqpConnectionString}} - err := api.connect() + connection := newActiveConnection() + conn := connection.conn.(*amqpConnMock) + conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil) + + pool := &connectionPoolMock{} + pool.On("GetConnection", mock.Anything).Return(connection, nil) + + closedConnection := newClosedConnection() + closedConnMock := closedConnection.conn.(*amqpConnMock) + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connection: closedConnection, + connectionPool: pool, + connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0}, + senderCache: make(map[string]*AmqpSenderSession), + } + + err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any)) assert.NoError(t, err) - expectedError := errors.New("expected error") + sender.AssertNumberOfCalls(t, "Send", 1) + session.AssertNumberOfCalls(t, "NewSender", 1) + pool.AssertNumberOfCalls(t, "GetConnection", 2) + closedConnMock.AssertNumberOfCalls(t, "Done", 1) + }) - // Instantiate mock sender - senderMock := AmqpSenderMock{} - senderMock.On("Send", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedError) - senderMock.On("Close", mock.Anything).Return(nil) - var amqpSender AmqpSender = &senderMock + t.Run("connection nil get connection fail", func(t *testing.T) { + var connection *AmqpConnection = nil - // Set mock session - sessionMock := AmqpSessionMock{} - sessionMock.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(&amqpSender, nil) - sessionMock.On("Close", mock.Anything).Return(nil) + pool := &connectionPoolMock{} + pool.On("GetConnection", mock.Anything).Return(connection, errors.New("connection error")) - var amqpSession AmqpSession = &sessionMock - api.session = amqpSession + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connectionPool: pool, + connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0}, + senderCache: make(map[string]*AmqpSenderSession), + } - // It's expected that the test succeeds. - // First the sender and session are closed as the sender returns the expected error - // Then the retry mechanism restarts the connection and successfully sends the data - value := "test" - err = api.Send(ctx, topicName, []byte(value), "application/json", make(map[string]any)) + err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any)) + assert.EqualError(t, err, "get connection: connection error") + + pool.AssertNumberOfCalls(t, "GetConnection", 2) + }) + + t.Run("connection active sender nil", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + session := &amqpSessionMock{} + session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, nil) + + connection := newActiveConnection() + conn := connection.conn.(*amqpConnMock) + conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil) + + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connection: connection, + senderCache: make(map[string]*AmqpSenderSession), + } + + err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any)) assert.NoError(t, err) - // Check that the mocks were called - assert.True(t, sessionMock.AssertNumberOfCalls(t, "NewSender", 1)) - assert.True(t, sessionMock.AssertNumberOfCalls(t, "Close", 1)) - assert.True(t, senderMock.AssertNumberOfCalls(t, "Send", 1)) - assert.True(t, senderMock.AssertNumberOfCalls(t, "Close", 1)) + sender.AssertNumberOfCalls(t, "Send", 1) + session.AssertNumberOfCalls(t, "NewSender", 1) + }) - message, err := solaceContainer.NextMessage(ctx, fmt.Sprintf("queue://%s", queueName), true) + t.Run("connection active new sender fail", func(t *testing.T) { + var sender *amqpSenderMock = nil + + session := &amqpSessionMock{} + session.On("NewSender", mock.Anything, mock.Anything, mock.Anything).Return(sender, errors.New("new sender error")) + session.On("Close", mock.Anything).Return(nil) + + connection := newActiveConnection() + conn := connection.conn.(*amqpConnMock) + conn.On("NewSession", mock.Anything, mock.Anything).Return(session, nil) + + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connection: connection, + senderCache: make(map[string]*AmqpSenderSession), + } + + err := amqpApi.Send(context.Background(), "topic://some-topic", []byte("data"), "application/json", make(map[string]any)) + assert.EqualError(t, err, "new sender: new internal sender: new sender error") + + session.AssertNumberOfCalls(t, "NewSender", 1) + session.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("connection active sender set", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + topic := "topic://some-topic" + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connection: newActiveConnection(), + senderCache: map[string]*AmqpSenderSession{topic: {sender: sender}}, + } + + err := amqpApi.Send(context.Background(), topic, []byte("data"), "application/json", make(map[string]any)) assert.NoError(t, err) - assert.Equal(t, value, string(message.Data[0])) - assert.Equal(t, topicName, *message.Properties.To) + + sender.AssertNumberOfCalls(t, "Send", 1) + }) + + t.Run("send fail", func(t *testing.T) { + sender := &amqpSenderMock{} + sender.On("Send", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("send error")) + + topic := "topic://some-topic" + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connection: newActiveConnection(), + senderCache: map[string]*AmqpSenderSession{topic: {sender: sender}}, + } + + err := amqpApi.Send(context.Background(), topic, []byte("data"), "application/json", make(map[string]any)) + assert.EqualError(t, err, "send: send error") + + sender.AssertNumberOfCalls(t, "Send", 1) + }) +} + +func Test_AmqpMessagingApi_Close(t *testing.T) { + + t.Run("close without cached senders", func(t *testing.T) { + pool := &connectionPoolMock{} + pool.On("Close").Return(nil) + + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connectionPool: pool, + connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0}, + senderCache: make(map[string]*AmqpSenderSession), + } + + err := amqpApi.Close(context.Background()) + assert.NoError(t, err) + + pool.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("close fail without cached senders", func(t *testing.T) { + pool := &connectionPoolMock{} + pool.On("Close").Return(errors.New("close error")) + + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connectionPool: pool, + connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0}, + senderCache: make(map[string]*AmqpSenderSession), + } + + err := amqpApi.Close(context.Background()) + assert.EqualError(t, err, "close: close pool: close error") + + pool.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("close with cached senders", func(t *testing.T) { + pool := &connectionPoolMock{} + pool.On("Close").Return(nil) + + session := &amqpSessionMock{} + session.On("Close", mock.Anything).Return(nil) + sender := &amqpSenderMock{} + sender.On("Close", mock.Anything).Return(nil) + senderSession := &AmqpSenderSession{ + session: session, + sender: sender, + } + + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connectionPool: pool, + connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0}, + senderCache: map[string]*AmqpSenderSession{"key": senderSession}, + } + + err := amqpApi.Close(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(amqpApi.senderCache)) + + pool.AssertNumberOfCalls(t, "Close", 1) + session.AssertNumberOfCalls(t, "Close", 1) + sender.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("close fail with cached senders", func(t *testing.T) { + pool := &connectionPoolMock{} + pool.On("Close").Return(nil) + + session := &amqpSessionMock{} + session.On("Close", mock.Anything).Return(nil) + sender := &amqpSenderMock{} + sender.On("Close", mock.Anything).Return(errors.New("close sender error")) + senderSession := &AmqpSenderSession{ + session: session, + sender: sender, + } + + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connectionPool: pool, + connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0}, + senderCache: map[string]*AmqpSenderSession{"key": senderSession}, + } + + err := amqpApi.Close(context.Background()) + assert.EqualError(t, err, "close: close session: close sender error") + assert.Equal(t, 0, len(amqpApi.senderCache)) + + pool.AssertNumberOfCalls(t, "Close", 1) + session.AssertNumberOfCalls(t, "Close", 1) + sender.AssertNumberOfCalls(t, "Close", 1) + }) + + t.Run("close fail", func(t *testing.T) { + pool := &connectionPoolMock{} + pool.On("Close").Return(errors.New("close pool error")) + + session := &amqpSessionMock{} + session.On("Close", mock.Anything).Return(errors.New("close session error")) + sender := &amqpSenderMock{} + sender.On("Close", mock.Anything).Return(errors.New("close sender error")) + senderSession := &AmqpSenderSession{ + session: session, + sender: sender, + } + + amqpApi := &AmqpApi{config: AmqpConnectionPoolConfig{}, + connectionPool: pool, + connectionPoolHandle: &ConnectionPoolHandle{connectionOffset: 0}, + senderCache: map[string]*AmqpSenderSession{"key": senderSession}, + } + + err := amqpApi.Close(context.Background()) + assert.EqualError(t, err, "close: close session: close sender error\nclose session error\nclose pool: close pool error") + assert.Equal(t, 0, len(amqpApi.senderCache)) + + pool.AssertNumberOfCalls(t, "Close", 1) + session.AssertNumberOfCalls(t, "Close", 1) + sender.AssertNumberOfCalls(t, "Close", 1) }) } diff --git a/audit/messaging/solace.go b/audit/messaging/solace.go index b9b6fe8..36e5922 100644 --- a/audit/messaging/solace.go +++ b/audit/messaging/solace.go @@ -18,7 +18,6 @@ import ( ) const ( - AmqpTopicPrefix = "topic://" AmqpQueuePrefix = "queue://" )