-
-
Save mattn/a42f0d5c0e135d4d03ee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go | |
index 1ce679d..f831e58 100644 | |
--- a/src/database/sql/sql.go | |
+++ b/src/database/sql/sql.go | |
@@ -14,6 +14,7 @@ package sql | |
import ( | |
"database/sql/driver" | |
+ "encoding/json" | |
"errors" | |
"fmt" | |
"io" | |
@@ -24,6 +25,7 @@ import ( | |
) | |
var drivers = make(map[string]driver.Driver) | |
+var jsonNull = []byte("null") | |
// Register makes a database driver available by the provided name. | |
// If Register is called twice with the same name or if driver is nil, | |
@@ -94,6 +96,14 @@ func (ns NullString) Value() (driver.Value, error) { | |
return ns.String, nil | |
} | |
+// MarshalJSON implements the json.Marshaler interface. | |
+func (ns NullString) MarshalJSON() ([]byte, error) { | |
+ if ns.Valid { | |
+ return json.Marshal(ns.String) | |
+ } | |
+ return jsonNull, nil | |
+} | |
+ | |
// NullInt64 represents an int64 that may be null. | |
// NullInt64 implements the Scanner interface so | |
// it can be used as a scan destination, similar to NullString. | |
@@ -120,6 +130,14 @@ func (n NullInt64) Value() (driver.Value, error) { | |
return n.Int64, nil | |
} | |
+// MarshalJSON implements the json.Marshaler interface. | |
+func (n NullInt64) MarshalJSON() ([]byte, error) { | |
+ if n.Valid { | |
+ return json.Marshal(n.Int64) | |
+ } | |
+ return jsonNull, nil | |
+} | |
+ | |
// NullFloat64 represents a float64 that may be null. | |
// NullFloat64 implements the Scanner interface so | |
// it can be used as a scan destination, similar to NullString. | |
@@ -146,6 +164,14 @@ func (n NullFloat64) Value() (driver.Value, error) { | |
return n.Float64, nil | |
} | |
+// MarshalJSON implements the json.Marshaler interface. | |
+func (n NullFloat64) MarshalJSON() ([]byte, error) { | |
+ if n.Valid { | |
+ return json.Marshal(n.Float64) | |
+ } | |
+ return jsonNull, nil | |
+} | |
+ | |
// NullBool represents a bool that may be null. | |
// NullBool implements the Scanner interface so | |
// it can be used as a scan destination, similar to NullString. | |
@@ -172,6 +198,14 @@ func (n NullBool) Value() (driver.Value, error) { | |
return n.Bool, nil | |
} | |
+// MarshalJSON implements the json.Marshaler interface. | |
+func (n NullBool) MarshalJSON() ([]byte, error) { | |
+ if n.Valid { | |
+ return json.Marshal(n.Bool) | |
+ } | |
+ return jsonNull, nil | |
+} | |
+ | |
// Scanner is an interface used by Scan. | |
type Scanner interface { | |
// Scan assigns a value from a database driver. | |
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go | |
index 60bdefa..61bf455 100644 | |
--- a/src/database/sql/sql_test.go | |
+++ b/src/database/sql/sql_test.go | |
@@ -6,6 +6,7 @@ package sql | |
import ( | |
"database/sql/driver" | |
+ "encoding/json" | |
"errors" | |
"fmt" | |
"math/rand" | |
@@ -840,6 +841,37 @@ func nullTestRun(t *testing.T, spec nullTestSpec) { | |
} | |
} | |
+var nullJsonTests = []struct { | |
+ value interface{} | |
+ json string | |
+}{ | |
+ {NullString{"not null", true}, `"not null"`}, | |
+ {NullString{"not null", false}, `null`}, | |
+ {NullString{"", true}, `""`}, | |
+ {NullString{"", false}, `null`}, | |
+ {NullInt64{123, true}, `123`}, | |
+ {NullInt64{123, false}, `null`}, | |
+ {NullFloat64{123.4, true}, `123.4`}, | |
+ {NullFloat64{123.4, false}, `null`}, | |
+ {NullBool{true, true}, `true`}, | |
+ {NullBool{false, true}, `false`}, | |
+ {NullBool{true, false}, `null`}, | |
+ {NullBool{false, false}, `null`}, | |
+} | |
+ | |
+func TestNullJson(t *testing.T) { | |
+ for _, jt := range nullJsonTests { | |
+ b, err := json.Marshal(jt.value) | |
+ if err != nil { | |
+ t.Fatal(err) | |
+ } | |
+ s := string(b) | |
+ if s != jt.json { | |
+ t.Fatalf("expected %v; got %v: err", jt.json, s) | |
+ } | |
+ } | |
+} | |
+ | |
// golang.org/issue/4859 | |
func TestQueryRowNilScanDest(t *testing.T) { | |
db := newTestDB(t, "people") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment