Skip to content

Instantly share code, notes, and snippets.

@mavimo
Created July 24, 2018 07:03
Show Gist options
  • Save mavimo/749d793d0ad1f568ac2e9186cc2791ef to your computer and use it in GitHub Desktop.
Save mavimo/749d793d0ad1f568ac2e9186cc2791ef to your computer and use it in GitHub Desktop.
docker-scaler
package dockerscaler
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net/http"
"sync/atomic"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/swarm"
docker "github.com/docker/docker/client"
)
const svcNotFound = errors.New("service not found")
// DockerScaler is the service used to scale a service in a docker env
type DockerScaler struct {
dockerClient *docker.Client
}
// Create is the factory method for DockerScaler service
func Create(socket, version, key, ca, cert string) (*DockerScaler, error) {
httpClient, err := createClient(key, ca, cert)
if err != nil {
return nil, fmt.Errorf("unable to create the http client: %v", err)
}
dockerClient, err := docker.NewClient(
socket,
version,
httpClient,
map[string]string{},
)
if err != nil {
return nil, fmt.Errorf("unable to create the docker client: %v", err)
}
return &DockerScaler{dockerClient: dockerClient}, nil
}
// ScaleService is the command to execute to scale a specific service to a defined number of containers
func (ds DockerScaler) ScaleService(name string, i uint64) ([]string, error) {
svc, err := ds.retrieveService(name)
if err != nil {
return []string{}, fmt.Errorf("unable to retrive service: %v", err)
}
var oldReplicas = svc.Spec.Mode.Replicated.Replicas
var newReplicas = atomic.AddUint64(oldReplicas, i)
svc.Spec.Mode.Replicated.Replicas = &newReplicas
res, err := ds.dockerClient.ServiceUpdate(
context.Background(),
svc.ID,
swarm.Version{
Index: svc.Version.Index,
},
svc.Spec,
types.ServiceUpdateOptions{},
)
if err != nil {
return []string{}, fmt.Errorf("unable to update service: %v", err)
}
return res.Warnings, nil
}
func createClient(key, ca, cert string) (*http.Client, error) {
if ca == "" && key == "" && cert == "" {
return initHTTPClient()
}
return initHTTPSClient(key, cert, ca)
}
func initHTTPClient() (*http.Client, error) {
return &http.Client{
Timeout: 1 * time.Second,
}, nil
}
func initHTTPSClient(key, cert, ca string) (*http.Client, error) {
c, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
return nil, fmt.Errorf("unable to load X509 key pair: %v", err)
}
caCert, err := ioutil.ReadFile(ca)
if err != nil {
return nil, fmt.Errorf("unable to load file %s CA file: %v", ca, err)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{c},
RootCAs: caCertPool,
}
tlsConfig.BuildNameToCertificate()
transport := &http.Transport{TLSClientConfig: tlsConfig}
return &http.Client{
Transport: transport,
Timeout: 1 * time.Second,
}, nil
}
func (ds DockerScaler) retrieveService(svc string) (*swarm.Service, error) {
args := filters.NewArgs()
args.Add("name", svc)
serviceListOps := types.ServiceListOptions{
Filters: args,
}
serviceList, err := ds.dockerClient.ServiceList(
context.Background(),
serviceListOps,
)
if err != nil {
return nil, fmt.Errorf("unable to get service list: %v", err)
}
if len(serviceList) == 0 {
return nil, svcNotFound
}
return &serviceList[0], nil
}
package main
import (
"flag"
"log"
"os"
"github.com/mavimo/749d793d0ad1f568ac2e9186cc2791ef/dockerscaler"
)
var (
cert = flag.String("cert", os.Getenv("cert"), "Certificate file (if Docker is using TLS)")
key = flag.String("key", os.Getenv("key"), "Private key (if Docker is using TLS")
ca = flag.String("ca", os.Getenv("ca"), "Certificate Authority (if Docker is using TLS")
svc = flag.String("svc", os.Getenv("svc"), "Service name to scale")
socket = flag.String("socket", os.Getenv("socket"), "Docker Daemon Socket (UNIX or TCP)")
)
func main() {
flag.Parse()
ds, err := dockerscaler.Create(*socket, "", *key, *ca, *cert)
if err != nil {
panic(err)
}
warnings, err := ds.ScaleService(*svc, 1)
if err != nil {
panic(err)
}
for _, warning := range warnings {
log.Println("warning scaling %s service to %d instances: %s", svc, 1, warning)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment