Created
January 29, 2019 15:06
-
-
Save nevivurn/d8ad11964611b4f3f819597324bd50d7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"context" | |
"crypto/md5" | |
"database/sql" | |
"errors" | |
"flag" | |
"fmt" | |
"log" | |
"net/http" | |
"strings" | |
"sync" | |
"time" | |
_ "github.com/lib/pq" | |
"github.com/NYTimes/gziphandler" | |
) | |
// Command args | |
var ( | |
gid = flag.Int("gid", 0, "gid to set all users as") | |
laddr = flag.String("laddr", "127.0.0.1:8080", "listening address") | |
dbconn = flag.String("db", "user=id password=secret dbname=id sslmode=disable", "database connection string") | |
) | |
var errUnauthorizedHost = errors.New("unauthorized host") | |
type cache struct { | |
sync.RWMutex | |
passwd []byte | |
check [md5.Size]byte | |
modified string | |
} | |
func (c *cache) update(db *sql.DB) error { | |
start := time.Now() | |
row, err := db.Query("SELECT username, uid, name, shell FROM users ORDER BY uid") | |
if err != nil { | |
return err | |
} | |
defer row.Close() | |
type user struct { | |
username string | |
uid int | |
name string | |
home string | |
shell string | |
} | |
var users []user | |
for row.Next() { | |
var u user | |
if err := row.Scan(&u.username, &u.uid, &u.name, &u.shell); err != nil { | |
return err | |
} | |
u.home = "/csehome/" + u.username | |
users = append(users, u) | |
} | |
if err := row.Err(); err != nil { | |
return err | |
} | |
var buf bytes.Buffer | |
for _, user := range users { | |
fmt.Fprintf(&buf, "%s:x:%d:%d:%s:%s:%s\n", | |
user.username, user.uid, *gid, user.name, user.home, user.shell) | |
} | |
passwd := buf.Bytes() | |
check := md5.Sum(passwd) | |
if check == c.check { | |
return nil | |
} | |
c.Lock() | |
c.passwd = passwd | |
c.check = check | |
c.modified = time.Now().UTC().Format(time.RFC1123) | |
c.Unlock() | |
log.Printf("change detected: %d entries generated in %v", len(users), time.Since(start)) | |
return nil | |
} | |
func checkHost(ctx context.Context, db *sql.DB, host string) error { | |
row := db.QueryRowContext(ctx, "SELECT EXISTS (SELECT 1 FROM hosts WHERE host = $1)", host) | |
var ok bool | |
if err := row.Scan(&ok); err != nil { | |
return err | |
} | |
if !ok { | |
return errUnauthorizedHost | |
} | |
return nil | |
} | |
func main() { | |
flag.Parse() | |
db, err := sql.Open("postgres", *dbconn) | |
if err != nil { | |
log.Fatalln("db:", err) | |
} | |
defer db.Close() | |
var ch cache | |
if err := ch.update(db); err != nil { | |
log.Fatalln("update:", err) | |
} | |
go func() { | |
tick := time.Tick(15 * time.Minute) | |
for { | |
<-tick | |
if err := ch.update(db); err != nil { | |
log.Println("update:", err) | |
} | |
} | |
}() | |
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
if r.Method != http.MethodGet { | |
w.WriteHeader(http.StatusMethodNotAllowed) | |
return | |
} | |
host := r.RemoteAddr[:strings.LastIndexByte(r.RemoteAddr, ':')] | |
if proxy := r.Header.Get("X-Forwarded-For"); (host == "127.0.0.1" || host == "::1") && proxy != "" { | |
if ind := strings.IndexByte(proxy, ','); ind != -1 { | |
proxy = proxy[:ind] | |
} | |
host = proxy | |
} | |
if err := checkHost(r.Context(), db, host); err != nil { | |
if err == errUnauthorizedHost { | |
w.WriteHeader(401) | |
log.Println("unauthorized:", host) | |
} else { | |
w.WriteHeader(500) | |
log.Println("host:", err) | |
} | |
return | |
} | |
ch.RLock() | |
defer ch.RUnlock() | |
// TODO: It may be possible cache to between restarts | |
// - Store last-modified somewhere | |
// - Get last-modified from the DB, somehow | |
// - encode checksum as time, send as last-modified - horrible hack | |
// - modify nsscache to use ETag | |
if r.Header.Get("If-Modified-Since") == ch.modified { | |
w.WriteHeader(http.StatusNotModified) | |
return | |
} | |
w.Header().Add("Last-Modified", ch.modified) | |
w.Write(ch.passwd) | |
}) | |
http.Handle("/api/get-passwd", gziphandler.GzipHandler(handler)) | |
// TODO: /api/get-groups | |
log.Println("listen:", http.ListenAndServe(*laddr, nil)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment