Created
January 21, 2016 18:00
-
-
Save jharlap/839a85f317a9a9ed8947 to your computer and use it in GitHub Desktop.
SQS Producer writes each line of a text file as a message to an SQS queue.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"log" | |
"os" | |
"fmt" | |
"bufio" | |
"flag" | |
"github.com/aws/aws-sdk-go/aws" | |
"github.com/aws/aws-sdk-go/aws/awserr" | |
"github.com/aws/aws-sdk-go/service/sqs" | |
"sync" | |
"time" | |
"sync/atomic" | |
) | |
// getQueue gets the URL of the SQS queue | |
func getQueue(svc *sqs.SQS, name string) (*string, error) { | |
// if the queue already exists just get the url | |
resp, err := svc.GetQueueUrl(&sqs.GetQueueUrlInput{ | |
QueueName: aws.String(name), | |
}) | |
if err != nil { | |
if awsErr, ok := err.(awserr.Error); ok { | |
// Generic AWS error with Code, Message, and original error (if any) | |
log.Println("AWS Error:", awsErr.Code(), awsErr.Message(), awsErr.OrigErr()) | |
if reqErr, ok := err.(awserr.RequestFailure); ok { | |
// A service error occurred | |
log.Println("AWS Service Error:", reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID()) | |
} | |
} | |
return nil, err | |
} | |
return resp.QueueUrl, nil | |
} | |
func startSenderWorkerPool(n int, wg *sync.WaitGroup, svc *sqs.SQS, queueUrl *string) chan<- []string { | |
q := make(chan []string) | |
for i:=0; i<n; i++ { | |
go func() { | |
for b := range q { | |
wg.Add(1) | |
err := sendMessageBatch(svc, queueUrl, b) | |
if err != nil { | |
log.Printf("Error sending message batch: %s", err) | |
} | |
wg.Done() | |
} | |
}() | |
} | |
return q | |
} | |
var progress int64 | |
func progressAdd(n int64) { | |
atomic.AddInt64(&progress, n) | |
} | |
func startProgressLogger() { | |
go func() { | |
var lastProgress int64 | |
lastTime := time.Now() | |
for { | |
<-time.After(time.Second) | |
curProgress := atomic.LoadInt64(&progress) | |
curTime := time.Now() | |
rate := (float64)(curProgress - lastProgress) / curTime.Sub(lastTime).Seconds() | |
lastProgress = curProgress | |
lastTime = curTime | |
log.Printf("Sent %d messages (%f msgs/sec)", curProgress, rate) | |
} | |
}() | |
} | |
func sendMessageBatch(svc *sqs.SQS, queueUrl *string, messages []string) error { | |
p := &sqs.SendMessageBatchInput{QueueUrl: queueUrl} | |
for i, m := range messages { | |
p.Entries = append(p.Entries, &sqs.SendMessageBatchRequestEntry{ | |
Id: aws.String(fmt.Sprintf("%d", i)), | |
MessageBody: aws.String(m), | |
}) | |
} | |
resp, err := svc.SendMessageBatch(p) | |
if err != nil { | |
return err | |
} | |
progressAdd((int64)(len(messages) - len(resp.Failed))) | |
if len(resp.Failed) > 0 { | |
var failedMessages string | |
for _, f := range resp.Failed { | |
failedMessages += aws.StringValue(f.Message) + "\n" | |
} | |
return fmt.Errorf("Failed to send %d messages:\n%s", len(resp.Failed), failedMessages) | |
} | |
return nil | |
} | |
func main() { | |
var queueName, messagesFilename string | |
var workers int | |
flag.StringVar(&queueName, "queue", "", "Queue name") | |
flag.StringVar(&messagesFilename, "messages", "", "File containing one message body per line") | |
flag.IntVar(&workers, "workers", 10, "Number of concurrent sender workers") | |
flag.Parse() | |
if queueName == "" || messagesFilename == "" { | |
fmt.Println("You must specify both the queue and messages file") | |
flag.Usage() | |
os.Exit(1) | |
} | |
svc := sqs.New(&aws.Config{ | |
Region: aws.String("us-east-1"), | |
}) | |
queueUrl, err := getQueue(svc, queueName) | |
if err != nil { | |
log.Println("Could not get queue:", err) | |
os.Exit(1) | |
} | |
log.Println("Using queueUrl:", *queueUrl) | |
messagesFile, err := os.Open(messagesFilename) | |
if err != nil { | |
log.Fatalf("Error reading messages file: %s", err) | |
} | |
start := time.Now() | |
wg := &sync.WaitGroup{} | |
batchQueue := startSenderWorkerPool(workers, wg, svc, queueUrl) | |
go startProgressLogger() | |
var count int | |
scanner := bufio.NewScanner(messagesFile) | |
var messages []string | |
for scanner.Scan() { | |
count++ | |
messages = append(messages, scanner.Text()) | |
if len(messages) == 10 { | |
batchQueue <- messages | |
messages = nil | |
} | |
} | |
if len(messages) > 0 { | |
batchQueue <- messages | |
messages = nil | |
} | |
wg.Wait() | |
duration := time.Since(start).Seconds() | |
log.Printf("Sent %d messages in %f secs (%f msgs/sec)", count, duration, (float64)(count)/duration) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment