Last active
May 30, 2022 19:31
-
-
Save lnsp/65c6d3a43a682438c10fad68ea96079b to your computer and use it in GitHub Desktop.
Snowflake profiling script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"context" | |
"database/sql" | |
"database/sql/driver" | |
"flag" | |
"fmt" | |
"io" | |
"log" | |
"net/http" | |
"os" | |
"path/filepath" | |
"reflect" | |
"runtime" | |
"strings" | |
"sync" | |
"time" | |
"unsafe" | |
"golang.org/x/exp/rand" | |
"github.com/snowflakedb/gosnowflake" | |
"golang.org/x/sync/semaphore" | |
"gonum.org/v1/gonum/stat/distuv" | |
) | |
var ( | |
// snowflakeToken = flag.String("token", "", "snowflake oauth token") | |
snowflakeConn = flag.String("conn", "", "snowflake connection string") | |
noCache = flag.Bool("nocache", false, "attempt to disable file-level caching") | |
sleepInterval = flag.Duration("sleep", 2*time.Minute, "sleep between queries") | |
upload = flag.String("upload", "", "upload given table file") | |
uploadDir = flag.String("upload-dir", "", "source upload directory") | |
poisson = flag.Float64("poisson", 0, "enable poisson-distributed benchmark") | |
poissonThreads = flag.Int("threads", 1, "number of threads for benchmark") | |
poissonSeed = flag.Uint64("seed", 42, "seed for poisson rng") | |
) | |
const ( | |
warehouseBaseURL = "https://vg00000.west-europe.azure.snowflakecomputing.com" | |
snowflakeProfileURL = warehouseBaseURL + "/monitoring/query-plan-data/%s" | |
snowflakeQueryURL = warehouseBaseURL + "/monitoring/queries/%s" | |
snowflakeTokenURL = warehouseBaseURL + "/session/token-request" | |
) | |
type simpleTokenAccessor struct { | |
token string | |
masterToken string | |
sessionID int64 | |
accessorLock sync.Mutex // Used to implement accessor's Lock and Unlock | |
tokenLock sync.RWMutex // Used to synchronize SetTokens and GetTokens | |
} | |
func (sta *simpleTokenAccessor) Lock() error { | |
sta.accessorLock.Lock() | |
return nil | |
} | |
func (sta *simpleTokenAccessor) Unlock() { | |
sta.accessorLock.Unlock() | |
} | |
func (sta *simpleTokenAccessor) GetTokens() (token string, masterToken string, sessionID int64) { | |
sta.tokenLock.RLock() | |
defer sta.tokenLock.RUnlock() | |
return sta.token, sta.masterToken, sta.sessionID | |
} | |
func (sta *simpleTokenAccessor) SetTokens(token string, masterToken string, sessionID int64) { | |
sta.tokenLock.Lock() | |
defer sta.tokenLock.Unlock() | |
sta.token = token | |
sta.masterToken = masterToken | |
sta.sessionID = sessionID | |
log.Println("Updated token for session", sessionID) | |
} | |
// This is somehow required because otherwise you wont be able to access Snowflake's QueryID field. | |
func getUnexportedField(field reflect.Value) interface{} { | |
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() | |
} | |
func execQuery(db *sql.Conn, sqlPath, queryProfileDest, queryInfoDest string, noCache bool, tokenAccessor gosnowflake.TokenAccessor) error { | |
queryString, err := os.ReadFile(sqlPath) | |
if err != nil { | |
return err | |
} | |
log.Println("Read query file", sqlPath) | |
// Force startup | |
if _, err := db.ExecContext(context.Background(), "ALTER WAREHOUSE RESUME IF SUSPENDED"); err != nil { | |
return err | |
} | |
log.Println("Resumed warehouse") | |
rows, err := db.QueryContext(context.Background(), string(queryString)) | |
if err != nil { | |
return err | |
} | |
count := 0 | |
for rows.Next() { | |
count++ | |
} | |
if err := rows.Err(); err != nil { | |
return err | |
} | |
var queryID string | |
val := reflect.ValueOf(rows).Elem().FieldByName("rowsi") | |
field := getUnexportedField(val) | |
if driverRows, ok := field.(driver.Rows); ok { | |
queryID = driverRows.(gosnowflake.SnowflakeResult).GetQueryID() | |
} | |
log.Printf("Query '%s' returned %d rows", queryID, count) | |
session, _, _ := tokenAccessor.GetTokens() | |
if queryInfoDest != "" { | |
// Spill out query info | |
queryRequest, err := http.NewRequest(http.MethodGet, fmt.Sprintf(snowflakeQueryURL, queryID), nil) | |
if err != nil { | |
return err | |
} | |
queryRequest.Header.Set("Authorization", fmt.Sprintf("Snowflake Token=\"%s\"", session)) | |
queryResponse, err := http.DefaultClient.Do(queryRequest) | |
if err != nil { | |
return err | |
} | |
defer queryResponse.Body.Close() | |
// Open up query info dump | |
queryFile, err := os.Create(queryInfoDest) | |
if err != nil { | |
return err | |
} | |
defer queryFile.Close() | |
if _, err := io.Copy(queryFile, queryResponse.Body); err != nil { | |
return err | |
} | |
log.Println("Dumped query info") | |
} | |
if queryProfileDest != "" { | |
// Spill out profile | |
profileRequest, err := http.NewRequest(http.MethodGet, fmt.Sprintf(snowflakeProfileURL, queryID), nil) | |
if err != nil { | |
return err | |
} | |
profileRequest.Header.Set("Authorization", fmt.Sprintf("Snowflake Token=\"%s\"", session)) | |
profileResponse, err := http.DefaultClient.Do(profileRequest) | |
if err != nil { | |
return err | |
} | |
defer profileResponse.Body.Close() | |
profileFile, err := os.Create(queryProfileDest) | |
if err != nil { | |
return err | |
} | |
defer profileFile.Close() | |
if _, err := io.Copy(profileFile, profileResponse.Body); err != nil { | |
return err | |
} | |
log.Println("Dumped query profile") | |
} | |
if noCache { | |
if _, err := db.ExecContext(context.Background(), "ALTER WAREHOUSE SUSPEND"); err != nil { | |
return err | |
} | |
log.Println("Suspended warehouse to prevent result caching, wait until fully suspended") | |
time.Sleep(*sleepInterval) | |
} | |
return nil | |
} | |
func runWithPoisson(threads int, lambda float64, rngseed uint64) error { | |
// Connect to Snowflake | |
cfg, err := gosnowflake.ParseDSN(*snowflakeConn) | |
if err != nil { | |
return err | |
} | |
//cfg.KeepSessionAlive = true | |
cfg.TokenAccessor = &simpleTokenAccessor{sessionID: -1} | |
db := sql.OpenDB(gosnowflake.NewConnector(&gosnowflake.SnowflakeDriver{}, *cfg)) | |
defer db.Close() | |
log.Println("Opened connection to Snowflake") | |
// We do not need to disable result-caching | |
// Go through tasks | |
scanner := bufio.NewScanner(os.Stdin) | |
tasks := [][3]string{} | |
for scanner.Scan() { | |
command := scanner.Text() | |
args := strings.Fields(command) | |
if len(args) != 3 { | |
return fmt.Errorf("expected 3 paths") | |
} | |
tasks = append(tasks, [...]string{args[0], args[1], args[2]}) | |
} | |
// Initialize sleep time source | |
rndsrc := rand.NewSource(rngseed) | |
rndperm := rand.New(rndsrc).Perm(len(tasks)) | |
poisson := distuv.Poisson{Lambda: lambda, Src: rndsrc} | |
// Start up workers | |
ch := make(chan [3]string, threads) | |
wg := sync.WaitGroup{} | |
wg.Add(threads) | |
for i := 0; i < threads; i++ { | |
go func() { | |
var ( | |
conn *sql.Conn | |
err error | |
reinit bool | |
connected time.Time | |
) | |
for t := range ch { | |
if time.Since(connected) > 10*time.Minute { | |
if conn != nil { | |
conn.Close() | |
} | |
conn, err = db.Conn(context.Background()) | |
if err != nil { | |
log.Fatal("create connection:", err) | |
} | |
reinit = true | |
log.Println("created new connection") | |
connected = time.Now() | |
} | |
if reinit { | |
if _, err := conn.ExecContext(context.Background(), "ALTER SESSION SET CLIENT_SESSION_KEEP_ALIVE = TRUE"); err != nil { | |
log.Fatal(err) | |
} | |
if _, err := conn.ExecContext(context.Background(), "ALTER SESSION SET USE_CACHED_RESULT = FALSE"); err != nil { | |
log.Fatal(err) | |
} | |
reinit = false | |
log.Println("setup session") | |
} | |
if err := execQuery(conn, t[0], t[1], t[2], *noCache, cfg.TokenAccessor); err != nil { | |
log.Println(t[0], err) | |
} | |
log.Println("Finished", t[0]) | |
} | |
wg.Done() | |
}() | |
} | |
// Schedule tasks | |
for i := range tasks { | |
t := tasks[rndperm[i]] | |
ch <- t | |
log.Println("Scheduled", t) | |
// Sleep | |
time.Sleep(time.Duration(float64(time.Second) * poisson.Rand())) | |
} | |
// Wait for them to finish | |
close(ch) | |
wg.Wait() | |
return nil | |
} | |
func run() error { | |
// Connect to Snowflake | |
cfg, err := gosnowflake.ParseDSN(*snowflakeConn) | |
if err != nil { | |
return err | |
} | |
cfg.TokenAccessor = &simpleTokenAccessor{sessionID: -1} | |
db := sql.OpenDB(gosnowflake.NewConnector(&gosnowflake.SnowflakeDriver{}, *cfg)) | |
defer db.Close() | |
log.Println("Opened connection to Snowflake") | |
conn, err := db.Conn(context.Background()) | |
if err != nil { | |
log.Fatal(err) | |
} | |
// Disable result caching | |
if _, err := conn.ExecContext(context.Background(), "ALTER SESSION SET USE_CACHED_RESULT = FALSE"); err != nil { | |
log.Fatal(err) | |
} | |
log.Println("Disabled result caching") | |
// Go through tasks | |
scanner := bufio.NewScanner(os.Stdin) | |
for scanner.Scan() { | |
command := scanner.Text() | |
args := strings.Fields(command) | |
if len(args) != 3 { | |
return fmt.Errorf("expected 3 paths") | |
} | |
// First arg is SQL script, second profile dest, third info dest | |
sqlPath, queryProfileDest, queryInfoDest := args[0], args[1], args[2] | |
if err := execQuery(conn, sqlPath, queryProfileDest, queryInfoDest, *noCache, cfg.TokenAccessor); err != nil { | |
return fmt.Errorf("exec query: %w", err) | |
} | |
} | |
return nil | |
} | |
func uploadToSchema() error { | |
db, err := sql.Open("snowflake", *snowflakeConn) | |
if err != nil { | |
log.Fatal(err) | |
} | |
defer db.Close() | |
log.Println("Opened connection to Snowflake, wiping table") | |
// Wipe table | |
if _, err := db.Exec(fmt.Sprintf("DELETE FROM %s", *upload)); err != nil { | |
log.Fatal(err) | |
} | |
log.Println("Opened connection to Snowflake, uploading file") | |
// Clean datastore beforehand | |
if _, err := db.Exec("REMOVE @~"); err != nil { | |
log.Fatal(err) | |
} | |
var segments []string | |
source, err := os.Open(filepath.Join(*uploadDir, *upload+".dat")) | |
if err != nil { | |
return err | |
} | |
scan := true | |
newline := []byte{'\n'} | |
scanner := bufio.NewScanner(source) | |
for scan { | |
intermediate, err := os.CreateTemp("", "snowprof") | |
if err != nil { | |
return err | |
} | |
log.Println("Created new segment", intermediate.Name()) | |
written := 0 | |
for { | |
scan = scanner.Scan() | |
if !scan { | |
break | |
} | |
n, err := intermediate.Write(scanner.Bytes()) | |
if err != nil { | |
return err | |
} | |
intermediate.Write(newline) | |
written += n + 1 | |
if written > 100_000_000 { | |
break | |
} | |
} | |
intermediate.Close() | |
segments = append(segments, intermediate.Name()) | |
} | |
maxprocs := int64(runtime.GOMAXPROCS(0)) | |
sempool := semaphore.NewWeighted(maxprocs) | |
wg := sync.WaitGroup{} | |
wg.Add(len(segments)) | |
for i, seg := range segments { | |
go func(i int, seg string) { | |
sempool.Acquire(context.Background(), 1) | |
path := fmt.Sprintf("@~/segment%d.dat", i) | |
log.Printf("Uploading segment %d (%s)", i, seg) | |
// Upload file | |
if _, err := db.Exec(fmt.Sprintf("PUT file://%s %s OVERWRITE=TRUE", seg, path)); err != nil { | |
log.Fatal(err) | |
} | |
log.Printf("Copying segment %d (%s) into table", i, seg) | |
// Copy data to table | |
if _, err := db.Exec(fmt.Sprintf("COPY INTO %s FROM %s FILE_FORMAT=( TYPE = CSV, FIELD_DELIMITER = '|')", *upload, path)); err != nil { | |
log.Fatal(err) | |
} | |
sempool.Release(1) | |
wg.Done() | |
}(i, seg) | |
} | |
// Wait for uploads to finish | |
wg.Wait() | |
return nil | |
} | |
func main() { | |
flag.Parse() | |
if *upload != "" { | |
if err := uploadToSchema(); err != nil { | |
log.Fatal(err) | |
} | |
return | |
} | |
if *poisson != 0 { | |
if err := runWithPoisson(*poissonThreads, *poisson, *poissonSeed); err != nil { | |
log.Fatal(err) | |
} | |
return | |
} | |
if err := run(); err != nil { | |
log.Fatal(err) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment