Skip to content

Instantly share code, notes, and snippets.

@toravir
Last active December 3, 2018 18:59
Show Gist options
  • Save toravir/feb9c591e71f28f7fe773383150d635e to your computer and use it in GitHub Desktop.
Save toravir/feb9c591e71f28f7fe773383150d635e to your computer and use it in GitHub Desktop.
REST Interface to cockroach DB (Go talks to cockroach DB using GORM)
package main
import (
"encoding/json"
"log"
"net/http"
"fmt"
"sort"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/postgres"
)
var db *gorm.DB
func getDbHandle () *gorm.DB {
if db == nil {
const addr = "postgresql://root@127.0.0.1:26257/ccdb?sslmode=disable"
dbl, err := gorm.Open("postgres", addr)
if err != nil {
log.Fatal(err)
}
db = dbl
}
return db
}
func CreateTable(w http.ResponseWriter, req *http.Request) {
params := mux.Vars(req)
tableName := params["table"]
fields := make(map[string]interface{},0)
_ = json.NewDecoder(req.Body).Decode(&fields)
fmt.Println("Table Name:", tableName)
fmt.Println("Fields:", fields)
sqlStmt := fmt.Sprintf("CREATE TABLE %s(", tableName)
colNames := make([]string, 0, len(fields))
for k,_ := range fields {
colNames = append(colNames, k)
}
//fmt.Println("colNames:", colNames)
sort.Strings(colNames)
for i, col := range colNames {
sqlStmt += fmt.Sprintf("%s %s, ", col, fields[col])
if (i == len(colNames)-1) {
ba := []byte(sqlStmt)
ba[len(sqlStmt)-2]=')'
ba[len(sqlStmt)-1]=';'
sqlStmt = string(ba)
}
}
fmt.Println("Executing SQL Stmt:", sqlStmt, "<EOS>")
db := getDbHandle()
db.Exec(sqlStmt)
json.NewEncoder(w).Encode("Success")
}
func DeleteTable(w http.ResponseWriter, req *http.Request) {
params := mux.Vars(req)
tableName := params["table"]
db := getDbHandle()
sqlStmt := fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)
fmt.Println("Executing SQL Stmt:", sqlStmt, "<EOS>")
db.Exec(sqlStmt)
json.NewEncoder(w).Encode("Success")
}
func CreateRow(w http.ResponseWriter, req *http.Request) {
params := mux.Vars(req)
tableName := params["table"]
values := make(map[string]interface{},0)
_ = json.NewDecoder(req.Body).Decode(&values)
fmt.Println("Table Name:", tableName)
fmt.Println("Fields:", values)
sqlStmt := fmt.Sprintf("INSERT INTO %s(", tableName)
colNames := make([]string, 0, len(values))
for k,_ := range values {
colNames = append(colNames, k)
}
sort.Strings(colNames)
for i, col := range colNames {
sqlStmt += fmt.Sprintf("%s, ", col)
if (i == len(colNames)-1) {
ba := []byte(sqlStmt)
ba[len(sqlStmt)-2]=')'
sqlStmt = string(ba)
}
}
sqlStmt += "values "
for i, col := range colNames {
if i == 0 {
sqlStmt += "("
}
sqlStmt += fmt.Sprint(values[col], ", ")
if (i == len(colNames)-1) {
ba := []byte(sqlStmt)
ba[len(sqlStmt)-2]=')'
ba[len(sqlStmt)-1]=';'
sqlStmt = string(ba)
}
}
fmt.Println("Executing SQL Stmt:", sqlStmt, "<EOS>")
db := getDbHandle()
db.Exec(sqlStmt)
json.NewEncoder(w).Encode("Success")
}
func CreateMultipleRows(w http.ResponseWriter, req *http.Request) {
params := mux.Vars(req)
tableName := params["table"]
values := make(map[string]interface{},0)
_ = json.NewDecoder(req.Body).Decode(&values)
fmt.Println("Table Name:", tableName)
fmt.Println("Fields:", values)
if val, ok := values["Rows"]; ok {
maparr, err := val.([]interface{})
if !err {
panic(err)
}
fmt.Println("Map arr:", maparr)
sqlStmt := fmt.Sprintf("INSERT INTO %s(", tableName)
m1 := maparr[0].(map[string]interface{})
colNames := make([]string, 0, len(m1))
for k,_ := range m1 {
colNames = append(colNames, k)
}
sort.Strings(colNames)
for i, col := range colNames {
sqlStmt += fmt.Sprintf("%s, ", col)
if (i == len(colNames)-1) {
ba := []byte(sqlStmt)
ba[len(sqlStmt)-2]=')'
sqlStmt = string(ba)
}
}
sqlStmt += "values "
for j := 0; j < len(maparr); j++ {
m := maparr[j].(map[string]interface{})
for i, col := range colNames {
if i == 0 {
sqlStmt += "("
}
sqlStmt += fmt.Sprint(m[col], ", ")
if (i == len(colNames)-1) {
ba := []byte(sqlStmt)
ba[len(sqlStmt)-2]=')'
ba[len(sqlStmt)-1]=','
sqlStmt = string(ba)
}
}
if (j == len(maparr)-1) {
ba := []byte(sqlStmt)
ba[len(sqlStmt)-1]=';'
sqlStmt = string(ba)
}
}
fmt.Println("Executing SQL Stmt:", sqlStmt, "<EOS>")
db := getDbHandle()
db.Exec(sqlStmt)
json.NewEncoder(w).Encode("Success")
return
}
json.NewEncoder(w).Encode("Failed")
}
/*
CREATE Table in the DB:
'POST' URL: 'http://<IP>:12345/db/[TableName]'
Payload: '{"ColumName1" : "Type1", "ColumnName2" : "Type2", ...}'
DELETE Table in the DB:
'DELETE' URL: 'http://<IP>:12345/db/[TableName]'
ADD Single Row to a Table:
'POST' URL : 'http://<IP>:12345/db/[TableName]/row'
Payload: '{"ColumnName1" : Value1, "ColumnName2": Value2, ...}'
*/
func main() {
router := mux.NewRouter()
router.HandleFunc("/db/{table}", CreateTable).Methods("POST")
router.HandleFunc("/db/{table}", DeleteTable).Methods("DELETE")
router.HandleFunc("/db/{table}/row", CreateRow).Methods("POST")
//router.HandleFunc("/db/{table}/multirow", CreateMultipleRows).Methods("POST")
log.Fatal(http.ListenAndServe(":12345", router))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment