Skip to content

Instantly share code, notes, and snippets.

@tpiperatgod
Last active May 3, 2021 02:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tpiperatgod/23733668d2f17a0dcd9e8e98c6606e93 to your computer and use it in GitHub Desktop.
Save tpiperatgod/23733668d2f17a0dcd9e8e98c6606e93 to your computer and use it in GitHub Desktop.
Change '{}' to '[]' for handle JSONB in lib/pg/array.go
package slices
import (
"bytes"
"database/sql/driver"
"encoding/csv"
"encoding/json"
"fmt"
"io"
"strings"
)
// String is a slice of strings.
type String []string
// Interface implements the nulls.nullable interface.
func (s String) Interface() interface{} {
return []string(s)
}
// Scan implements the sql.Scanner interface.
// It allows to read the string slice from the database value.
func (s *String) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return s.scanBytes(src)
case string:
return s.scanBytes([]byte(src))
case nil:
*s = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to StringArray", src)
}
func (s *String) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "StringArray")
if err != nil {
return err
}
if *s != nil && len(elems) == 0 {
*s = (*s)[:0]
} else {
b := make(String, len(elems))
for i, v := range elems {
if b[i] = string(v); v == nil {
return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i)
}
}
*s = b
}
return nil
}
// Value implements the driver.Valuer interface.
// It allows to convert the string slice to a driver.value.
func (s String) Value() (driver.Value, error) {
if s == nil {
return nil, nil
}
if n := len(s); n > 0 {
// There will be at least two curly brackets, 2*N bytes of quotes,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+3*n)
// '{' => '['
b[0] = '['
b = appendArrayQuotedBytes(b, []byte(s[0]))
for i := 1; i < n; i++ {
b = append(b, ',')
b = appendArrayQuotedBytes(b, []byte(s[i]))
}
// '}' => ']'
return string(append(b, ']')), nil
}
// '{}' => '[]'
return "[]", nil
}
// UnmarshalJSON will unmarshall JSON value into
// the string slice representation of this value.
func (s *String) UnmarshalJSON(data []byte) error {
var ss []string
if err := json.Unmarshal(data, &ss); err != nil {
return err
}
*s = ss
return nil
}
// UnmarshalText will unmarshall text value into
// the string slice representation of this value.
func (s *String) UnmarshalText(text []byte) error {
r := csv.NewReader(bytes.NewReader(text))
var words []string
for {
record, err := r.Read()
if err == io.EOF {
break
}
if err != nil {
return err
}
words = append(words, record...)
}
*s = words
return nil
}
// TagValue implements the tagValuer interface, to work with https://github.com/gobuffalo/tags.
func (s String) TagValue() string {
return s.Format(",")
}
// Format presents the slice as a string, using a given separator.
func (s String) Format(sep string) string {
return strings.Join(s, sep)
}
// parseArray extracts the dimensions and elements of an array represented in
// text format. Only representations emitted by the backend are supported.
// Notably, whitespace around brackets and delimiters is significant, and NULL
// is case-sensitive.
//
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
var depth, i int
// '{' => '['
if len(src) < 1 || src[0] != '[' {
// '{' => '['
return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '[', 0)
}
Open:
for i < len(src) {
switch src[i] {
// '{' => '['
case '[':
depth++
i++
// '}' => ']'
case ']':
elems = make([][]byte, 0)
goto Close
default:
break Open
}
}
dims = make([]int, i)
Element:
for i < len(src) {
switch src[i] {
// '{' => '['
case '[':
if depth == len(dims) {
break Element
}
depth++
dims[depth-1] = 0
i++
case '"':
var elem = []byte{}
var escape bool
for i++; i < len(src); i++ {
if escape {
elem = append(elem, src[i])
escape = false
} else {
switch src[i] {
default:
elem = append(elem, src[i])
case '\\':
escape = true
case '"':
elems = append(elems, elem)
i++
break Element
}
}
}
default:
for start := i; i < len(src); i++ {
if bytes.HasPrefix(src[i:], del) || src[i] == ']' {
elem := src[start:i]
if len(elem) == 0 {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
if bytes.Equal(elem, []byte("NULL")) {
elem = nil
}
elems = append(elems, elem)
break Element
}
}
}
}
for i < len(src) {
if bytes.HasPrefix(src[i:], del) && depth > 0 {
dims[depth-1]++
i += len(del)
goto Element
// '}' => ']'
} else if src[i] == ']' && depth > 0 {
dims[depth-1]++
depth--
i++
} else {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
Close:
for i < len(src) {
// '}' => ']'
if src[i] == ']' && depth > 0 {
depth--
i++
} else {
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
}
}
if depth > 0 {
// '}' => ']'
err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", ']', i)
}
if err == nil {
for _, d := range dims {
if (len(elems) % d) != 0 {
err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions")
}
}
}
return
}
func appendArrayQuotedBytes(b, v []byte) []byte {
b = append(b, '"')
for {
i := bytes.IndexAny(v, `"\`)
if i < 0 {
b = append(b, v...)
break
}
if i > 0 {
b = append(b, v[:i]...)
}
b = append(b, '\\', v[i])
v = v[i+1:]
}
return append(b, '"')
}
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
dims, elems, err := parseArray(src, del)
if err != nil {
return nil, err
}
if len(dims) > 1 {
return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
}
return elems, err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment