Skip to content

Instantly share code, notes, and snippets.

@Integralist
Last active March 27, 2024 12:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Integralist/927f91c34be67499a6a1a430ddaebe92 to your computer and use it in GitHub Desktop.
Save Integralist/927f91c34be67499a6a1a430ddaebe92 to your computer and use it in GitHub Desktop.
[Go basic middleware abstraction] #go #golang #middleware
// THIS IS THE REDESIGNED VERSION (see 3. middleware.go for my original approach)
package middleware
import (
"net/http"
)
// Decorator is a middleware function.
type Decorator func(http.Handler) http.Handler
// Pipeline is a type for wrapping an http.Handler and adding optional logging,
// metrics and tracing.
type Pipeline struct {
middleware []Decorator
}
// New returns a middleware Pipeline.
func New(middleware ...Decorator) *Pipeline {
p := &Pipeline{}
p.middleware = append(p.middleware, middleware...)
return p
}
// Decorate wraps next in the defined middleware and executes the pipeline.
func (p *Pipeline) Decorate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for i := len(p.middleware) - 1; i >= 0; i-- {
next = p.middleware[i](next)
}
next.ServeHTTP(w, r)
})
}
// THIS IS THE REDESIGNED VERSION (see 4. main.go for my original approach)
// Package main is the entry point for the application.
package main
import (
"context"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/sync/errgroup"
"github.com/org/repo/internal/config"
h "github.com/org/repo/internal/http"
"github.com/org/repo/internal/logging"
"github.com/org/repo/internal/metrics"
"github.com/org/repo/internal/middleware"
"github.com/org/repo/internal/traces"
)
var (
cfgPath = flag.String("cfg", "", "Path to config")
metricsHost = flag.String("metrics-host", "127.0.0.1", "Host to use for metrics & pprof")
metricsPort = flag.Int("metrics-port", 8447, "Port to expose for metrics & pprof") //nolint:gomnd
)
// ImageTag is the container image version and is displayed in the healthcheck.
// It is overridden at run/build time (see ../../Makefile).
var ImageTag = "dev"
func main() {
build.AddVersionFlag()
flag.Parse()
// NOTE: Moving the logic to `app()` allows us to honour `defer` calls.
if err := app(); err != nil {
os.Exit(1)
}
}
func app() error {
ctx := context.Background()
logger := logging.NewLogger()
cfg, err := config.Load(*cfgPath)
if err != nil {
logger.LogAttrs(ctx, slog.LevelError, "config_load", slog.Any("error", fmt.Errorf("unable to load config: %w", err)))
}
if err := cfg.Validate(); err != nil {
logger.LogAttrs(ctx, slog.LevelError, "config_validate", slog.Any("error", fmt.Errorf("invalid config: %w", err)))
}
// validate secret is available
fmt.Printf("foo secret: %#v\n", cfg.Secrets)
// setup observability
registry := prometheus.NewRegistry()
metric := metrics.New(registry)
state := "enabled"
if cfg.OtelDisabled {
state = "disabled"
}
logger.LogAttrs(ctx, slog.LevelInfo, "trace exporting "+state)
traceShutdown, tracer, err := traces.New(ctx, cfg)
if err != nil {
logger.LogAttrs(ctx, slog.LevelError, "trace create", slog.Any("error", fmt.Errorf("unable to start tracing: %w", err)))
}
defer func() {
// NOTE: We use a separate context for the shutdown.
// This is so that the graceful shutdown of the main HTTP server doesn't
// end up cancelling the traceShutdown context before it has a chance to
// complete its shutdown process.
if err := traceShutdown(context.WithoutCancel(ctx)); err != nil {
logger.LogAttrs(ctx, slog.LevelError, "trace shutdown", slog.Any("error", err))
}
}()
g := new(errgroup.Group)
timeout := time.Duration(10) * time.Second //nolint:gomnd
// Setup separate /metrics endpoint for Prometheus scraping.
// Needs to be separate listener as metrics service runs on its own port.
var (
metricsAddr string
metricServer *http.Server
)
if *metricsPort != 0 {
metricsAddr = fmt.Sprintf("%s:%d", *metricsHost, *metricsPort)
g.Go(func() error {
mux := http.NewServeMux()
mux.Handle("/metrics", metric.Handler())
metricServer = &http.Server{
Addr: metricsAddr,
Handler: mux,
ReadHeaderTimeout: timeout, // avoid slowloris https://app.deepsource.com/directory/analyzers/go/issues/GO-S2114
}
if err := metricServer.ListenAndServe(); err != nil {
err := fmt.Errorf("metrics server failed to listen and serve requests: %w", err)
logger.LogAttrs(ctx, slog.LevelError, "error handling metric requests",
slog.Any("error", err),
slog.String("metrics_addr", metricsAddr),
)
return err
}
return nil
})
}
logger = logger.With(
slog.Int("gomaxprocs", runtime.GOMAXPROCS(0)),
slog.String("build_info", build.Info.String()),
slog.String("cfg", *cfgPath),
slog.String("cfg_addr", cfg.Addr),
slog.String("metrics_addr", metricsAddr),
)
logger.LogAttrs(ctx, slog.LevelInfo, "service starting")
// Routing
mux := http.NewServeMux()
// NOTE: We have separate pipelines.
// One for the healthcheck and another for all other endpoints.
// This is because we want to avoid tracing the healthcheck endpoint.
// For some apps, tracing the healthcheck can result in poor performance.
pipelineHealth := middleware.New(logging.WithRoutePattern(mux, logger), metric.WithHealthCheckMetrics, h.WithHeaders(cfg))
healthHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status": "success", "image_tag": "%s"}`, ImageTag)
})
mux.Handle("/healthcheck", pipelineHealth.Decorate(healthHandler))
// NOTE: The order of the middleware is deliberate.
// WithSpan updates the request context with a Span ID.
// WithRoutePattern checks for the Span ID and attaches it to the logger.
// WithMetrics, WithRouteTag and WithParamAttrs all check for the route pattern added by WithRoutePattern.
pipeline := middleware.New(traces.WithSpan(tracer), logging.WithRoutePattern(mux, logger), metric.WithMetrics, traces.WithRouteTag, traces.WithParamAttrs, h.WithHeaders(cfg))
exampleHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"endpoint": "example"}`)
})
mux.Handle("/example/{id}", pipeline.Decorate(exampleHandler))
// API Server
apiServer := &http.Server{
Addr: cfg.Addr,
Handler: mux,
ReadHeaderTimeout: timeout, // avoid slowloris https://app.deepsource.com/directory/analyzers/go/issues/GO-S2114
}
g.Go(func() error {
if err := apiServer.ListenAndServe(); err != nil {
err := fmt.Errorf("api server failed to listen and serve requests: %w", err)
logger.LogAttrs(ctx, slog.LevelError, "error handling API requests",
slog.Any("error", err),
slog.String("metrics_addr", metricsAddr),
)
return err
}
return nil
})
// Graceful shutdown
g.Go(func() error {
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
signalType := <-quit
logger.LogAttrs(ctx, slog.LevelWarn, "http server graceful shutdown",
slog.String("signal", signalType.String()),
)
// NOTE: Shutdown is given a separate context from the API server.
if apiServer != nil {
if err := apiServer.Shutdown(context.Background()); err != nil { //nolint:contextcheck
err := fmt.Errorf("unable to gracefully stop API server: %w", err)
logger.LogAttrs(ctx, slog.LevelError, "error shutting down API server",
slog.Any("error", err),
)
return err
}
}
if metricServer != nil {
if err := metricServer.Shutdown(context.Background()); err != nil { //nolint:contextcheck
err := fmt.Errorf("unable to gracefully stop metrics server: %w", err)
logger.LogAttrs(ctx, slog.LevelError, "error shutting down metrics server",
slog.Any("error", err),
)
return err
}
}
return nil
})
return g.Wait()
}
package middleware
import (
"net/http"
)
// Decorator is a middleware function.
type Decorator func(http.Handler) http.Handler
// Pipeline is a type for wrapping an http.Handler and adding optional logging,
// metrics and tracing.
type Pipeline struct {
middleware []Decorator
mux *http.ServeMux
}
// New returns a middleware Pipeline.
func New(mux *http.ServeMux) *Pipeline {
return &Pipeline{
mux: mux,
}
}
// Use installs one or more middleware in the request cycle.
func (p *Pipeline) Use(middleware ...Decorator) {
p.middleware = append(p.middleware, middleware...)
}
// Handle registers a handler for the given path.
func (p *Pipeline) Handle(path string, handler http.Handler, middleware ...Decorator) {
for i := len(middleware) - 1; i >= 0; i-- {
handler = middleware[i](handler)
}
for i := len(p.middleware) - 1; i >= 0; i-- {
handler = p.middleware[i](handler)
}
p.mux.Handle(path, handler)
}
// Package main is the entry point for the application.
package main
import (
"context"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/org/repo/internal/config"
"github.com/org/repo/internal/logging"
"github.com/org/repo/internal/metrics"
"github.com/org/repo/internal/middleware"
"github.com/org/repo/internal/traces"
)
var (
cfgPath = flag.String("cfg", "", "Path to config")
metricsHost = flag.String("metrics-host", "127.0.0.1", "Host to use for metrics & pprof")
metricsPort = flag.Int("metrics-port", 8447, "Port to expose for metrics & pprof") //nolint:gomnd
)
// ImageTag is the container image version and is displayed in the healthcheck.
// It is overridden at run/build time (see ../../Makefile).
var ImageTag = "dev"
func main() {
build.AddVersionFlag()
flag.Parse()
// NOTE: Moving the logic to `app()` allows us to honour `defer` calls.
if err := app(); err != nil {
os.Exit(1)
}
}
func app() error {
ctx := context.Background()
logger := logging.NewLogger()
cfg, err := config.Load(*cfgPath)
if err != nil {
logger.LogAttrs(ctx, slog.LevelError, "config_load", slog.Any("error", fmt.Errorf("unable to load config: %w", err)))
}
if err := cfg.Validate(); err != nil {
logger.LogAttrs(ctx, slog.LevelError, "config_validate", slog.Any("error", fmt.Errorf("invalid config: %w", err)))
}
// validate secret is available
fmt.Printf("foo secret: %#v\n", cfg.Secrets)
// setup observability
registry := prometheus.NewRegistry()
metric := metrics.New(registry)
state := "enabled"
if cfg.OtelDisabled {
state = "disabled"
}
logger.LogAttrs(ctx, slog.LevelInfo, "trace exporting "+state)
traceShutdown, tracer, err := traces.New(ctx, cfg)
if err != nil {
logger.LogAttrs(ctx, slog.LevelError, "trace create", slog.Any("error", fmt.Errorf("unable to start tracing: %w", err)))
}
defer func() {
// NOTE: We use a separate context for the shutdown.
// This is so that the graceful shutdown of the main HTTP server doesn't
// end up cancelling the traceShutdown context before it has a chance to
// complete its shutdown process.
if err := traceShutdown(context.WithoutCancel(ctx)); err != nil {
logger.LogAttrs(ctx, slog.LevelError, "trace shutdown", slog.Any("error", err))
}
}()
serverError := make(chan error, 1)
timeout := time.Duration(10) * time.Second //nolint:gomnd
// Setup separate /metrics endpoint for Prometheus scraping.
// Needs to be separate listener as metrics service runs on its own port.
var metricsAddr string
if *metricsPort != 0 {
metricsAddr = fmt.Sprintf("%s:%d", *metricsHost, *metricsPort)
go func() {
mux := http.NewServeMux()
mux.Handle("/metrics", metric.Handler())
server := &http.Server{
Addr: metricsAddr,
Handler: mux,
ReadHeaderTimeout: timeout, // avoid slowloris https://app.deepsource.com/directory/analyzers/go/issues/GO-S2114
}
if err := server.ListenAndServe(); err != nil {
logger.LogAttrs(ctx, slog.LevelError, "error handling metric requests",
slog.Any("error", err),
slog.String("metrics_addr", metricsAddr),
)
serverError <- err
}
}()
}
select {
case err := <-serverError:
return err
case <-time.After(1 * time.Second):
// no-op: allow logic to flow to setting up routing/http server
}
logger = logger.With(
slog.Int("gomaxprocs", runtime.GOMAXPROCS(0)),
slog.String("build_info", build.Info.String()),
slog.String("cfg", *cfgPath),
slog.String("cfg_addr", cfg.Addr),
slog.String("metrics_addr", metricsAddr),
)
logger.LogAttrs(ctx, slog.LevelInfo, "service starting")
// Routing
mux := http.NewServeMux()
// NOTE: We have separate pipelines.
// One for the healthcheck and another for all other endpoints.
// This is because we want to avoid tracing the healthcheck endpoint.
// For some apps, tracing the healthcheck can result in poor performance.
pipelineHealth := middleware.New(mux)
pipelineHealth.Use(logging.WithRoutePattern(mux, logger), metric.WithHealthCheckMetrics)
pipelineHealth.Handle("/healthcheck", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(fmt.Sprintf(`{"status": "success", "image_tag": "%s"}`, ImageTag)))
}))
// NOTE: The order of the middleware is deliberate.
// WithSpan updates the request context with a Span ID.
// WithRoutePattern checks for the Span ID and attaches it to the logger.
// WithMetrics, WithRouteTag and WithParamAttrs all check for the route pattern added by WithRoutePattern.
pipeline := middleware.New(mux)
pipeline.Use(traces.WithSpan(tracer), logging.WithRoutePattern(mux, logger), metric.WithMetrics, traces.WithRouteTag, traces.WithParamAttrs)
pipeline.Handle("/example/{id}", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"endpoint": "example"}`))
}))
// HTTP Server
server := &http.Server{
Addr: cfg.Addr,
Handler: mux,
ReadHeaderTimeout: timeout, // avoid slowloris https://app.deepsource.com/directory/analyzers/go/issues/GO-S2114
}
go func() {
if err := server.ListenAndServe(); err != nil {
logger.LogAttrs(ctx, slog.LevelError, "error handling server requests",
slog.Any("error", err),
slog.String("metrics_addr", metricsAddr),
)
serverError <- err
}
}()
select {
case err := <-serverError:
return err
case <-time.After(1 * time.Second):
// no-op: allow logic to flow to setting up graceful shutdown
}
// Graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
signalType := <-quit
logger.LogAttrs(ctx, slog.LevelWarn, "http server graceful shutdown",
slog.String("signal", signalType.String()),
)
// NOTE: Shutdown is given a separate context from the API server.
if err := server.Shutdown(context.Background()); err != nil {
logger.LogAttrs(ctx, slog.LevelError, "error shutting down server",
slog.Any("error", fmt.Errorf("unable to gracefully stop server: %w", err)),
)
}
return nil
}
package logging
import (
"context"
"io"
"log"
"log/slog"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
h "github.com/org/repo/internal/http"
"github.com/org/repo/internal/traces"
)
// Level allows dynamically changing the output level via .Set() method.
// Defaults to [slog.LevelInfo].
var Level = new(slog.LevelVar)
// routePattern matches the content inside curly brackets.
var routePattern = regexp.MustCompile(`\{([^{}]+)\}`)
// Logger describes the set of features we want to expose from log/slog.
//
// NOTE: Don't confuse our custom With() signature with (*slog.Logger).With
// We return a Logger type where the standard library returns a *slog.Logger
type Logger interface {
Enabled(ctx context.Context, level slog.Level) bool
LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr)
With(args ...any) Logger
_private()
}
// NewLogger returns a logging.Logger configured for stderr.
func NewLogger() Logger {
return NewLoggerWithOutputLevel(os.Stdout, Level)
}
// NewLoggerWithOutput returns a [Logger] configured with an output writer.
func NewLoggerWithOutput(w io.Writer) Logger {
return (*logger)(slog.New(slog.NewJSONHandler(w, defaultOptions()).WithAttrs(defaultAttrs())))
}
// NewLoggerWithOutputLevel returns a [Logger] configured with an output writer and Level.
func NewLoggerWithOutputLevel(w io.Writer, l slog.Leveler) Logger {
opts := defaultOptions()
opts.Level = l
return (*logger)(slog.New(slog.NewJSONHandler(w, opts).WithAttrs(defaultAttrs())))
}
// NewBareLoggerWithOutputLevel returns a [Logger] configured with an output location and [slog.Leveler].
// It does not include any additional attributes.
func NewBareLoggerWithOutputLevel(w io.Writer, l slog.Leveler) Logger {
opts := defaultOptions()
opts.Level = l
return (*logger)(slog.New(slog.NewJSONHandler(w, opts)))
}
// nolint:revive
//
//lint:ignore U1000 Prevents any other package from implementing this interface
type private struct{} //nolint:unused
// IMPORTANT: logger is an alias to slog.Logger to avoid a double-pointer deference.
// All methods off the type will need to type-cast a *logger to *slog.Logger.
// With() must additionally type-cast back to a Logger compatible type.
type logger slog.Logger
func (*logger) _private() {}
func (l *logger) Enabled(ctx context.Context, level slog.Level) bool {
return (*slog.Logger)(l).Enabled(ctx, level)
}
func (l *logger) LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) {
(*slog.Logger)(l).LogAttrs(ctx, level, msg, attrs...)
}
func (l *logger) With(args ...any) Logger {
return (*logger)((*slog.Logger)(l).With(args...))
}
// Adapt returns a [log.Logger] for use with packages that are not yet compatible with
// [log/slog].
func Adapt(l Logger, level slog.Level) *log.Logger {
// _private() ensures this type assertion cannot panic.
slogger := (*slog.Logger)(l.(*logger)) //nolint:revive,forcetypeassert
return slog.NewLogLogger(slogger.Handler(), level)
}
func defaultOptions() *slog.HandlerOptions {
return &slog.HandlerOptions{
AddSource: true,
ReplaceAttr: slogReplaceAttr,
Level: Level,
}
}
func defaultAttrs() []slog.Attr {
return []slog.Attr{slog.Group("app",
slog.String("name", build.Info.Project),
slog.String("version", build.Info.Version),
)}
}
// NullLogger discards logs.
func NullLogger() Logger {
// NOTE: We pass a level not currently defined to reduce operational overhead.
// The intent, unlike passing nil for the opts argument, is for the logger to
// not even bother generating a message that will just be discarded.
// An additional gap of 4 was used as it aligns with Go's original design.
// https://github.com/golang/go/blob/1e95fc7/src/log/slog/level.go#L34-L42
return (*logger)(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError + 4}))) //nolint:gomnd
}
// slogReplaceAttr adjusts the log output.
//
// - Restricts these changes to top-level keys (not keys within groups)
// - Changes default time field value to UTC time zone
// - Replaces msg key with event
// - Omits event field if empty
// - Omits error field if when nil
// - Truncates source's filename to example-app directory
//
// - Formats duration and delay values in microseconds as xxxxµs
//
// See https://pkg.go.dev/log/slog#HandlerOptions.ReplaceAttr
// N.B: TextHandler manages quoting attribute values as necessary.
func slogReplaceAttr(groups []string, a slog.Attr) slog.Attr {
// Limit application of these rules only to top-level keys
if len(groups) == 0 {
// Set time zone to UTC
if a.Key == slog.TimeKey {
a.Value = slog.TimeValue(a.Value.Time().UTC())
return a
}
// Use event as the default MessageKey, remove if empty
if a.Key == slog.MessageKey {
a.Key = "event"
if a.Value.String() == "" {
return slog.Attr{}
}
return a
}
// Display a 'partial' path.
// Avoids ambiguity when multiple files have the same name across packages.
// e.g. billing.go appears under 'global', 'billing' and 'server' packages.
if a.Key == slog.SourceKey {
if source, ok := a.Value.Any().(*slog.Source); ok {
a.Key = "caller"
if _, after, ok := strings.Cut(source.File, "example-app"+string(filepath.Separator)); ok {
source.File = after
}
}
return a
}
}
// Remove error key=value when error is nil
if a.Equal(slog.Any("error", error(nil))) {
return slog.Attr{}
}
// Present durations and delays as xxxxµs
switch a.Key {
case "dur", "delay", "p95", "previous_p95", "remaining", "max_wait":
a.Value = slog.StringValue(strconv.FormatInt(a.Value.Duration().Microseconds(), 10) + "µs")
}
return a
}
// WithRoutePattern logs the http handler with the route pattern.
func WithRoutePattern(mux *http.ServeMux, logger Logger) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := h.ContextWithRoutePattern(mux, r)
route, _ := h.RoutePatternFromContext(ctx) //nolint:contextcheck
attrs := []any{
slog.String("pattern", route),
}
segments := routePattern.FindAllStringSubmatch(route, -1)
for _, seg := range segments {
if len(seg) == 2 { //nolint:gomnd
key := seg[1]
attrs = append(attrs, slog.String(key, r.PathValue(key)))
}
}
if traceID := traces.GetTraceID(ctx); traceID != "" { //nolint:contextcheck
logger = logger.With(slog.String("trace_id", traceID))
}
logger.LogAttrs(r.Context(), slog.LevelInfo, "handler_called", slog.Group("request", attrs...))
next.ServeHTTP(w, r.WithContext(ctx)) //nolint:contextcheck
})
}
}
package http
import (
"context"
"net/http"
)
// RouteContextKey is a custom type for the key used in context.Context
type routeContextKey struct{}
// RoutePatternContextKey is the key to store a route pattern in context.Context
var RoutePatternContextKey = routeContextKey{}
// ContextWithRoutePattern sets the request route into the context.Context
func ContextWithRoutePattern(mux *http.ServeMux, r *http.Request) context.Context {
_, route := mux.Handler(r)
return context.WithValue(r.Context(), RoutePatternContextKey, route)
}
// RoutePatternFromContext returns the request route from the context.Context
func RoutePatternFromContext(ctx context.Context) (string, bool) {
v := ctx.Value(RoutePatternContextKey)
if v != nil {
if s, ok := v.(string); ok {
return s, ok
}
}
return "", false
}
package metrics
import (
"context"
"net/http"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
h "github.com/org/repo/internal/http"
"github.com/org/repo/internal/traces"
)
const (
namespace = "fastly"
subsystem = "domainr_api"
)
// New initializes our prometheus metrics.
func New(registry *prometheus.Registry) *Metrics {
m := &Metrics{
registry: registry,
}
registerAndSetBuildInfo(registry)
registerUptime(registry)
m.registerRequestMetrics()
registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{Namespace: namespace}))
registry.MustRegister(collectors.NewGoCollector())
return m
}
// Metrics contains all the Prometheus tracked by Domainr API
type Metrics struct {
registry *prometheus.Registry
healthChecksTotal *prometheus.CounterVec
requestsTotal *prometheus.CounterVec
requestDuration *prometheus.HistogramVec
requestSize *prometheus.SummaryVec
responseSize *prometheus.SummaryVec
}
// Handler is the prometheus http metrics handler.
func (m *Metrics) Handler() http.Handler {
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{
EnableOpenMetrics: true,
})
}
// RegisterCounter returns a counter.
func (m *Metrics) RegisterCounter(name, help string) prometheus.Counter {
customCounter := prometheus.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: name,
Help: help,
})
m.registry.MustRegister(customCounter)
return customCounter
}
func (m *Metrics) WithMetrics(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
label := prometheus.Labels{}
if route, ok := h.RoutePatternFromContext(r.Context()); ok {
label = prometheus.Labels{"route": route}
}
// Wraps the provided http.Handler to observe the request result with the provided metrics.
promhttp.InstrumentHandlerCounter(
m.requestsTotal.MustCurryWith(label),
promhttp.InstrumentHandlerDuration(
m.requestDuration.MustCurryWith(label),
promhttp.InstrumentHandlerRequestSize(
m.requestSize.MustCurryWith(label),
promhttp.InstrumentHandlerResponseSize(
m.responseSize.MustCurryWith(label),
next,
),
),
[]promhttp.Option{promhttp.WithExemplarFromContext(getTraceID)}...,
),
[]promhttp.Option{promhttp.WithExemplarFromContext(getTraceID)}...,
).ServeHTTP(w, r)
})
}
func (m *Metrics) WithHealthCheckMetrics(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
label := prometheus.Labels{}
if route, ok := h.RoutePatternFromContext(r.Context()); ok {
label = prometheus.Labels{"route": route}
}
// Wraps the provided http.Handler to observe the request result with the provided metrics.
promhttp.InstrumentHandlerCounter(
m.healthChecksTotal.MustCurryWith(label),
next,
).ServeHTTP(w, r)
})
}
func (m *Metrics) registerRequestMetrics() {
m.healthChecksTotal = promauto.With(m.registry).NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "health_checks",
Help: "Number of health check requests received",
}, []string{"code", "route"},
)
m.requestsTotal = promauto.With(m.registry).NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "requests_total",
Help: "Tracks the number of HTTP requests.",
}, []string{"code", "route"},
)
m.requestDuration = promauto.With(m.registry).NewHistogramVec(
prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "request_duration_seconds",
Help: "Tracks the latencies for HTTP requests.",
Buckets: prometheus.DefBuckets,
},
[]string{"code", "route"},
)
m.requestSize = promauto.With(m.registry).NewSummaryVec(
prometheus.SummaryOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "request_size_bytes",
Help: "Tracks the size of HTTP requests.",
},
[]string{"code", "route"},
)
m.responseSize = promauto.With(m.registry).NewSummaryVec(
prometheus.SummaryOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "response_size_bytes",
Help: "Tracks the size of HTTP responses.",
},
[]string{"code", "route"},
)
}
func getTraceID(ctx context.Context) prometheus.Labels {
if traceID := traces.GetTraceID(ctx); traceID != "" {
return prometheus.Labels{"traceID": traceID}
}
return nil
}
func registerAndSetBuildInfo(registry *prometheus.Registry) {
info := build.Info
version := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: namespace,
Name: "build_info",
Help: "Build information for domainr-api.",
}, []string{
"version",
"revision",
"goversion",
})
registry.MustRegister(version)
version.WithLabelValues(
info.Version+"-"+info.BuildNumber,
info.GitRevision,
info.GoVersion,
).Set(1)
}
func registerUptime(registry *prometheus.Registry) {
start := time.Now()
uptime := prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Namespace: namespace,
Name: "uptime_duration_seconds",
Help: "Process uptime",
}, func() float64 {
return time.Since(start).Seconds()
})
registry.MustRegister(uptime)
}
package traces
import (
"context"
"fmt"
"net/http"
"regexp"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.24.0"
"go.opentelemetry.io/otel/trace"
"github.com/org/repo/internal/config"
h "github.com/org/repo/internal/http"
)
const (
serviceName = "example-app"
version = "0.1.0"
)
// routePattern matches the content inside curly brackets.
var routePattern = regexp.MustCompile(`\{([^{}]+)\}`)
func newResource() *resource.Resource {
info := build.Info
return resource.NewWithAttributes(
semconv.SchemaURL,
semconv.ServiceName(serviceName),
semconv.ServiceVersionKey.String(
info.Version+"-"+info.BuildNumber,
),
)
}
func New(ctx context.Context, cfg *config.Config) (func(context.Context) error, trace.Tracer, error) {
var tracerProvider *sdktrace.TracerProvider
if cfg.OtelDisabled {
tracerProvider = sdktrace.NewTracerProvider(
sdktrace.WithResource(newResource()),
sdktrace.WithSampler(sdktrace.NeverSample()),
)
} else {
exporter, err := otlptracehttp.New(ctx)
if err != nil {
return nil, nil, fmt.Errorf("creating OTLP trace exporter: %w", err)
}
tracerProvider = sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exporter),
sdktrace.WithResource(newResource()),
)
}
otel.SetTracerProvider(tracerProvider)
tracer := tracerProvider.Tracer(serviceName, trace.WithInstrumentationVersion(version))
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
return tracerProvider.Shutdown, tracer, nil
}
func NewHandler(handler http.Handler) http.Handler {
return otelhttp.NewHandler(handler, "server",
otelhttp.WithMessageEvents(otelhttp.ReadEvents, otelhttp.WriteEvents),
otelhttp.WithFilter(otelReqFilter),
)
}
// GetTraceID returns the current traceID if it will be sampled and exported
func GetTraceID(ctx context.Context) string {
span := trace.SpanFromContext(ctx)
sampled := span.SpanContext().IsSampled()
traceID := span.SpanContext().TraceID().String()
if traceID != "" && sampled {
return traceID
}
return ""
}
func Tracer() trace.Tracer {
return otel.GetTracerProvider().Tracer(serviceName)
}
// Start creates a span and a context.Context containing the newly-created span.
func Start(ctx context.Context, name string) (context.Context, trace.Span) {
return Tracer().Start(ctx, name, trace.WithSpanKind(trace.SpanKindInternal))
}
func otelReqFilter(req *http.Request) bool {
return req.URL.Path != "/healthcheck"
}
// WithSpan starts a trace.
func WithSpan(tracer trace.Tracer) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := tracer.Start(r.Context(), "handler_called")
defer span.End()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// WithRouteTag annotates spans and metrics with the provided route name with
// HTTP route attribute.
func WithRouteTag(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
route, _ := h.RoutePatternFromContext(r.Context()) //nolint:contextcheck
ctx := otel.GetTextMapPropagator().Extract(
r.Context(), propagation.HeaderCarrier(r.Header),
)
otelhttp.WithRouteTag(route, next).ServeHTTP(w, r.WithContext(ctx))
})
}
// WithParamAttrs annotates spans with route pattern parameters and their values.
func WithParamAttrs(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
route, _ := h.RoutePatternFromContext(ctx) //nolint:contextcheck
span := trace.SpanFromContext(ctx)
segments := routePattern.FindAllStringSubmatch(route, -1)
for _, seg := range segments {
if len(seg) == 2 { //nolint:gomnd
key := seg[1]
span.SetAttributes(attribute.String(key, r.PathValue(key)))
}
}
next.ServeHTTP(w, r)
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment