Skip to content

Instantly share code, notes, and snippets.

@addrummond
Created January 4, 2021 18:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save addrummond/f5ac937c4456f69e0be8cdf6cbd273db to your computer and use it in GitHub Desktop.
Save addrummond/f5ac937c4456f69e0be8cdf6cbd273db to your computer and use it in GitHub Desktop.
// Set up the following environment variables and then run with the path to the 'deploy' dir as the first and only argument.
//
// export IBEXFARM_DATABASE_URL=postgresql://postgres:password@localhost:6432/nmibex_dev
// export IBEXFARM_QUOTA_ID=1
// export IBEXFARM_S3_HOST=http://localhost:9000
// export IBEXFARM_S3_BUCKET=nmibex-dev
// export IBEXFARM_NO_S3_SSL=true
// export IBEXFARM_USE_MINIO=true
// export AWS_ACCESS_KEY_ID=key
// export AWS_SECRET_ACCESS_KEY=password
// export AWS_REGION=eu-west-2
// export IBEXFARM_HTPASSWD_FILE=./example_deploy_dir/htpasswd
package main
import (
"bufio"
"bytes"
"context"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/golang/snappy"
"github.com/jackc/pgx/v4/pgxpool"
)
const legacyIbexFilesS3Prefix = "ibex_legacy_files/"
const rawResultsS3Prefix = "raw-results/"
const experimentFilesPrefix = "experimentfiles/"
const ibexClassicVersion = "classic"
type s3Config struct {
key string
secret string
host string
bucket string
region string
disableSSL bool
forcePath bool
}
func main() {
deployDir := os.Args[1]
databaseUrl := os.Getenv("IBEXFARM_DATABASE_URL")
htpasswdFile := os.Getenv("IBEXFARM_HTPASSWD_FILE")
fileQuotaID, err := strconv.ParseInt(os.Getenv("IBEXFARM_QUOTA_ID"), 10, 64)
if err != nil {
log.Fatal(err)
}
s3Config := &s3Config{
key: os.Getenv("AWS_ACCESS_KEY_ID"),
secret: os.Getenv("AWS_SECRET_ACCESS_KEY"),
host: os.Getenv("IBEXFARM_S3_HOST"),
bucket: os.Getenv("IBEXFARM_S3_BUCKET"),
region: os.Getenv("AWS_REGION"),
disableSSL: os.Getenv("IBEXFARM_NO_S3_SSL") == "true",
forcePath: os.Getenv("IBEXFARM_USE_MINIO") == "true",
}
hashToFileID := make(map[string]int)
fileHashToS3Key := make(map[string]string)
if err := uploadFilesToS3(fileHashToS3Key, deployDir, s3Config); err != nil {
fmt.Errorf("Error uploading files to S3.\n")
log.Fatal(err)
}
users, err := ioutil.ReadDir(deployDir)
if err != nil {
fmt.Errorf("Error reading deployment dir.\n")
log.Fatal(err)
}
passwords, err := parseHtpasswd(htpasswdFile)
if err != nil {
fmt.Errorf("Error opening htpasswd: %v\n", htpasswdFile)
log.Fatal(err)
}
for _, userInfo := range users {
if !userInfo.IsDir() {
continue
}
log.Printf("Inserting data for user %v\n", userInfo.Name())
userBlobPath := path.Join(deployDir, userInfo.Name(), "USER")
userBlob, err := ioutil.ReadFile(userBlobPath)
if err != nil {
log.Printf("Error opening USER blob %v", userBlobPath)
continue
}
var userRecord ibexUser
if err := json.Unmarshal(userBlob, &userRecord); err != nil {
log.Printf("Error parsing USER blob for %v: %v\n", userInfo.Name(), err)
continue
}
if userRecord.Password == "" || userRecord.Username == "" {
fmt.Errorf("Incomplete user record %+v\n", userRecord)
continue
}
userBlobInfo, err := os.Stat(userBlobPath)
if err != nil {
fmt.Errorf("Error stating USER file %v\n", err)
log.Fatal(err)
}
log.Printf("Inserting user %v %+v...\n", userInfo.Name(), userRecord)
userID, err := insertUser(databaseUrl, int(fileQuotaID), userRecord, userBlobInfo.ModTime(), userInfo.ModTime())
if err != nil {
fmt.Errorf("Error inserting user %v\n", userRecord.Username)
log.Fatal(err)
}
log.Printf("...inserted with id %v\n", userID)
expsPath := path.Join(deployDir, userInfo.Name())
experiments, err := ioutil.ReadDir(expsPath)
if err != nil {
fmt.Errorf("Error reading experiment dir %v\n", expsPath)
log.Fatal(err)
}
// Used by importRawResults
// Should be big enough for most sets of results, avoiding repeated allocations within nested loop.
compressionBuffer := make([]byte, 1024*1024*1024)
for _, experimentInfo := range experiments {
if !experimentInfo.IsDir() {
continue
}
expPath := path.Join(expsPath, experimentInfo.Name())
passwordHash, _ := passwords[userRecord.Username+"/"+experimentInfo.Name()]
log.Printf("Inserting experiment %v\n", expPath)
experimentID, err := insertExperiment(databaseUrl, userID, expPath, s3Config, fileHashToS3Key, hashToFileID, passwordHash, experimentInfo.ModTime())
if err != nil {
fmt.Errorf("Error inserting experiment %v\n", experimentInfo.Name())
log.Fatal(err)
}
log.Printf("...experiment inserted with id %v\n", experimentID)
rawResultsPath := path.Join(expPath, "ibex-deploy", "results", "raw_results")
_, err = os.Stat(rawResultsPath)
if os.IsNotExist(err) {
continue
}
if err != nil {
fmt.Errorf("Error opening raw results %v\n", rawResultsPath)
log.Fatal(err)
}
blobs, err := parseRawResults(userInfo.Name(), experimentInfo.Name(), rawResultsPath)
if err != nil {
fmt.Errorf("Error parsing raw results\n")
log.Fatal(err)
}
log.Printf("Importing %v sets of results...\n", len(blobs))
err = importRawResults(databaseUrl, experimentID, s3Config, blobs, compressionBuffer)
if err != nil {
fmt.Errorf("Error importing raw results\n")
log.Fatal(err)
}
log.Printf("...imported\n")
log.Printf("Importing legacy results file...\n")
resultsPath := path.Join(expPath, "ibex-deploy", "results", "results")
_, err = os.Stat(resultsPath)
if os.IsNotExist(err) {
continue
}
if err != nil {
fmt.Errorf("Error opening results %v\n", resultsPath)
log.Fatal(err)
}
err = importResults(databaseUrl, experimentID, s3Config, resultsPath)
if err != nil {
fmt.Errorf("Error importing results %v\n", resultsPath)
log.Fatal(err)
}
}
}
// Convert untrusted_sha1_hash column to base64
conn, err := pgxpool.Connect(context.Background(), databaseUrl)
if err != nil {
log.Fatal(err)
}
_, err = conn.Exec(
context.Background(),
`UPDATE files
SET untrusted_sha1_hash =
CASE
WHEN octet_length(untrusted_sha1_hash) = 40 THEN
encode(decode(untrusted_sha1_hash, 'hex'), 'base64')
ELSE
untrusted_sha1_hash
END
WHERE s3_key LIKE 'ibex_legacy_files/%'
`,
)
if err != nil {
log.Fatal(err)
}
}
func newS3Session(config *s3Config) *session.Session {
awsConfig :=
&aws.Config{
Credentials: credentials.NewStaticCredentials(config.key, config.secret, ""),
Region: aws.String(config.region),
DisableSSL: aws.Bool(config.disableSSL),
S3ForcePathStyle: aws.Bool(config.forcePath),
}
if config.host != "" {
awsConfig.Endpoint = aws.String(config.host)
}
return session.New(awsConfig)
}
func uploadFilesToS3(table map[string]string, root string, s3Config *s3Config) error {
sess := newS3Session(s3Config)
existingHashes := make(map[string]bool)
return filepath.Walk(root, func(filePath string, info os.FileInfo, err error) error {
if info.IsDir() {
return nil
}
dir := path.Dir(filePath)
if !(strings.HasSuffix(dir, "/js_includes") || strings.HasSuffix(dir, "/chunk_includes") || strings.HasSuffix(dir, "/data_includes") || strings.HasSuffix(dir, "/css_includes") || strings.HasSuffix(dir, "/other_includes") || strings.HasSuffix(dir, "/www")) {
return nil
}
if err != nil {
return err
}
contents, err := ioutil.ReadFile(filePath)
if err != nil {
return err
}
hasher := sha1.New()
hasher.Write(contents)
hashBytes := hasher.Sum(nil)
hashHexBytes := make([]byte, hex.EncodedLen(len(hashBytes)))
hex.Encode(hashHexBytes, hashBytes)
hash := string(hashHexBytes)
table[filePath] = string(hash)
if _, ok := existingHashes[hash]; !ok {
log.Printf("Uploading %v...\n", filePath)
uploader := s3manager.NewUploader(sess)
tagging := "StorageClass=legacy-forever-expfile"
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(s3Config.bucket),
Key: aws.String(legacyIbexFilesS3Prefix + hash),
Body: bytes.NewReader(contents),
Tagging: &tagging,
})
if err != nil {
return err
}
existingHashes[hash] = true
log.Printf("...uploaded\n")
}
return nil
})
}
var legacyResultstimeFormat string = "Monday January 02 2006 15:04:05 MST"
type ibexUser struct {
Password string `json:"password"`
Email string `json:"email_address"`
Username string `json:"username"`
}
type resultsBlob struct {
time time.Time
userAgent string
counterPlusOne int
json string
}
func insertUser(databaseUrl string, fileQuotaID int, user ibexUser, insertedAt time.Time, updatedAt time.Time) (int, error) {
conn, err := pgxpool.Connect(context.Background(), databaseUrl)
if err != nil {
return -1, err
}
defer conn.Close()
var id int
err = conn.QueryRow(
context.Background(),
`INSERT INTO users (email, username, password_hash, inserted_at, updated_at, file_quota_id, file_count, byte_count)
VALUES (CASE WHEN $1 = '' THEN NULL ELSE $1 END, $2, $3, $4, $5, $6, $7, $8)
RETURNING id`,
user.Email,
user.Username,
user.Password,
insertedAt,
updatedAt,
fileQuotaID,
0,
0,
).Scan(&id)
if err != nil {
return -1, err
}
return id, nil
}
func insertExperiment(databaseUrl string, userID int, experimentPath string, s3Config *s3Config, filePathToHash map[string]string, hashToFileID map[string]int, passwordHash string, modtime time.Time) (int, error) {
config, err := pgxpool.ParseConfig(databaseUrl)
if err != nil {
return -1, err
}
config.MaxConns = 2
conn, err := pgxpool.ConnectConfig(context.Background(), config)
if err != nil {
return -1, err
}
defer conn.Close()
name := path.Base(experimentPath)
versionPath := path.Join(experimentPath, "ibex-deploy", "VERSION")
counterPath := path.Join(experimentPath, "ibex-deploy", "server_state", "counter")
versionBytes, verr := ioutil.ReadFile(versionPath)
counterBytes, cerr := ioutil.ReadFile(counterPath)
version := "unknown"
counter := 0
if verr == nil {
version = strings.TrimSpace(string(versionBytes))
} else {
fmt.Errorf("Error opening %v\n", versionPath)
}
if cerr == nil {
c, err := strconv.ParseInt(strings.TrimSpace(string(counterBytes)), 10, 64)
if err != nil {
counter = int(c)
}
} else {
fmt.Errorf("Error opening %v\n", counterPath)
}
var dbPasswordHash interface{}
if passwordHash != "" {
dbPasswordHash = passwordHash
}
var experimentID int
err = conn.QueryRow(
context.Background(),
`INSERT INTO experiments_including_soft_deleted (user_id, name, ibex_version, legacy_ibex_version, public, counter, inserted_at, updated_at, password_hash)
VALUES ($1, $2, $3, $4, FALSE, $5, $6, $7, $8)
RETURNING id`,
userID, // user_id
name, // name
ibexClassicVersion, // ibex_version
version, // legacy_ibex_version
counter, // counter
modtime, // inserted_at
modtime, // updated_at
dbPasswordHash, // password_hash
).Scan(&experimentID)
if err != nil {
return -1, err
}
for _, dir := range []string{"js_includes", "chunk_includes", "data_includes", "css_includes", "other_includes", "www"} {
dpath := path.Join(experimentPath, "ibex-deploy", dir)
files, err := ioutil.ReadDir(dpath)
if err != nil {
fmt.Errorf("Non-fatal error opening %v\n", dpath)
continue
}
for _, fileInfo := range files {
if fileInfo.IsDir() {
continue
}
filePath := path.Join(dpath, fileInfo.Name())
hash, ok := filePathToHash[filePath]
if !ok {
log.Fatalf("Unexpectedly failed to find hash for %v\n", filePath)
}
log.Printf("Inserting file %v\n", filePath)
fileID, ok := hashToFileID[hash]
if true && !ok {
err := conn.QueryRow(context.Background(), "SELECT COALESCE(MAX(id), 0) FROM files WHERE s3_bucket = $1 AND s3_key = $2 AND NOT writeable", s3Config.bucket, legacyIbexFilesS3Prefix+hash).Scan(&fileID)
if err != nil {
return -1, err
}
if fileID == 0 {
err = conn.QueryRow(
context.Background(),
`INSERT INTO files (mime_type, s3_bucket, s3_key, size, writeable, untrusted_sha1_hash, inserted_at, updated_at)
VALUES (
$1, $2, $3, $4, FALSE, $5, $6, $7
) RETURNING id`,
mimeTypeOf(fileInfo.Name()),
s3Config.bucket,
legacyIbexFilesS3Prefix+hash,
fileInfo.Size(),
hash,
fileInfo.ModTime(),
fileInfo.ModTime(),
).Scan(&fileID)
if err != nil {
return -1, err
}
log.Printf("...inserted with id %v\n", fileID)
} else {
log.Printf("...already existed with id %v\n", fileID)
}
hashToFileID[hash] = fileID
} else {
log.Printf("...already exists with id %v\n", fileID)
}
_, err := conn.Exec(
context.Background(),
`INSERT INTO experiment_files(experiment_id, directory, name, file_id, inserted_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6)`,
experimentID,
dir,
fileInfo.Name(),
fileID,
fileInfo.ModTime(),
fileInfo.ModTime(),
)
if err != nil {
return -1, err
}
}
}
return experimentID, err
}
func importResults(databaseUrl string, experimentID int, s3Config *s3Config, resultsFileName string) error {
conn, err := pgxpool.Connect(context.Background(), databaseUrl)
if err != nil {
return err
}
defer conn.Close()
resultsFileInfo, err := os.Stat(resultsFileName)
if err != nil {
return err
}
resultsBytes, err := ioutil.ReadFile(resultsFileName)
if err != nil {
return err
}
encoded := snappy.Encode(nil, resultsBytes)
sess := newS3Session(s3Config)
uploader := s3manager.NewUploader(sess)
key := fmt.Sprintf("%v%v/%v", experimentFilesPrefix, experimentID, "legacy_results")
tagging := "StorageClass=results"
_, err = uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(s3Config.bucket),
Key: aws.String(key),
Body: bytes.NewReader(encoded),
Tagging: &tagging,
})
if err != nil {
return err
}
var fileID int
err = conn.QueryRow(
context.Background(),
`INSERT INTO files (mime_type, s3_bucket, s3_key, size, writeable, untrusted_sha1_hash, inserted_at, updated_at)
VALUES (
'application/octet-stream', $1, $2, $3, FALSE, NULL, $4, $5
) RETURNING id`,
s3Config.bucket,
key,
len(encoded),
resultsFileInfo.ModTime(),
resultsFileInfo.ModTime(),
).Scan(&fileID)
if err != nil {
return err
}
_, err = conn.Exec(
context.Background(),
`INSERT INTO experiment_files(experiment_id, directory, name, file_id, inserted_at, updated_at)
VALUES ($1, 'results', 'results', $2, $3, $4)`,
experimentID,
fileID,
resultsFileInfo.ModTime(),
resultsFileInfo.ModTime(),
)
if err != nil {
return err
}
return nil
}
func importRawResults(databaseUrl string, experimentID int, s3Config *s3Config, blobs []resultsBlob, compressionBuffer []byte) error {
conn, err := pgxpool.Connect(context.Background(), databaseUrl)
if err != nil {
return err
}
defer conn.Close()
sess := newS3Session(s3Config)
for i, b := range blobs {
uploader := s3manager.NewUploader(sess)
// Using a negative number so that it won't conflict with the positive numbers coming from the sequence
key := fmt.Sprintf("%v%v/%v", rawResultsS3Prefix, experimentID, -i)
// Encoded
encoded := snappy.Encode(compressionBuffer, []byte(b.json))
encoded = append(encoded, 0)
copy(encoded[1:], encoded)
encoded[0] = 's'
tagging := "StorageClass=results"
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(s3Config.bucket),
Key: aws.String(key),
Body: bytes.NewReader(encoded),
Tagging: &tagging,
})
if err != nil {
return err
}
var counter interface{} = b.counterPlusOne - 1
if b.counterPlusOne == 0 {
counter = nil
}
_, err = conn.Exec(
context.Background(),
`INSERT INTO experiment_results (experiment_id, s3_bucket, s3_key, notified_at, received_at, legacy, counter, user_agent, inserted_at, updated_at, old_ibex)
VALUES ($1, $2, $3, $4, $5, TRUE, $6, $7, $8, $9, TRUE)`,
experimentID,
s3Config.bucket,
key,
b.time,
b.time,
counter,
b.userAgent,
b.time,
b.time,
)
if err != nil {
return err
}
}
return nil
}
func parseHtpasswd(htpasswdPath string) (map[string]string, error) {
entries := make(map[string]string)
file, err := os.Open(htpasswdPath)
if err != nil {
return entries, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(string(scanner.Text()))
if line == "" || line[0] == '#' {
continue
}
components := strings.Split(line, ":")
if len(components) != 2 {
fmt.Errorf("Error parsing htpasswd line [1]: %v\n", line)
continue
}
if !strings.HasPrefix(components[1], "$apr1$") {
fmt.Errorf("Error parsing htpasswd line [2]: %v\n", line)
continue
}
unamexpname := strings.Split(components[0], "/")
if len(unamexpname) != 2 {
fmt.Errorf("Error parsing htpasswd line [3]: %v\n", line)
continue
}
entries[components[0]] = components[1]
}
return entries, nil
}
const (
resultsOnPrefix = "# Results on "
userAgentPrefix = "# USER AGENT: "
nonRandomCounterPrefix = "# Design number was non-random = "
randomCounterPrefix = "# Design number was random = "
)
func parseRawResults(user, experiment, rawResultsPath string) ([]resultsBlob, error) {
var results []resultsBlob
file, err := os.Open(rawResultsPath)
if err != nil {
return results, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
var currentBlob resultsBlob
for scanner.Scan() {
line := string(scanner.Text())
trimmed := strings.TrimSpace(line)
if trimmed == "" || trimmed == "#" {
continue
}
if line[0] != '#' {
if !currentBlob.time.IsZero() {
currentBlob.json = line
results = append(results, currentBlob)
}
currentBlob = resultsBlob{}
} else {
if strings.HasPrefix(line, resultsOnPrefix) {
timeString := strings.Replace(line, resultsOnPrefix, "", 1)
timeString = timeString[0 : len(timeString)-1] // remove trailing '.'
if t, err := time.Parse(legacyResultstimeFormat, timeString); err == nil {
currentBlob.time = t
} else {
fmt.Errorf("Error parsing time %v: %v\n", timeString, err)
}
} else if strings.HasPrefix(line, userAgentPrefix) {
currentBlob.userAgent = strings.Replace(line, userAgentPrefix, "", 1)
} else if strings.HasPrefix(line, nonRandomCounterPrefix) {
stringCounterValue := strings.TrimSpace(strings.Replace(line, nonRandomCounterPrefix, "", 1))
i, err := strconv.ParseInt(stringCounterValue, 10, 64)
if err == nil {
currentBlob.counterPlusOne = int(i + 1)
} else {
fmt.Errorf("Error parsing counter value %v: %v\n", stringCounterValue, err)
}
} else if strings.HasPrefix(line, randomCounterPrefix) {
stringCounterValue := strings.TrimSpace(strings.Replace(line, randomCounterPrefix, "", 1))
i, err := strconv.ParseInt(stringCounterValue, 10, 64)
if err == nil {
currentBlob.counterPlusOne = int(i + 1)
} else {
fmt.Errorf("Error parsing counter value %v: %v\n", stringCounterValue, err)
}
}
}
}
return results, nil
}
func mimeTypeOf(fileName string) string {
ext := path.Ext(fileName)
if ext == ".js" {
return "application/javascript"
}
if ext == ".html" {
return "text/html"
}
if ext == ".css" {
return "text/css"
}
if ext == ".swf" {
return "application/x-shockwave-flash"
}
if ext == ".txt" {
return "text/plain"
}
if ext == ".mp3" {
return "audio/mpeg"
}
if ext == ".wav" {
return "audio/wav"
}
return "application/octet-stream"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment