package api import ( "context" "errors" "fmt" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" "testing" "time" "dev.azure.com/schwarzit/schwarzit.stackit-core-platform/common-audit.git/audit/messaging" auditV1 "dev.azure.com/schwarzit/schwarzit.stackit-core-platform/common-audit.git/gen/go/audit/v1" "github.com/bufbuild/protovalidate-go" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" ) type MessagingApiMock struct { mock.Mock } func (m *MessagingApiMock) Send(ctx context.Context, topic string, data []byte, contentType string) error { args := m.Called(ctx, topic, data, contentType) return args.Error(0) } type ProtobufValidatorMock struct { mock.Mock } func (m *ProtobufValidatorMock) Validate(msg proto.Message) error { args := m.Called(msg) return args.Error(0) } type TopicNameResolverMock struct { mock.Mock } func (m *TopicNameResolverMock) Resolve(routingIdentifier *RoutingIdentifier) (string, error) { args := m.Called(routingIdentifier) return args.String(0), args.Error(1) } func NewValidator(t *testing.T) ProtobufValidator { validator, err := protovalidate.New() var protoValidator ProtobufValidator = validator assert.NoError(t, err) return protoValidator } func Test_ValidateAndSerializePartially_EventNil(t *testing.T) { validator := NewValidator(t) _, err := validateAndSerializePartially( &validator, nil, auditV1.Visibility_VISIBILITY_PUBLIC, nil, nil) assert.ErrorIs(t, err, ErrEventNil) } func Test_ValidateAndSerializePartially_AuditEventSequenceNumber(t *testing.T) { validator := NewValidator(t) t.Run("Sequence number too low", func(t *testing.T) { event, routingIdentifier, objectIdentifier := NewOrganizationAuditEvent(nil) event.SequenceNumber = wrapperspb.Int64(-2) _, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.EqualError(t, err, "validation error:\n - sequence_number: value must be greater than or equal to -1 [int64.gte]") }) t.Run("Sequence number is minimum", func(t *testing.T) { event, routingIdentifier, objectIdentifier := NewOrganizationAuditEvent(nil) event.SequenceNumber = wrapperspb.Int64(-1) e, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.NoError(t, err) validateSequenceNumber(t, e, -1) }) t.Run("Sequence number is default", func(t *testing.T) { event, routingIdentifier, objectIdentifier := NewOrganizationAuditEvent(nil) event.SequenceNumber = wrapperspb.Int64(0) e, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.NoError(t, err) validateSequenceNumber(t, e, 0) }) t.Run("Sequence number is greater than default", func(t *testing.T) { event, routingIdentifier, objectIdentifier := NewOrganizationAuditEvent(nil) event.SequenceNumber = wrapperspb.Int64(1) e, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.NoError(t, err) validateSequenceNumber(t, e, 1) }) t.Run("Sequence number not set", func(t *testing.T) { event := &auditV1.AuditEvent{ EventName: "ORGANIZATION_CREATED", EventTimeStamp: timestamppb.New(time.Now()), EventTrigger: auditV1.EventTrigger_EVENT_TRIGGER_EVENT, Initiator: &auditV1.Principal{ Id: uuid.NewString(), }, } identifier := uuid.New() routingIdentifier := &RoutingIdentifier{ Identifier: identifier, Type: RoutingIdentifierTypeOrganization, } objectIdentifier := &auditV1.ObjectIdentifier{ Identifier: identifier.String(), Type: auditV1.ObjectType_OBJECT_TYPE_ORGANIZATION, } _, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.EqualError(t, err, "validation error:\n - sequence_number: value is required [required]") }) } func validateSequenceNumber(t *testing.T, event *auditV1.RoutableAuditEvent, expectedNumber int64) { switch data := event.GetData().(type) { case *auditV1.RoutableAuditEvent_UnencryptedData: var auditEvent auditV1.AuditEvent assert.NoError(t, proto.Unmarshal(data.UnencryptedData.Data, &auditEvent)) assert.Equal(t, expectedNumber, auditEvent.SequenceNumber.Value) default: assert.Fail(t, "expected unencrypted data") } } func Test_ValidateAndSerializePartially_AuditEventValidationFailed(t *testing.T) { validator := NewValidator(t) event, routingIdentifier, objectIdentifier := NewOrganizationAuditEvent(nil) event.EventName = "" _, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.EqualError(t, err, "validation error:\n - event_name: value is required [required]") } func Test_ValidateAndSerializePartially_RoutableEventValidationFailed(t *testing.T) { validator := NewValidator(t) event, routingIdentifier, objectIdentifier := NewOrganizationAuditEvent(nil) _, err := validateAndSerializePartially(&validator, event, 3, routingIdentifier, objectIdentifier) assert.EqualError(t, err, "validation error:\n - visibility: value must be one of the defined enum values [enum.defined_only]") } func Test_ValidateAndSerializePartially_CheckVisibility(t *testing.T) { validator := NewValidator(t) event, routingIdentifier, objectIdentifier := NewOrganizationAuditEvent(nil) t.Run("Visibility public - object identifier nil - routing identifier nil", func(t *testing.T) { _, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, nil, nil) assert.ErrorIs(t, err, ErrObjectIdentifierVisibilityMismatch) }) t.Run("Visibility public - object identifier nil - routing identifier set", func(t *testing.T) { _, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, nil) assert.ErrorIs(t, err, ErrObjectIdentifierVisibilityMismatch) }) t.Run("Visibility public - object identifier set - routing identifier nil", func(t *testing.T) { _, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, nil, objectIdentifier) assert.ErrorIs(t, err, ErrRoutableIdentifierMissing) }) t.Run("Visibility public - object identifier set - routing identifier set", func(t *testing.T) { routableEvent, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.NoError(t, err) assert.NotNil(t, routableEvent) }) t.Run("Visibility private - object identifier nil - routing identifier nil", func(t *testing.T) { routableEvent, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PRIVATE, nil, nil) assert.NoError(t, err) assert.NotNil(t, routableEvent) }) t.Run("Visibility private - object identifier nil - routing identifier set", func(t *testing.T) { routableEvent, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PRIVATE, routingIdentifier, nil) assert.NoError(t, err) assert.NotNil(t, routableEvent) }) t.Run("Visibility private - object identifier set - routing identifier nil", func(t *testing.T) { routableEvent, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PRIVATE, nil, objectIdentifier) assert.NoError(t, err) assert.NotNil(t, routableEvent) }) t.Run("Visibility private - object identifier set - routing identifier set", func(t *testing.T) { routableEvent, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PRIVATE, routingIdentifier, objectIdentifier) assert.NoError(t, err) assert.NotNil(t, routableEvent) }) } func Test_ValidateAndSerializePartially_IdentifierTypeMismatch(t *testing.T) { validator := NewValidator(t) event, routingIdentifier, objectIdentifier := NewFolderAuditEvent(nil) routingIdentifier.Type = RoutingIdentifierTypeProject _, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.ErrorIs(t, err, ErrRoutableIdentifierTypeMismatch) } func Test_ValidateAndSerializePartially_IdentifierMismatch(t *testing.T) { validator := NewValidator(t) event, routingIdentifier, objectIdentifier := NewProjectAuditEvent(nil) routingIdentifier.Identifier = uuid.New() _, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, routingIdentifier, objectIdentifier) assert.ErrorIs(t, err, ErrRoutableIdentifierMismatch) } func Test_ValidateAndSerializePartially_SystemEvent(t *testing.T) { validator := NewValidator(t) event := NewSystemAuditEvent(nil) routableEvent, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PRIVATE, nil, nil) assert.NoError(t, err) switch reference := routableEvent.ResourceReference.(type) { case *auditV1.RoutableAuditEvent_ObjectName: assert.Equal(t, auditV1.ObjectName_OBJECT_NAME_SYSTEM, reference.ObjectName) default: assert.Fail(t, "unexpected resource reference") } } func Test_SerializeToProtobufMessage(t *testing.T) { validator := NewValidator(t) // Create test data event, identifier, objectIdentifier := NewOrganizationAuditEventWithDetails() // Serialize to routable event routableEvent, err := validateAndSerializePartially( &validator, event, auditV1.Visibility_VISIBILITY_PUBLIC, identifier, objectIdentifier) assert.NoError(t, err) // Serialize to protobuf message serializedPayload, err := serializeToProtobufMessage(routableEvent) assert.NoError(t, err) assert.Equal(t, serializedPayload.GetContentType(), ContentTypeProtobuf) // Deserialize var deserializedEvent auditV1.ProtobufMessage assert.NoError(t, proto.Unmarshal(serializedPayload.GetPayload(), &deserializedEvent)) expectedProtobufType := fmt.Sprintf("%v", routableEvent.ProtoReflect().Descriptor().FullName()) assert.Equal(t, expectedProtobufType, deserializedEvent.ProtobufType) var deserializedRoutableEvent auditV1.RoutableAuditEvent assert.NoError(t, proto.Unmarshal(deserializedEvent.Value, &deserializedRoutableEvent)) assert.True(t, proto.Equal(routableEvent, &deserializedRoutableEvent)) } func Test_Send_TopicNameResolverNil(t *testing.T) { err := send(nil, nil, context.Background(), nil, nil) assert.ErrorIs(t, err, ErrTopicNameResolverNil) } func Test_Send_TopicNameResolutionError(t *testing.T) { expectedError := errors.New("expected error") topicNameResolverMock := TopicNameResolverMock{} topicNameResolverMock.On("Resolve", mock.Anything).Return("topic", expectedError) var topicNameResolver TopicNameResolver = &topicNameResolverMock var serializedPayload SerializedPayload = &routablePayload{} var messagingApi messaging.MessagingApi = &messaging.AmqpMessagingApi{} err := send(&topicNameResolver, &messagingApi, context.Background(), nil, &serializedPayload) assert.ErrorIs(t, err, expectedError) } func Test_Send_MessagingApiNil(t *testing.T) { var topicNameResolver TopicNameResolver = &LegacyTopicNameResolver{topicName: "test"} err := send(&topicNameResolver, nil, context.Background(), nil, nil) assert.ErrorIs(t, err, ErrMessagingApiNil) } func Test_Send_SerializedPayloadNil(t *testing.T) { var topicNameResolver TopicNameResolver = &LegacyTopicNameResolver{topicName: "test"} var messagingApi messaging.MessagingApi = &messaging.AmqpMessagingApi{} err := send(&topicNameResolver, &messagingApi, context.Background(), nil, nil) assert.ErrorIs(t, err, ErrSerializedPayloadNil) } func Test_Send(t *testing.T) { topicNameResolverMock := TopicNameResolverMock{} topicNameResolverMock.On("Resolve", mock.Anything).Return("topic", nil) var topicNameResolver TopicNameResolver = &topicNameResolverMock var serializedPayload SerializedPayload = &routablePayload{} messagingApiMock := MessagingApiMock{} messagingApiMock.On("Send", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) var messagingApi messaging.MessagingApi = &messagingApiMock assert.NoError(t, send(&topicNameResolver, &messagingApi, context.Background(), nil, &serializedPayload)) assert.True(t, messagingApiMock.AssertNumberOfCalls(t, "Send", 1)) }