Skip to content

Instantly share code, notes, and snippets.

@nevivurn
Created January 29, 2019 15:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nevivurn/d8ad11964611b4f3f819597324bd50d7 to your computer and use it in GitHub Desktop.
Save nevivurn/d8ad11964611b4f3f819597324bd50d7 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"context"
"crypto/md5"
"database/sql"
"errors"
"flag"
"fmt"
"log"
"net/http"
"strings"
"sync"
"time"
_ "github.com/lib/pq"
"github.com/NYTimes/gziphandler"
)
// Command args
var (
gid = flag.Int("gid", 0, "gid to set all users as")
laddr = flag.String("laddr", "127.0.0.1:8080", "listening address")
dbconn = flag.String("db", "user=id password=secret dbname=id sslmode=disable", "database connection string")
)
var errUnauthorizedHost = errors.New("unauthorized host")
type cache struct {
sync.RWMutex
passwd []byte
check [md5.Size]byte
modified string
}
func (c *cache) update(db *sql.DB) error {
start := time.Now()
row, err := db.Query("SELECT username, uid, name, shell FROM users ORDER BY uid")
if err != nil {
return err
}
defer row.Close()
type user struct {
username string
uid int
name string
home string
shell string
}
var users []user
for row.Next() {
var u user
if err := row.Scan(&u.username, &u.uid, &u.name, &u.shell); err != nil {
return err
}
u.home = "/csehome/" + u.username
users = append(users, u)
}
if err := row.Err(); err != nil {
return err
}
var buf bytes.Buffer
for _, user := range users {
fmt.Fprintf(&buf, "%s:x:%d:%d:%s:%s:%s\n",
user.username, user.uid, *gid, user.name, user.home, user.shell)
}
passwd := buf.Bytes()
check := md5.Sum(passwd)
if check == c.check {
return nil
}
c.Lock()
c.passwd = passwd
c.check = check
c.modified = time.Now().UTC().Format(time.RFC1123)
c.Unlock()
log.Printf("change detected: %d entries generated in %v", len(users), time.Since(start))
return nil
}
func checkHost(ctx context.Context, db *sql.DB, host string) error {
row := db.QueryRowContext(ctx, "SELECT EXISTS (SELECT 1 FROM hosts WHERE host = $1)", host)
var ok bool
if err := row.Scan(&ok); err != nil {
return err
}
if !ok {
return errUnauthorizedHost
}
return nil
}
func main() {
flag.Parse()
db, err := sql.Open("postgres", *dbconn)
if err != nil {
log.Fatalln("db:", err)
}
defer db.Close()
var ch cache
if err := ch.update(db); err != nil {
log.Fatalln("update:", err)
}
go func() {
tick := time.Tick(15 * time.Minute)
for {
<-tick
if err := ch.update(db); err != nil {
log.Println("update:", err)
}
}
}()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
host := r.RemoteAddr[:strings.LastIndexByte(r.RemoteAddr, ':')]
if proxy := r.Header.Get("X-Forwarded-For"); (host == "127.0.0.1" || host == "::1") && proxy != "" {
if ind := strings.IndexByte(proxy, ','); ind != -1 {
proxy = proxy[:ind]
}
host = proxy
}
if err := checkHost(r.Context(), db, host); err != nil {
if err == errUnauthorizedHost {
w.WriteHeader(401)
log.Println("unauthorized:", host)
} else {
w.WriteHeader(500)
log.Println("host:", err)
}
return
}
ch.RLock()
defer ch.RUnlock()
// TODO: It may be possible cache to between restarts
// - Store last-modified somewhere
// - Get last-modified from the DB, somehow
// - encode checksum as time, send as last-modified - horrible hack
// - modify nsscache to use ETag
if r.Header.Get("If-Modified-Since") == ch.modified {
w.WriteHeader(http.StatusNotModified)
return
}
w.Header().Add("Last-Modified", ch.modified)
w.Write(ch.passwd)
})
http.Handle("/api/get-passwd", gziphandler.GzipHandler(handler))
// TODO: /api/get-groups
log.Println("listen:", http.ListenAndServe(*laddr, nil))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment