package messaging import ( "context" "errors" "fmt" "github.com/Azure/go-amqp" "log/slog" "sync" "time" ) var ConnectionClosedError = errors.New("amqp connection is closed") type AmqpConnection struct { connectionName string lock sync.RWMutex brokerUrl string username string password string conn amqpConn dialer amqpDial } // amqpConn is an abstraction of amqp.Conn type amqpConn interface { NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) Close() error Done() <-chan struct{} } type defaultAmqpConn struct { conn *amqp.Conn } func newDefaultAmqpConn(conn *amqp.Conn) *defaultAmqpConn { return &defaultAmqpConn{ conn: conn, } } func (d defaultAmqpConn) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpSession, error) { session, err := d.conn.NewSession(ctx, opts) if err != nil { return nil, err } return newDefaultAmqpSession(session), nil } func (d defaultAmqpConn) Close() error { return d.conn.Close() } func (d defaultAmqpConn) Done() <-chan struct{} { return d.conn.Done() } var _ amqpConn = (*defaultAmqpConn)(nil) type amqpDial interface { Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) } type amqpSession interface { NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) Close(ctx context.Context) error } type defaultAmqpSession struct { session *amqp.Session } func newDefaultAmqpSession(session *amqp.Session) *defaultAmqpSession { return &defaultAmqpSession{ session: session, } } func (s *defaultAmqpSession) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpSender, error) { return s.session.NewSender(ctx, target, opts) } func (s *defaultAmqpSession) Close(ctx context.Context) error { return s.session.Close(ctx) } var _ amqpSession = (*defaultAmqpSession)(nil) type defaultAmqpDialer struct{} func (d *defaultAmqpDialer) Dial(ctx context.Context, addr string, opts *amqp.ConnOptions) (amqpConn, error) { dial, err := amqp.Dial(ctx, addr, opts) if err != nil { return nil, err } return newDefaultAmqpConn(dial), nil } var _ amqpDial = (*defaultAmqpDialer)(nil) func NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection { return &AmqpConnection{ connectionName: connectionName, lock: sync.RWMutex{}, brokerUrl: config.BrokerUrl, username: config.Username, password: config.Password, dialer: &defaultAmqpDialer{}, } } func (c *AmqpConnection) NewSender(ctx context.Context, topic string) (*AmqpSenderSession, error) { if c.conn == nil { return nil, errors.New("connection is not initialized") } if c.internalIsClosed() { return nil, ConnectionClosedError } c.lock.RLock() defer c.lock.RUnlock() // new session newSession, err := c.conn.NewSession(ctx, nil) if err != nil { return nil, fmt.Errorf("new session: %w", err) } // new sender newSender, err := newSession.NewSender(ctx, topic, nil) if err != nil { err = fmt.Errorf("new internal sender: %w", err) closeErr := newSession.Close(ctx) if closeErr != nil { return nil, errors.Join(err, fmt.Errorf("close session: %w", closeErr)) } return nil, err } return &AmqpSenderSession{newSession, newSender}, nil } func (c *AmqpConnection) ResetConnectionAndRetryIfErrorWithReturnValue(opName string, fn func(ctx context.Context) (any, error)) (any, error) { ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) result, err := fn(ctx) cancelFn() if err != nil { slog.Info(fmt.Sprintf("amqp connection: %s", opName), slog.Any("connection", c.connectionName), slog.Any("err", err)) err := c.ResetConnection(context.Background()) if err != nil { return nil, fmt.Errorf("reset connection: %w", err) } newCtx, closeFn := context.WithTimeout(context.Background(), 10*time.Second) defer closeFn() // Retry return fn(newCtx) } return result, nil } func (c *AmqpConnection) ResetConnection(ctx context.Context) error { c.lock.Lock() defer c.lock.Unlock() err := c.internalClose() if err != nil { slog.Warn("amqp connection: reset: failed to close amqp connection", slog.Any("err", err)) } subCtx, cancel := context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second) err = c.internalConnect(subCtx) cancel() if err != nil { slog.Warn("amqp connection: reset: failed to connect to amqp server, retry..", slog.Any("err", err)) subCtx, cancel = context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second) err = c.internalConnect(subCtx) cancel() if err != nil { return fmt.Errorf("connect: %w", err) } } return nil } func As[T any](value any, err error) (*T, error) { if err != nil { return nil, err } if value == nil { return nil, nil } castedValue, isType := value.(*T) if !isType { return nil, fmt.Errorf("could not cast value: %T", value) } return castedValue, nil } func (c *AmqpConnection) Connect() error { c.lock.Lock() defer c.lock.Unlock() subCtx, cancel := context.WithTimeout(context.Background(), connectionTimeoutSeconds*time.Second) defer cancel() if err := c.internalConnect(subCtx); err != nil { return fmt.Errorf("internal connection connect: %w", err) } return nil } func (c *AmqpConnection) internalConnect(ctx context.Context) error { if c.conn == nil { // Set credentials if specified auth := amqp.SASLTypeAnonymous() if c.username != "" && c.password != "" { auth = amqp.SASLTypePlain(c.username, c.password) } else { slog.Debug("amqp connection: connect: using anonymous messaging") } options := &amqp.ConnOptions{ SASLType: auth, } // Initialize connection conn, err := c.dialer.Dial(ctx, c.brokerUrl, options) if err != nil { return fmt.Errorf("dial: %w", err) } c.conn = conn } return nil } func (c *AmqpConnection) Close() error { c.lock.Lock() defer c.lock.Unlock() if err := c.internalClose(); err != nil { return fmt.Errorf("close: %w", err) } return nil } func (c *AmqpConnection) internalClose() error { if c.conn != nil { if err := c.conn.Close(); err != nil { return fmt.Errorf("internal connection close: %w", err) } c.conn = nil } return nil } func (c *AmqpConnection) IsClosed() bool { c.lock.RLock() defer c.lock.RUnlock() return c.internalIsClosed() } func (c *AmqpConnection) internalIsClosed() bool { if c.conn == nil { return true } select { case <-c.conn.Done(): return true default: return false } }