package messaging import ( "context" "errors" "fmt" "github.com/Azure/go-amqp" "log/slog" "sync" "time" ) var ConnectionClosedError = errors.New("amqp connection is closed") type AmqpConnection struct { connectionName string lock sync.RWMutex brokerUrl string username string password string conn *amqp.Conn } func NewAmqpConnection(config *AmqpConnectionConfig, connectionName string) *AmqpConnection { return &AmqpConnection{ connectionName: connectionName, lock: sync.RWMutex{}, brokerUrl: config.BrokerUrl, username: config.Username, password: config.Password, } } func (c *AmqpConnection) NewReceiver(ctx context.Context, source string) (*AmqpReceiverSession, error) { if c.conn == nil { return nil, errors.New("connection is not initialized") } if c.internalIsClosed() { return nil, ConnectionClosedError } c.lock.RLock() defer c.lock.RUnlock() // new session session, err := c.conn.NewSession(ctx, nil) if err != nil { return nil, fmt.Errorf("new session: %w", err) } // new receiver receiver, err := session.NewReceiver(ctx, source, nil) if err != nil { err = fmt.Errorf("new internal receiver: %w", err) closeErr := session.Close(ctx) if closeErr != nil { return nil, errors.Join(fmt.Errorf("close session: %w", err), err) } return nil, err } return &AmqpReceiverSession{session, receiver}, nil } func (c *AmqpConnection) NewSender(ctx context.Context, topic string) (*AmqpSenderSession, error) { if c.conn == nil { return nil, errors.New("connection is not initialized") } if c.internalIsClosed() { return nil, ConnectionClosedError } c.lock.RLock() defer c.lock.RUnlock() // new session newSession, err := c.conn.NewSession(ctx, nil) if err != nil { return nil, fmt.Errorf("new session: %w", err) } // new sender newSender, err := newSession.NewSender(ctx, topic, nil) if err != nil { err = fmt.Errorf("new internal sender: %w", err) closeErr := newSession.Close(ctx) if closeErr != nil { return nil, errors.Join(fmt.Errorf("close session: %w", err), err) } return nil, err } return &AmqpSenderSession{newSession, newSender}, nil } func (c *AmqpConnection) ResetConnectionAndRetryIfErrorWithReturnValue(opName string, fn func(ctx context.Context) (any, error)) (any, error) { ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) result, err := fn(ctx) cancelFn() if err != nil { slog.Info(fmt.Sprintf("amqp connection: %s", opName), slog.Any("connection", c.connectionName), slog.Any("err", err)) err := c.ResetConnection(context.Background()) if err != nil { return nil, fmt.Errorf("reset connection: %w", err) } newCtx, closeFn := context.WithTimeout(context.Background(), 10*time.Second) defer closeFn() // Retry return fn(newCtx) } return result, nil } func (c *AmqpConnection) ResetConnection(ctx context.Context) error { c.lock.Lock() defer c.lock.Unlock() subCtx, cancel := context.WithTimeout(ctx, disconnectionTimeoutSeconds*time.Second) err := c.internalClose(subCtx) cancel() if err != nil { slog.Warn("amqp connection: reset: failed to close amqp connection", slog.Any("err", err)) } subCtx, cancel = context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second) err = c.internalConnect(subCtx) cancel() if err != nil { slog.Warn("amqp connection: reset: failed to connect to amqp server, retry..", slog.Any("err", err)) subCtx, cancel = context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second) err = c.internalConnect(subCtx) cancel() if err != nil { return fmt.Errorf("connect: %w", err) } } return nil } func As[T any](value any, err error) (*T, error) { if err != nil { return nil, err } if value == nil { return nil, nil } castedValue, isType := value.(*T) if !isType { return nil, fmt.Errorf("could not cast value: %T", value) } return castedValue, nil } func (c *AmqpConnection) Connect(ctx context.Context) error { c.lock.Lock() defer c.lock.Unlock() subCtx, cancel := context.WithTimeout(ctx, connectionTimeoutSeconds*time.Second) defer cancel() return c.internalConnect(subCtx) } func (c *AmqpConnection) internalConnect(ctx context.Context) error { if c.conn == nil { // Set credentials if specified auth := amqp.SASLTypeAnonymous() if c.username != "" && c.password != "" { auth = amqp.SASLTypePlain(c.username, c.password) } else { slog.Debug("amqp connection: connect: using anonymous messaging") } options := &amqp.ConnOptions{ SASLType: auth, } // Initialize connection conn, err := amqp.Dial(ctx, c.brokerUrl, options) if err != nil { return fmt.Errorf("dial: %w", err) } c.conn = conn } return nil } func (c *AmqpConnection) Close(ctx context.Context) error { c.lock.Lock() defer c.lock.Unlock() subCtx, cancel := context.WithTimeout(ctx, disconnectionTimeoutSeconds*time.Second) defer cancel() return c.internalClose(subCtx) } func (c *AmqpConnection) internalClose(ctx context.Context) error { if c.conn != nil { closeErrors := make([]error, 0) closeErrors = c.internalCloseConnection(closeErrors) // return errors or nil if len(closeErrors) > 0 { return errors.Join(closeErrors...) } } return nil } func (c *AmqpConnection) internalCloseConnection(closeErrors []error) []error { err := c.conn.Close() if err != nil { closeErrors = append(closeErrors, err) } return closeErrors } func (c *AmqpConnection) IsClosed() bool { c.lock.RLock() defer c.lock.RUnlock() return c.internalIsClosed() } func (c *AmqpConnection) internalIsClosed() bool { if c.conn == nil { return true } select { case <-c.conn.Done(): return true default: return false } }