Skip to content

Instantly share code, notes, and snippets.

@yifanes
Forked from s0j0hn/rabbitmq.go
Created November 3, 2020 06:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yifanes/2f5ba54c3fd45fc6bb2c682b54961976 to your computer and use it in GitHub Desktop.
Save yifanes/2f5ba54c3fd45fc6bb2c682b54961976 to your computer and use it in GitHub Desktop.
package rabbitmq
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/rs/zerolog"
"github.com/streadway/amqp"
"os"
"runtime"
"sync"
"time"
)
var (
errDisconnected = errors.New("disconnected from rabbitmq, trying to reconnect")
)
const (
// When reconnecting to the server after connection failure
reconnectDelay = 5 * time.Second
)
// AMQPClient holds necessery information for rabbitMQ
type AMQPClient struct {
pushQueue string
listenQueue string
logger zerolog.Logger
connection *amqp.Connection
channel *amqp.Channel
notifyClose chan *amqp.Error
notifyConfirm chan amqp.Confirmation
isConnected bool
alive bool
threads int
wg *sync.WaitGroup
}
// NewAMQPClient is a constructor that takes address, push and listen queue names, logger, and a channel that will notify rabbitmq client on server shutdown. We calculate the number of threads, create the client, and start the connection process. Connect method connects to the rabbitmq server and creates push/listen channels if they don't exist.
func NewAMQPClient(listenQueue, pushQueue, addr string, l zerolog.Logger, goChan chan os.Signal) *AMQPClient {
threads := runtime.GOMAXPROCS(0)
if numCPU := runtime.NumCPU(); numCPU > threads {
threads = numCPU
}
client := AMQPClient{
listenQueue: listenQueue,
logger: l,
threads: threads,
pushQueue: pushQueue,
alive: true,
wg: &sync.WaitGroup{},
}
client.wg.Add(threads)
go client.handleReconnect(addr, goChan)
<-goChan
return &client
}
// handleReconnect will wait for a connection error on
// notifyClose, and then continuously attempt to reconnect.
func (c *AMQPClient) handleReconnect(addr string, clientChannel chan os.Signal) {
for c.alive {
var retryCount int
fmt.Printf("Attempting to connect to rabbitMQ: %s\n", addr)
c.isConnected = false
t := time.Now()
for !c.connect(addr) {
if !c.alive {
return
}
select {
case <-clientChannel:
c.logger.Printf("Recieved something into clinet channel")
return
case <-time.After(reconnectDelay + time.Duration(retryCount)*time.Second):
fmt.Printf("disconnected from rabbitMQ and failed to connect")
retryCount++
}
}
c.logger.Printf("Connected to rabbitMQ in: %vms", time.Since(t).Milliseconds())
//select {
// case <-clientChannel:
// c.logger.Printf("Recieved something into clinet channel")
// return
// case <-c.notifyClose:
// c.logger.Printf("Recieved notifyClose")
//}
clientChannel <- os.Kill
return
}
}
// connect will make a single attempt to connect to
// RabbitMq. It returns the success of the attempt.
func (c *AMQPClient) connect(addr string) bool {
conn, err := amqp.Dial(addr)
if err != nil {
c.logger.Printf("failed to dial rabbitMQ server: %v", err)
return false
}
ch, err := conn.Channel()
if err != nil {
c.logger.Printf("failed connecting to channel: %v", err)
return false
}
err = ch.Confirm(false)
if err != nil {
c.logger.Printf("failed to confirm channel: %v", err)
return false
}
_, err = ch.QueueDeclare(
c.listenQueue,
true, // Durable
false, // Delete when unused
false, // Exclusive
false, // No-wait
nil, // Arguments
)
if err != nil {
c.logger.Printf("failed to declare listen queue: %v", err)
return false
}
_, err = ch.QueueDeclare(
c.pushQueue,
true, // Durable
false, // Delete when unused
false, // Exclusive
false, // No-wait
nil, // Arguments
)
if err != nil {
c.logger.Printf("failed to declare push queue: %v", err)
return false
}
c.changeConnection(conn, ch)
c.isConnected = true
return true
}
// changeConnection takes a new connection to the queue,
// and updates the channel listeners to reflect this.
func (c *AMQPClient) changeConnection(connection *amqp.Connection, channel *amqp.Channel) {
c.connection = connection
c.channel = channel
c.notifyClose = make(chan *amqp.Error)
c.notifyConfirm = make(chan amqp.Confirmation)
c.channel.NotifyClose(c.notifyClose)
c.channel.NotifyPublish(c.notifyConfirm)
}
// Push will push data onto the queue, and wait for a confirmation.
// If no confirms are received until within the resendTimeout,
// it continuously resends messages until a confirmation is received.
// This will block until the server sends a confirm.
func (c *AMQPClient) Push(data []byte) error {
if !c.isConnected {
return errDisconnected
}
for {
err := c.UnsafePush(data)
if err != nil {
if err == errDisconnected {
continue
}
return err
}
select {
case confirm := <-c.notifyConfirm:
if confirm.Ack {
return nil
}
case <-time.After(1 * time.Second):
}
}
}
// UnsafePush will push to the queue without checking for
// confirmation. It returns an error if it fails to connect.
// No guarantees are provided for whether the server will
// receive the message.
func (c *AMQPClient) UnsafePush(data []byte) error {
if !c.isConnected {
return errDisconnected
}
return c.channel.Publish(
"", // Exchange
c.pushQueue, // Routing key
false, // Mandatory
false, // Immediate
amqp.Publishing{
DeliveryMode: amqp.Persistent,
ContentType: "text/plain",
Body: data,
},
)
}
// Stream is used to listen on queue and parse the messages.
func (c *AMQPClient) Stream(cancelCtx context.Context) error {
for {
if c.isConnected {
break
}
time.Sleep(1 * time.Second)
}
err := c.channel.Qos(1, 0, false)
if err != nil {
return err
}
var connectionDropped bool
for i := 1; i <= c.threads; i++ {
messages, err := c.channel.Consume(
c.listenQueue,
consumerName(i), // Consumer
false, // Auto-Ack
false, // Exclusive
false, // No-local
false, // No-Wait
nil, // Args
)
if err != nil {
return err
}
go func() {
defer c.wg.Done()
for {
select {
case <-cancelCtx.Done():
return
case message, ok := <-messages:
if !ok {
connectionDropped = true
return
}
c.parseEvent(message)
}
}
}()
}
c.wg.Wait()
if connectionDropped {
return errDisconnected
}
return nil
}
func (c *AMQPClient) parseEvent(msg amqp.Delivery) {
var evt Task
l := c.logger.Log().Timestamp()
startTime := time.Now()
err := json.Unmarshal(msg.Body, &evt)
if err != nil {
logAndNack(msg, l, startTime, "unmarshalling body: %s - %s", string(msg.Body), err.Error())
return
}
if evt.Status == "" {
logAndNack(msg, l, startTime, "received event without data")
return
}
switch evt.Status {
case "running":
// Call an actual function
case "failed":
// Call in case of fail
default:
err = msg.Reject(false)
if err != nil {
logAndNack(msg, l, startTime, err.Error())
return
}
return
}
l.Str("level", "info").Int64("took-ms", time.Since(startTime).Milliseconds()).Msgf("%s succeeded", evt.Status)
err = msg.Ack(false)
if err != nil {
logAndNack(msg, l, startTime, err.Error())
return
}
}
func logAndNack(msg amqp.Delivery, l *zerolog.Event, t time.Time, errorMessage string, args ...interface{}) {
err := msg.Nack(false, false)
if err != nil {
panic(err)
return
}
l.Int64("took-ms", time.Since(t).Milliseconds()).Str("level", "error").Msg(fmt.Sprintf(errorMessage, args...))
}
// Close is used to destroy all tcp connection to rabbitmq.
func (c *AMQPClient) Close() error {
if !c.isConnected {
return nil
}
c.alive = false
c.logger.Printf("Waiting for current messages to be processed...")
go func() {
defer c.wg.Done()
for i := 1; i <= c.threads; i++ {
c.logger.Printf("Closing consumer: ", i)
err := c.channel.Cancel(consumerName(i), false)
if err != nil {
c.logger.Printf("error canceling consumer %s: %v", consumerName(i), err)
}
}
}()
err := c.channel.Close()
if err != nil {
return err
}
err = c.connection.Close()
if err != nil {
return err
}
c.isConnected = false
c.logger.Printf("gracefully stopped rabbitMQ connection")
return nil
}
func consumerName(i int) string {
return fmt.Sprintf("go-consumer-%v", i)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment