Skip to content

Instantly share code, notes, and snippets.

@miguelmota
Created June 13, 2018 22:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save miguelmota/c595892d243fd9e3c210239fb3a53310 to your computer and use it in GitHub Desktop.
Save miguelmota/c595892d243fd9e3c210239fb3a53310 to your computer and use it in GitHub Desktop.
Go AWS SQS simple queue wrapper
package queue
import (
"encoding/json"
"errors"
"log"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
)
// Service service
type Service struct {
queue *sqs.SQS
queueURL string
isDev bool
}
// NewInput input for constructor
type NewInput struct {
Region string
QueueURL string
}
// New returns new service
func New(input *NewInput) *Service {
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{Region: aws.String(input.Region)},
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
log.Fatal(err)
}
queue := sqs.New(sess)
return &Service{
queue: queue,
queueURL: input.QueueURL,
}
}
// SendMessageInput ...
type SendMessageInput struct {
GroupID string
DeduplicationID string
Attributes map[string]string
}
// SendMessageOutput ...
type SendMessageOutput struct {
MessageID string
}
// SendMessage send message to queue
func (s *Service) SendMessage(input *SendMessageInput) (*SendMessageOutput, error) {
if input == nil {
return nil, errors.New("input is required")
}
dt := "String"
// keys must not contain special characters
attrs := map[string]*sqs.MessageAttributeValue{}
for k := range input.Attributes {
v := input.Attributes[k] // must do it this way
attrs[k] = &sqs.MessageAttributeValue{
DataType: &dt,
StringValue: &v,
}
}
b, err := json.Marshal(input.Attributes)
if err != nil {
return nil, err
}
body := string(b)
output, err := s.queue.SendMessage(&sqs.SendMessageInput{
MessageAttributes: attrs,
MessageBody: &body,
MessageGroupId: &input.GroupID, // message with same group Id get processed one-by-one
QueueUrl: &s.queueURL,
MessageDeduplicationId: &input.DeduplicationID,
})
if err != nil {
return nil, err
}
log.Println(*output.MessageId)
if output.MessageId == nil {
return nil, errors.New("message ID not receieved")
}
msgID := *output.MessageId
return &SendMessageOutput{
MessageID: msgID,
}, nil
}
// ReceiveMessageOutput ...
type ReceiveMessageOutput struct {
Messages []*Message
}
// ReceiveMessage ...
func (s *Service) ReceiveMessage() (*ReceiveMessageOutput, error) {
log.Println("receive message called")
var max int64 = 10 // 1-10
timeout := int64(20)
wait := int64(0)
attr := "All"
output, err := s.queue.ReceiveMessage(&sqs.ReceiveMessageInput{
MaxNumberOfMessages: &max,
QueueUrl: &s.queueURL,
VisibilityTimeout: &timeout,
WaitTimeSeconds: &wait,
MessageAttributeNames: []*string{
&attr,
},
})
if err != nil {
return nil, err
}
log.Println("messages", output.Messages)
var messages []*Message
for _, m := range output.Messages {
var msgID string
if m.MessageId != nil {
msgID = *m.MessageId
}
var receiptHandle string
if m.ReceiptHandle != nil {
receiptHandle = *m.ReceiptHandle
}
var body string
if m.Body != nil {
body = *m.Body
}
attributes := m.MessageAttributes
msg := &Message{
MessageID: msgID,
ReceiptHandle: receiptHandle,
Body: body,
Attributes: attributes,
service: s,
}
messages = append(messages, msg)
}
return &ReceiveMessageOutput{
Messages: messages,
}, nil
}
// DeleteMessageInput ...
type DeleteMessageInput struct {
ReceiptHandle string
}
// DeleteMessageOutput ...
type DeleteMessageOutput struct {
Response string
}
// DeleteMessage ...
func (s *Service) DeleteMessage(input *DeleteMessageInput) (*DeleteMessageOutput, error) {
if input == nil {
return nil, errors.New("input is required")
}
output, err := s.queue.DeleteMessage(&sqs.DeleteMessageInput{
QueueUrl: &s.queueURL,
ReceiptHandle: &input.ReceiptHandle,
})
if err != nil {
return nil, err
}
return &DeleteMessageOutput{
Response: output.String(),
}, nil
}
// ReadQueueURL ...
func (s *Service) ReadQueueURL() string {
return s.queueURL
}
// Message ...
type Message struct {
MessageID string
ReceiptHandle string
Body string
Attributes map[string]*sqs.MessageAttributeValue
service *Service
}
// SendHeartbeat extends visiblity timeout for message
func (s *Message) SendHeartbeat(t time.Duration) error {
receiptHandle := s.ReceiptHandle
timeout := int64(t.Seconds())
queueURL := s.service.ReadQueueURL()
_, err := s.service.queue.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{
QueueUrl: &queueURL,
ReceiptHandle: &receiptHandle,
VisibilityTimeout: &timeout,
})
if err != nil {
return err
}
return nil
}
package queue
import (
"fmt"
"os"
"testing"
"time"
uuid "github.com/satori/go.uuid"
)
func newUUID() uuid.UUID {
id, _ := uuid.NewV4()
return id
}
func TestFlow(t *testing.T) {
t.Parallel()
queue := New(&NewInput{
Region: os.Getenv("AWS_SQS_REGION"),
QueueURL: os.Getenv("AWS_SQS_QUEUE_URL"),
})
id := newUUID().String()
sendOutput, err := queue.SendMessage(&SendMessageInput{
GroupID: fmt.Sprintf("id/%s", id),
DeduplicationID: id,
Attributes: map[string]string{
"id": id,
"foo": "bar",
},
})
if err != nil {
t.Fatal(err)
}
if sendOutput.MessageID == "" {
t.Fatal("expected message ID")
}
t.Log("message ID:", sendOutput.MessageID)
receiveOutput, err := queue.ReceiveMessage()
if len(receiveOutput.Messages) == 0 {
t.Fatal("expected messages")
}
msg := receiveOutput.Messages[0]
t.Log("attrs:", msg.Attributes)
t.Log("receipt handle:", msg.ReceiptHandle)
err = msg.SendHeartbeat(10 * time.Second)
if err != nil {
t.Fatal(err)
}
deleteOutput, err := queue.DeleteMessage(&DeleteMessageInput{
ReceiptHandle: msg.ReceiptHandle,
})
if deleteOutput.Response == "" {
t.Fatal("expected response messages")
}
t.Log("delete response:", deleteOutput.Response)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment