Skip to content

Instantly share code, notes, and snippets.

@rabbitt
Last active September 25, 2015 04:36
Show Gist options
  • Save rabbitt/b27fc3209be922b7e808 to your computer and use it in GitHub Desktop.
Save rabbitt/b27fc3209be922b7e808 to your computer and use it in GitHub Desktop.
initial (extremely hacky) patch for pgbouncer that caches the last search_path set on the client socket.
diff --git a/include/varcache.h b/include/varcache.h
index 4984b01..d363fb0 100644
--- a/include/varcache.h
+++ b/include/varcache.h
@@ -5,7 +5,8 @@ enum VarCacheIdx {
VTimeZone,
VStdStr,
VAppName,
- NumVars
+ VSearchPath,
+ NumVars,
};
typedef struct VarCache VarCache;
diff --git a/src/admin.c b/src/admin.c
index d23d24b..fc7c927 100644
--- a/src/admin.c
+++ b/src/admin.c
@@ -242,6 +242,7 @@ static bool send_one_fd(PgSocket *admin,
const char *addr, int port,
uint64_t ckey, int link,
const char *client_enc,
+ const char *search_path,
const char *std_strings,
const char *datestyle,
const char *timezone,
@@ -255,9 +256,9 @@ static bool send_one_fd(PgSocket *admin,
struct PktBuf *pkt = pktbuf_temp();
- pktbuf_write_DataRow(pkt, "issssiqisssss",
+ pktbuf_write_DataRow(pkt, "issssiqissssss",
fd, task, user, db, addr, port, ckey, link,
- client_enc, std_strings, datestyle, timezone,
+ client_enc, search_path, std_strings, datestyle, timezone,
password);
if (pkt->failed)
return false;
@@ -308,6 +309,7 @@ static bool show_one_fd(PgSocket *admin, PgSocket *sk)
VarCache *v = &sk->vars;
uint64_t ckey;
const struct PStr *client_encoding = v->var_list[VClientEncoding];
+ const struct PStr *search_path = v->var_list[VSearchPath];
const struct PStr *std_strings = v->var_list[VStdStr];
const struct PStr *datestyle = v->var_list[VDateStyle];
const struct PStr *timezone = v->var_list[VTimeZone];
@@ -334,6 +336,7 @@ static bool show_one_fd(PgSocket *admin, PgSocket *sk)
ckey,
sk->link ? sbuf_socket(&sk->link->sbuf) : 0,
client_encoding ? client_encoding->str : NULL,
+ search_path ? search_path->str : NULL,
std_strings ? std_strings->str : NULL,
datestyle ? datestyle->str : NULL,
timezone ? timezone->str : NULL,
@@ -346,7 +349,7 @@ static bool show_pooler_cb(void *arg, int fd, const PgAddr *a)
return send_one_fd(arg, fd, "pooler", NULL, NULL,
pga_ntop(a, buf, sizeof(buf)), pga_port(a), 0, 0,
- NULL, NULL, NULL, NULL, NULL);
+ NULL, NULL, NULL, NULL, NULL, NULL);
}
/* send a row with sendmsg, optionally attaching a fd */
@@ -412,12 +415,12 @@ static bool admin_show_fds(PgSocket *admin, const char *arg)
/*
* send resultset
*/
- SEND_RowDescription(res, admin, "issssiqisssss",
+ SEND_RowDescription(res, admin, "issssiqissssss",
"fd", "task",
"user", "database",
"addr", "port",
"cancel", "link",
- "client_encoding", "std_strings",
+ "client_encoding", "search_path", "std_strings",
"datestyle", "timezone", "password");
if (res)
res = show_pooler_fds(admin);
diff --git a/src/client.c b/src/client.c
index 2b01ff2..e5adebe 100644
--- a/src/client.c
+++ b/src/client.c
@@ -23,6 +23,14 @@
#include "bouncer.h"
#include <usual/pgutil.h>
+#include <usual/string.h>
+
+static inline const char *str_skip_ws(const char *p)
+{
+ while (*p && isspace(*p))
+ p++;
+ return p;
+}
static const char *hdr2hex(const struct MBuf *data, char *buf, unsigned buflen)
{
@@ -592,6 +600,85 @@ static bool handle_client_startup(PgSocket *client, PktHdr *pkt)
return true;
}
+static char *scan_for_set_search_path(struct MBuf *buf, const char type)
+{
+ char *spath = (char *)buf->data + buf->read_pos;
+ const uint8_t *nul = memchr(spath, 0, mbuf_avail_for_read(buf));
+
+ if (!nul) {
+ log_debug("scan_for_set_search_path: couldn't find null terminator.");
+ return NULL;
+ }
+
+ /* Parse packet data is composed of: <string:pstmt_name>\0<string:query>\0
+ so, to get to the query, we need to skip past the "prepared statement name" */
+ if (type == 'P') {
+ int pstmt_name_len = ((uintptr_t)nul + 1) - (uintptr_t)spath;
+ spath = spath + pstmt_name_len;
+
+ /* make sure we have a null terminated query string */
+ nul = memchr(spath, 0, mbuf_avail_for_read(buf) - pstmt_name_len);
+ if (!nul) {
+ log_debug("scan_for_set_search_path:"
+ " couldn't find null terminator for query component of Parse packet.");
+ return NULL;
+ }
+ }
+
+ spath = (char *) str_skip_ws(spath);
+
+ /* check for SET */
+ if (strncasecmp("set", spath, strlen("set")) == 0) {
+ spath += strlen("set");
+ } else {
+ return NULL;
+ }
+
+ spath = (char *) str_skip_ws(spath);
+
+ /* check for search_path */
+ if (strncasecmp("search_path", spath, strlen("search_path")) == 0) {
+ spath += strlen("search_path");
+ } else {
+ return NULL;
+ }
+
+ spath = (char *) str_skip_ws(spath);
+
+ /* check for '=' or 'TO' */
+ if (*spath == '=') {
+ spath++;
+ } else {
+ if (tolower(*spath++) != 't') return NULL;
+ if (tolower(*spath++) != 'o') return NULL;
+ }
+
+ /* skip the last bit of whitespace, leading up to the actual search_path */
+ return (char *) str_skip_ws(spath);
+}
+
+static void client_cache_search_path(PgSocket *client, PktHdr *pkt)
+{
+ char spath[128], *p = NULL;
+
+ if ( (p = scan_for_set_search_path(&pkt->data, pkt->type)) != NULL) {
+ if (strlen(p) >= sizeof(spath)) {
+ slog_debug(client, "cache_search_path: search_path too big; skipping");
+ return;
+ }
+
+ strncpy(spath, p, strlen(p) + 1);
+
+ /* remove whitespace and semicolons from the end of the data */
+ p = (char *)spath + strlen(p) - 1;
+ while(p > (char *)spath && (isspace(*p) || *p == ';')) p--;
+ *(p+1) = 0;
+
+ varcache_set(&client->vars, "search_path", spath);
+ slog_noise(client, "cached search_path -> %s", spath);
+ }
+}
+
/* decide on packets of logged-in client */
static bool handle_client_work(PgSocket *client, PktHdr *pkt)
{
@@ -606,6 +693,8 @@ static bool handle_client_work(PgSocket *client, PktHdr *pkt)
disconnect_client(client, true, "PQexec disallowed");
return false;
}
+ client_cache_search_path(client, pkt);
+
case 'F': /* FunctionCall */
/* request immediate response from server */
@@ -625,6 +714,7 @@ static bool handle_client_work(PgSocket *client, PktHdr *pkt)
* to buffer packets until sync or flush is sent by client
*/
case 'P': /* Parse */
+ client_cache_search_path(client, pkt);
case 'E': /* Execute */
case 'C': /* Close */
case 'B': /* Bind */
diff --git a/src/varcache.c b/src/varcache.c
index 6321dc5..33db06f 100644
--- a/src/varcache.c
+++ b/src/varcache.c
@@ -35,6 +35,7 @@ static const struct var_lookup lookup [] = {
{"TimeZone", VTimeZone },
{"standard_conforming_strings", VStdStr },
{"application_name", VAppName },
+ {"search_path", VSearchPath },
{NULL},
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment