Skip to content

Instantly share code, notes, and snippets.

@lnsp
Last active May 30, 2022 19:31
Show Gist options
  • Save lnsp/65c6d3a43a682438c10fad68ea96079b to your computer and use it in GitHub Desktop.
Save lnsp/65c6d3a43a682438c10fad68ea96079b to your computer and use it in GitHub Desktop.
Snowflake profiling script
package main
import (
"bufio"
"context"
"database/sql"
"database/sql/driver"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
"time"
"unsafe"
"golang.org/x/exp/rand"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/sync/semaphore"
"gonum.org/v1/gonum/stat/distuv"
)
var (
// snowflakeToken = flag.String("token", "", "snowflake oauth token")
snowflakeConn = flag.String("conn", "", "snowflake connection string")
noCache = flag.Bool("nocache", false, "attempt to disable file-level caching")
sleepInterval = flag.Duration("sleep", 2*time.Minute, "sleep between queries")
upload = flag.String("upload", "", "upload given table file")
uploadDir = flag.String("upload-dir", "", "source upload directory")
poisson = flag.Float64("poisson", 0, "enable poisson-distributed benchmark")
poissonThreads = flag.Int("threads", 1, "number of threads for benchmark")
poissonSeed = flag.Uint64("seed", 42, "seed for poisson rng")
)
const (
warehouseBaseURL = "https://vg00000.west-europe.azure.snowflakecomputing.com"
snowflakeProfileURL = warehouseBaseURL + "/monitoring/query-plan-data/%s"
snowflakeQueryURL = warehouseBaseURL + "/monitoring/queries/%s"
snowflakeTokenURL = warehouseBaseURL + "/session/token-request"
)
type simpleTokenAccessor struct {
token string
masterToken string
sessionID int64
accessorLock sync.Mutex // Used to implement accessor's Lock and Unlock
tokenLock sync.RWMutex // Used to synchronize SetTokens and GetTokens
}
func (sta *simpleTokenAccessor) Lock() error {
sta.accessorLock.Lock()
return nil
}
func (sta *simpleTokenAccessor) Unlock() {
sta.accessorLock.Unlock()
}
func (sta *simpleTokenAccessor) GetTokens() (token string, masterToken string, sessionID int64) {
sta.tokenLock.RLock()
defer sta.tokenLock.RUnlock()
return sta.token, sta.masterToken, sta.sessionID
}
func (sta *simpleTokenAccessor) SetTokens(token string, masterToken string, sessionID int64) {
sta.tokenLock.Lock()
defer sta.tokenLock.Unlock()
sta.token = token
sta.masterToken = masterToken
sta.sessionID = sessionID
log.Println("Updated token for session", sessionID)
}
// This is somehow required because otherwise you wont be able to access Snowflake's QueryID field.
func getUnexportedField(field reflect.Value) interface{} {
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
}
func execQuery(db *sql.Conn, sqlPath, queryProfileDest, queryInfoDest string, noCache bool, tokenAccessor gosnowflake.TokenAccessor) error {
queryString, err := os.ReadFile(sqlPath)
if err != nil {
return err
}
log.Println("Read query file", sqlPath)
// Force startup
if _, err := db.ExecContext(context.Background(), "ALTER WAREHOUSE RESUME IF SUSPENDED"); err != nil {
return err
}
log.Println("Resumed warehouse")
rows, err := db.QueryContext(context.Background(), string(queryString))
if err != nil {
return err
}
count := 0
for rows.Next() {
count++
}
if err := rows.Err(); err != nil {
return err
}
var queryID string
val := reflect.ValueOf(rows).Elem().FieldByName("rowsi")
field := getUnexportedField(val)
if driverRows, ok := field.(driver.Rows); ok {
queryID = driverRows.(gosnowflake.SnowflakeResult).GetQueryID()
}
log.Printf("Query '%s' returned %d rows", queryID, count)
session, _, _ := tokenAccessor.GetTokens()
if queryInfoDest != "" {
// Spill out query info
queryRequest, err := http.NewRequest(http.MethodGet, fmt.Sprintf(snowflakeQueryURL, queryID), nil)
if err != nil {
return err
}
queryRequest.Header.Set("Authorization", fmt.Sprintf("Snowflake Token=\"%s\"", session))
queryResponse, err := http.DefaultClient.Do(queryRequest)
if err != nil {
return err
}
defer queryResponse.Body.Close()
// Open up query info dump
queryFile, err := os.Create(queryInfoDest)
if err != nil {
return err
}
defer queryFile.Close()
if _, err := io.Copy(queryFile, queryResponse.Body); err != nil {
return err
}
log.Println("Dumped query info")
}
if queryProfileDest != "" {
// Spill out profile
profileRequest, err := http.NewRequest(http.MethodGet, fmt.Sprintf(snowflakeProfileURL, queryID), nil)
if err != nil {
return err
}
profileRequest.Header.Set("Authorization", fmt.Sprintf("Snowflake Token=\"%s\"", session))
profileResponse, err := http.DefaultClient.Do(profileRequest)
if err != nil {
return err
}
defer profileResponse.Body.Close()
profileFile, err := os.Create(queryProfileDest)
if err != nil {
return err
}
defer profileFile.Close()
if _, err := io.Copy(profileFile, profileResponse.Body); err != nil {
return err
}
log.Println("Dumped query profile")
}
if noCache {
if _, err := db.ExecContext(context.Background(), "ALTER WAREHOUSE SUSPEND"); err != nil {
return err
}
log.Println("Suspended warehouse to prevent result caching, wait until fully suspended")
time.Sleep(*sleepInterval)
}
return nil
}
func runWithPoisson(threads int, lambda float64, rngseed uint64) error {
// Connect to Snowflake
cfg, err := gosnowflake.ParseDSN(*snowflakeConn)
if err != nil {
return err
}
//cfg.KeepSessionAlive = true
cfg.TokenAccessor = &simpleTokenAccessor{sessionID: -1}
db := sql.OpenDB(gosnowflake.NewConnector(&gosnowflake.SnowflakeDriver{}, *cfg))
defer db.Close()
log.Println("Opened connection to Snowflake")
// We do not need to disable result-caching
// Go through tasks
scanner := bufio.NewScanner(os.Stdin)
tasks := [][3]string{}
for scanner.Scan() {
command := scanner.Text()
args := strings.Fields(command)
if len(args) != 3 {
return fmt.Errorf("expected 3 paths")
}
tasks = append(tasks, [...]string{args[0], args[1], args[2]})
}
// Initialize sleep time source
rndsrc := rand.NewSource(rngseed)
rndperm := rand.New(rndsrc).Perm(len(tasks))
poisson := distuv.Poisson{Lambda: lambda, Src: rndsrc}
// Start up workers
ch := make(chan [3]string, threads)
wg := sync.WaitGroup{}
wg.Add(threads)
for i := 0; i < threads; i++ {
go func() {
var (
conn *sql.Conn
err error
reinit bool
connected time.Time
)
for t := range ch {
if time.Since(connected) > 10*time.Minute {
if conn != nil {
conn.Close()
}
conn, err = db.Conn(context.Background())
if err != nil {
log.Fatal("create connection:", err)
}
reinit = true
log.Println("created new connection")
connected = time.Now()
}
if reinit {
if _, err := conn.ExecContext(context.Background(), "ALTER SESSION SET CLIENT_SESSION_KEEP_ALIVE = TRUE"); err != nil {
log.Fatal(err)
}
if _, err := conn.ExecContext(context.Background(), "ALTER SESSION SET USE_CACHED_RESULT = FALSE"); err != nil {
log.Fatal(err)
}
reinit = false
log.Println("setup session")
}
if err := execQuery(conn, t[0], t[1], t[2], *noCache, cfg.TokenAccessor); err != nil {
log.Println(t[0], err)
}
log.Println("Finished", t[0])
}
wg.Done()
}()
}
// Schedule tasks
for i := range tasks {
t := tasks[rndperm[i]]
ch <- t
log.Println("Scheduled", t)
// Sleep
time.Sleep(time.Duration(float64(time.Second) * poisson.Rand()))
}
// Wait for them to finish
close(ch)
wg.Wait()
return nil
}
func run() error {
// Connect to Snowflake
cfg, err := gosnowflake.ParseDSN(*snowflakeConn)
if err != nil {
return err
}
cfg.TokenAccessor = &simpleTokenAccessor{sessionID: -1}
db := sql.OpenDB(gosnowflake.NewConnector(&gosnowflake.SnowflakeDriver{}, *cfg))
defer db.Close()
log.Println("Opened connection to Snowflake")
conn, err := db.Conn(context.Background())
if err != nil {
log.Fatal(err)
}
// Disable result caching
if _, err := conn.ExecContext(context.Background(), "ALTER SESSION SET USE_CACHED_RESULT = FALSE"); err != nil {
log.Fatal(err)
}
log.Println("Disabled result caching")
// Go through tasks
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
command := scanner.Text()
args := strings.Fields(command)
if len(args) != 3 {
return fmt.Errorf("expected 3 paths")
}
// First arg is SQL script, second profile dest, third info dest
sqlPath, queryProfileDest, queryInfoDest := args[0], args[1], args[2]
if err := execQuery(conn, sqlPath, queryProfileDest, queryInfoDest, *noCache, cfg.TokenAccessor); err != nil {
return fmt.Errorf("exec query: %w", err)
}
}
return nil
}
func uploadToSchema() error {
db, err := sql.Open("snowflake", *snowflakeConn)
if err != nil {
log.Fatal(err)
}
defer db.Close()
log.Println("Opened connection to Snowflake, wiping table")
// Wipe table
if _, err := db.Exec(fmt.Sprintf("DELETE FROM %s", *upload)); err != nil {
log.Fatal(err)
}
log.Println("Opened connection to Snowflake, uploading file")
// Clean datastore beforehand
if _, err := db.Exec("REMOVE @~"); err != nil {
log.Fatal(err)
}
var segments []string
source, err := os.Open(filepath.Join(*uploadDir, *upload+".dat"))
if err != nil {
return err
}
scan := true
newline := []byte{'\n'}
scanner := bufio.NewScanner(source)
for scan {
intermediate, err := os.CreateTemp("", "snowprof")
if err != nil {
return err
}
log.Println("Created new segment", intermediate.Name())
written := 0
for {
scan = scanner.Scan()
if !scan {
break
}
n, err := intermediate.Write(scanner.Bytes())
if err != nil {
return err
}
intermediate.Write(newline)
written += n + 1
if written > 100_000_000 {
break
}
}
intermediate.Close()
segments = append(segments, intermediate.Name())
}
maxprocs := int64(runtime.GOMAXPROCS(0))
sempool := semaphore.NewWeighted(maxprocs)
wg := sync.WaitGroup{}
wg.Add(len(segments))
for i, seg := range segments {
go func(i int, seg string) {
sempool.Acquire(context.Background(), 1)
path := fmt.Sprintf("@~/segment%d.dat", i)
log.Printf("Uploading segment %d (%s)", i, seg)
// Upload file
if _, err := db.Exec(fmt.Sprintf("PUT file://%s %s OVERWRITE=TRUE", seg, path)); err != nil {
log.Fatal(err)
}
log.Printf("Copying segment %d (%s) into table", i, seg)
// Copy data to table
if _, err := db.Exec(fmt.Sprintf("COPY INTO %s FROM %s FILE_FORMAT=( TYPE = CSV, FIELD_DELIMITER = '|')", *upload, path)); err != nil {
log.Fatal(err)
}
sempool.Release(1)
wg.Done()
}(i, seg)
}
// Wait for uploads to finish
wg.Wait()
return nil
}
func main() {
flag.Parse()
if *upload != "" {
if err := uploadToSchema(); err != nil {
log.Fatal(err)
}
return
}
if *poisson != 0 {
if err := runWithPoisson(*poissonThreads, *poisson, *poissonSeed); err != nil {
log.Fatal(err)
}
return
}
if err := run(); err != nil {
log.Fatal(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment