Skip to content

Instantly share code, notes, and snippets.

@ericlagergren
Last active April 12, 2016 17:42
Show Gist options
  • Save ericlagergren/a7f05b9ad12b627cf9994fdbca001769 to your computer and use it in GitHub Desktop.
Save ericlagergren/a7f05b9ad12b627cf9994fdbca001769 to your computer and use it in GitHub Desktop.
Parse PostgreSQL arrays
package main
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
)
func main() {
data := map[string]sql.Scanner{
`{google.com,https://foo.bar.com,www.example.com?abc=123#foo}`: &StringArray{},
`{}`: &StringArray{},
`{NULL,NULL,NULL}`: &NullStringArray{},
`{"something coo\"l"}`: &StringArray{},
`{foo\,bar,baz}`: &StringArray{},
`{100,200,300,500,10000,-1}`: &IntArray{},
`{NULL,NULL,NULL,NULL}`: &NullIntArray{},
`{1234,0}`: &IntArray{},
`{ }`: &IntArray{},
}
for k, v := range data {
v.Scan([]byte(k))
var length int
switch t := v.(type) {
case *IntArray:
length = len(*t)
case *NullIntArray:
length = len(*t)
case *StringArray:
length = len(*t)
case *NullStringArray:
length = len(*t)
}
fmt.Printf("%s (%d): %v\n", k, length, v)
}
}
type ParseError struct {
c byte
}
func (p *ParseError) Error() string {
return fmt.Sprintf("%c is an invalid character", p.c)
}
func parse(data []byte, pred func(byte) bool, get func([]byte) error) (err error) {
if len(data) < 2 {
return errors.New("len(data) < 2")
}
if data[0] == '{' && data[len(data)-1] == '}' {
if len(data) == 2 {
return nil
}
data = data[1 : len(data)-1]
}
var mark int
eof := len(data) - 1
for i := range data {
switch c := data[i]; {
case i == eof:
i++
fallthrough
case c == ',':
if data[i-1] != '\\' {
err = get(data[mark:i])
if err != nil {
return err
}
mark = i + 1
}
case !pred(c):
return &ParseError{c: c}
}
}
return nil
}
func getData(val interface{}) ([]byte, bool) {
if val == nil {
return nil, true
}
data, ok := val.([]byte)
return data, ok
}
type IntArray []int64
func (i *IntArray) Scan(val interface{}) error {
data, ok := getData(val)
if !ok {
return errors.New("IntArray.Scan: invalid type")
}
return parse(data, isnum, func(data []byte) error {
v, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return err
}
*i = append(*i, v)
return nil
})
}
func isnum(c byte) bool {
return (c >= '0' && c <= '9') || c == '-'
}
type NullIntArray []sql.NullInt64
func (i *NullIntArray) Scan(val interface{}) error {
data, ok := getData(val)
if !ok {
return errors.New("NullIntArray.Scan: invalid type")
}
var n sql.NullInt64
return parse(data, isNullNum, func(data []byte) error {
// If data != []byte("NULL")
// Avoiding an alloc if possible.
n.Valid = len(data) != 4 ||
data[0] != 'N' || data[1] != 'U' ||
data[2] != 'L' || data[3] != 'L'
if n.Valid {
v, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return err
}
n.Int64 = v
}
*i = append(*i, n)
return nil
})
}
func isNullNum(c byte) bool {
return (c >= '0' && c <= '9') || c == '-' ||
c == 'N' || c == 'U' || c == 'L'
}
var repl = strings.NewReplacer(`\\`, `\`, `\`, ``)
type StringArray []string
func (s *StringArray) Scan(val interface{}) error {
data, ok := getData(val)
if !ok {
return errors.New("StringArray.Scan: invalid type")
}
return parse(data, func(byte) bool { return true }, func(data []byte) error {
*s = append(*s, repl.Replace(string(data)))
return nil
})
}
type NullStringArray []sql.NullString
func (n *NullStringArray) Scan(val interface{}) error {
data, ok := getData(val)
if !ok {
return errors.New("NullStringArray.Scan: invalid type")
}
var s sql.NullString
return parse(data, func(byte) bool { return true }, func(data []byte) error {
// Avoiding an alloc if possible.
s.Valid = len(data) != 4 ||
data[0] != 'N' || data[1] != 'U' ||
data[2] != 'L' || data[3] != 'L'
// If data != []byte("NULL")
if s.Valid {
s.String = repl.Replace(string(data))
}
*n = append(*n, s)
return nil
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment