Skip to content

Instantly share code, notes, and snippets.

@raismaulana
Last active September 23, 2022 06:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raismaulana/b63e06913205bb0ae912b86a0ea552eb to your computer and use it in GitHub Desktop.
Save raismaulana/b63e06913205bb0ae912b86a0ea552eb to your computer and use it in GitHub Desktop.
Avoid SQL Injection by Writing Safe Query Using Prepare Statement and/or Whitelist Character
func query(db *sql.DB, name []string) {
query := fmt.Sprint("select id, name, grade from students ")
where := "where "
params := interface{}{}
if len(names) > 0 {
q, p := prepareStringArray(names, len(params))
params = append(params, p...)
where += fmt.Sprintf(`AND t.id IN (%v) `, q)
}
query = query + where
rows, err := db.Query(query, params...)
if err != nil {
fmt.Println(err.Error())
return
}
}
// input given:[]string{"A","B"}, 1
// return "$2,$3", []interface{}{"A","B"}
func prepareStringArray(val []string, lastSequence int) (q string, p []interface{}) {
for i, v := range val {
lastSequence++
if i != 0 {
q += ","
}
q += fmt.Sprintf("$%d", lastSequence)
p = append(p, v)
}
return
}
func query(db *sql.DB, orderBy []string) {
query := fmt.Sprint("select id, name, grade from students ")
paging := ""
if len(opt.OrderBy) > 0 {
validColumn := map[string]string{
"id": "id",
"created_at": "created_at",
}
paging = sanitizeOrderBy(rderBy, validColumn, paging)
}
query = query + paging
_, err := db.Query(query)
if err != nil {
fmt.Println(err.Error())
return
}
}
func sanitizeOrderBy(cols []string, valid map[string]string, defaultString string) (res string) {
sort := map[string]string{
"asc": "asc",
"desc": "desc",
}
i := 0
for _, col := range cols {
ss := strings.Split(col, " ")
if v, ok := valid[ss[0]]; ok {
if i != 0 {
res += ", "
}
res += prefix + v + " "
if len(ss) > 1 {
if v, ok := sort[strings.ToLower(ss[1])]; ok {
res += v + " "
}
}
i++
}
}
if res != "" {
res = "order by " + res
} else {
res = defaultString
}
return
}
func query(db *sql.DB, name string) {
query := fmt.Sprint("select id, name, grade from students ")
where := "where "
params := interface{}{}
if name != "" {
params = append(params, name)
where += "name=$%d",len(params)"
}
query = query + where
_, err := db.Query(query, params...)
if err != nil {
fmt.Println(err.Error())
return
}
}
func query(db *sql.DB, name []string) {
query := fmt.Sprint("select id, name, grade from students ")
where := "where "
params := interface{}{}
if len(names) > 0 {
where += `f.id IN ('` + strings.Join(names, "','") + `') `
}
query = query + where
_, err := db.Query(query, params...)
if err != nil {
fmt.Println(err.Error())
return
}
}
func query(db *sql.DB, orderBy []string) {
query := fmt.Sprint("select id, name, grade from students ")
paging := ""
if len(opt.OrderBy) > 0 {
paging += `ORDER BY ` + strings.Join(orderBy, ",") + ` `
}
query = query + paging
_, err := db.Query(query)
if err != nil {
fmt.Println(err.Error())
return
}
}
func query(db *sql.DB, name string) {
query := fmt.Sprint("select id, name, grade from students where name = %s", name)
_, err := db.Query(query)
if err != nil {
fmt.Println(err.Error())
return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment