Skip to content

Instantly share code, notes, and snippets.

@mpenick
Last active May 24, 2022 13:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mpenick/8b95bd6326d375de46e4fb6981dad066 to your computer and use it in GitHub Desktop.
Save mpenick/8b95bd6326d375de46e4fb6981dad066 to your computer and use it in GitHub Desktop.
package sugar
import (
"errors"
"github.com/stargate/stargate-grpc-go-client/stargate/pkg/client"
pb "github.com/stargate/stargate-grpc-go-client/stargate/pkg/proto"
"google.golang.org/grpc"
)
var defaultCodecs = map[int]Codec{
Int64Type: &IntCodec{},
StringType: &StringCodec{},
}
const (
Int64Type int = iota
StringType
)
type (
Client struct {
wrapped *client.StargateClient
codecs map[int]Codec
}
Response struct {
wrapped *pb.Response
client *Client
}
Scanner struct {
rowIdx int
codecs map[int]Codec
rs *pb.ResultSet
err error
}
Codec interface {
Encode(v interface{}) *pb.Value
Decode(value *pb.Value) (interface{}, error)
}
)
type IntCodec struct {
}
func (i *IntCodec) Decode(value *pb.Value) (interface{}, error) {
if v, ok := value.Inner.(*pb.Value_Int); !ok {
return nil, errors.New("not an integer")
} else {
return v.Int, nil
}
}
func (i *IntCodec) Encode(v interface{}) *pb.Value {
return &pb.Value{
Inner: &pb.Value_Int{Int: v.(int64)},
}
}
type StringCodec struct {
}
func (s *StringCodec) Decode(value *pb.Value) (interface{}, error) {
if v, ok := value.Inner.(*pb.Value_String_); !ok {
return nil, errors.New("not an string")
} else {
return v.String_, nil
}
}
func (s *StringCodec) Encode(v interface{}) *pb.Value {
return &pb.Value{
Inner: &pb.Value_String_{String_: v.(string)},
}
}
func NewClient(conn grpc.ClientConnInterface) (*Client, error) {
cl, err := client.NewStargateClientWithConn(conn)
if err != nil {
return nil, err
}
return &Client{
wrapped: cl,
codecs: defaultCodecs,
}, nil
}
func NewClientWithCodecs(conn grpc.ClientConnInterface, codecFunc) (*Client, error) {
cl, err := client.NewStargateClientWithConn(conn)
if err != nil {
return nil, err
}
return &Client{
wrapped: cl,
codecs: codecFunc(defaultCodecs),
}, nil
}
func (c *Client) ExecuteQuery(cql string, args ...interface{}) (Response, error) {
values, err := c.buildValues(args)
if err != nil {
return Response{}, err
}
res, err := c.wrapped.ExecuteQuery(&pb.Query{
Cql: cql,
Values: values,
})
return Response{res, c}, err
}
func (c *Client) buildValues(args ...interface{}) (*pb.Values, error) {
values := make([]*pb.Value, len(args))
for i, a := range args {
val, err := c.convertValue(a)
if err != nil {
return nil, err
}
values[i] = val
}
return &pb.Values{
Values: values,
}, nil
}
func (c *Client) convertValue(v interface{}) (*pb.Value, error) {
switch v.(type) {
case int:
return c.codecs[Int64Type].Encode(v), nil
case string:
return c.codecs[StringType].Encode(v), nil
}
return nil, errors.New("unhandled value type")
}
func (r *Response) Scanner() *Scanner {
return &Scanner{
codecs: r.client.codecs,
rs: r.wrapped.GetResultSet(),
}
}
func (s *Scanner) Scan(values ...interface{}) bool {
if s.rowIdx >= len(s.rs.Rows) || s.err != nil {
return false
}
if len(values) >= len(s.rs.Columns) {
s.err = errors.New("too many values for the number of columns")
return false
}
for i, value := range values {
switch v := value.(type) {
case *int:
decoded, err := s.codecs[Int64Type].Decode(s.rs.Rows[s.rowIdx].Values[i])
if err != nil {
s.err = err
return false
} else {
*v = int(decoded.(int64))
}
case *string:
decoded, err := s.codecs[StringType].Decode(s.rs.Rows[s.rowIdx].Values[i])
if err != nil {
s.err = err
return false
} else {
*v = decoded.(string)
}
}
}
return true
}
func (s *Scanner) Err() error {
return s.err
}
package sugar
import (
"fmt"
"testing"
)
func Test(t *testing.T) {
client, _ := NewClient(nil)
res, _ := client.ExecuteQuery("SELECT v FROM table k = ?", "key")
s := res.Scanner()
var v string
for s.Scan(&v) {
fmt.Println(v)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment