Skip to content

Instantly share code, notes, and snippets.

@miguelff
Created November 9, 2023 15:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save miguelff/0e49db62c80780f0a22d77f86f7d70fc to your computer and use it in GitHub Desktop.
Save miguelff/0e49db62c80780f0a22d77f86f7d70fc to your computer and use it in GitHub Desktop.
Channels based linearization
diff --git a/main.go b/main.go
index fdb5ccf..8a68f58 100644
--- a/main.go
+++ b/main.go
@@ -34,9 +34,57 @@ type mysqlConnKey struct {
username, pass, session string
}
+type request struct {
+ query string
+ session string
+}
+
type timedConn struct {
*mysql.Conn
lastUsed time.Time
+ reqs chan (*request)
+ res chan (*psdbv1alpha1.ExecuteResponse)
+}
+
+func (c *timedConn) close() {
+ close(c.reqs)
+ close(c.res)
+ c.Conn.Close()
+}
+
+func newTimedConn(my *mysql.Conn) *timedConn {
+ conn := timedConn{
+ Conn: my,
+ lastUsed: time.Now(),
+ reqs: make(chan (*request)),
+ res: make(chan (*psdbv1alpha1.ExecuteResponse)),
+ }
+
+ go func() {
+ for {
+ select {
+ case request, ok := <-conn.reqs:
+ if !ok {
+ return
+ }
+ conn.lastUsed = time.Now()
+ qr, err := conn.ExecuteFetch(request.query, int(*flagMySQLMaxRows), true)
+ conn.res <- &psdbv1alpha1.ExecuteResponse{
+ Session: request.session,
+ Result: sqltypes.ResultToProto3(qr),
+ Error: vterrors.ToVTRPC(err),
+ }
+ case <-time.After(*flagMySQLIdleTimeout):
+ expiration := time.Now().Add(-*flagMySQLIdleTimeout)
+ if conn.lastUsed.Before(expiration) {
+ conn.close()
+ }
+ return
+ }
+ }
+ }()
+
+ return &conn
}
var (
@@ -57,7 +105,7 @@ var (
// since this isn't meant to truly represent reality, it's possible you
// can do things with connections locally by munging session ids or auth
// that aren't allowed on PlanetScale. This is meant to just mimic the public API.
-func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, error) {
+func getConn(ctx context.Context, uname, pass, session string) (*timedConn, error) {
key := mysqlConnKey{uname, pass, session}
// check first if there's already a connection
@@ -65,7 +113,7 @@ func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, err
if conn, ok := connPool[key]; ok {
connMu.RUnlock()
conn.lastUsed = time.Now()
- return conn.Conn, nil
+ return conn, nil
}
connMu.RUnlock()
@@ -79,15 +127,12 @@ func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, err
// lock to write to map
connMu.Lock()
- connPool[key] = &timedConn{rawConn, time.Now()}
+ connPool[key] = newTimedConn(rawConn)
connMu.Unlock()
// since it was parallel, the last one would have won and been written
// so re-read back so we use the conn that was actually stored in the pool
- connMu.RLock()
- conn := connPool[key]
- connMu.RUnlock()
- return conn.Conn, nil
+ return getConn(ctx, uname, pass, session)
}
// dial connects to the underlying MySQL server, and switches to the underlying
@@ -187,13 +232,9 @@ func (s *server) Execute(
return nil, err
}
- // This is a gross simplificiation, but is likely sufficient
- qr, err := conn.ExecuteFetch(query, int(*flagMySQLMaxRows), true)
- return connect.NewResponse(&psdbv1alpha1.ExecuteResponse{
- Session: session,
- Result: sqltypes.ResultToProto3(qr),
- Error: vterrors.ToVTRPC(err),
- }), nil
+ conn.reqs <- &request{query, session}
+ res := <-conn.res
+ return connect.NewResponse(res), nil
}
func initConnPool() {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment